Merge pull request #1519 from sechkova/fetcher-max-length

Remove max_length parameter from fetch
This commit is contained in:
Jussi Kukkonen 2021-09-01 17:27:12 +03:00 committed by GitHub
commit deec2eaaa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 172 additions and 41 deletions

150
tests/test_fetcher_ng.py Normal file
View file

@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2021, New York University and the TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Unit test for RequestsFetcher.
"""
import io
import logging
import os
import sys
import unittest
import tempfile
import math
from tests import utils
from tuf import exceptions, unittest_toolbox
from tuf.ngclient._internal.requests_fetcher import RequestsFetcher
logger = logging.getLogger(__name__)
class TestFetcher(unittest_toolbox.Modified_TestCase):
@classmethod
def setUpClass(cls):
# Launch a SimpleHTTPServer (serves files in the current dir).
cls.server_process_handler = utils.TestServerProcess(log=logger)
@classmethod
def tearDownClass(cls):
# Stop server process and perform clean up.
cls.server_process_handler.clean()
def setUp(self):
"""
Create a temporary file and launch a simple server in the
current working directory.
"""
unittest_toolbox.Modified_TestCase.setUp(self)
# Making a temporary data file.
current_dir = os.getcwd()
target_filepath = self.make_temp_data_file(directory=current_dir)
self.target_fileobj = open(target_filepath, "r")
self.file_contents = self.target_fileobj.read()
self.file_length = len(self.file_contents)
self.rel_target_filepath = os.path.basename(target_filepath)
self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/{self.rel_target_filepath}"
# Instantiate a concrete instance of FetcherInterface
self.fetcher = RequestsFetcher()
def tearDown(self):
self.target_fileobj.close()
# Remove temporary directory
unittest_toolbox.Modified_TestCase.tearDown(self)
# Simple fetch.
def test_fetch(self):
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
temp_file.seek(0)
self.assertEqual(
self.file_contents, temp_file.read().decode("utf-8")
)
# URL data downloaded in more than one chunk
def test_fetch_in_chunks(self):
# Set a smaller chunk size to ensure that the file will be downloaded
# in more than one chunk
self.fetcher.chunk_size = 4
# expected_chunks_count: 3
expected_chunks_count = math.ceil(
self.file_length / self.fetcher.chunk_size
)
self.assertEqual(expected_chunks_count, 3)
chunks_count = 0
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
chunks_count += 1
temp_file.seek(0)
self.assertEqual(
self.file_contents, temp_file.read().decode("utf-8")
)
# Check that we calculate chunks as expected
self.assertEqual(chunks_count, expected_chunks_count)
# Incorrect URL parsing
def test_url_parsing(self):
with self.assertRaises(exceptions.URLParsingError):
self.fetcher.fetch(self.random_string())
# File not found error
def test_http_error(self):
with self.assertRaises(exceptions.FetcherHTTPError) as cm:
self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/non-existing-path"
self.fetcher.fetch(self.url)
self.assertEqual(cm.exception.status_code, 404)
# Simple bytes download
def test_download_bytes(self):
data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data.decode("utf-8"))
# Download file smaller than required max_length
def test_download_bytes_upper_length(self):
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
self.assertEqual(self.file_contents, data.decode("utf-8"))
# Download a file bigger than expected
def test_download_bytes_length_mismatch(self):
with self.assertRaises(exceptions.DownloadLengthMismatchError):
self.fetcher.download_bytes(self.url, self.file_length - 4)
# Simple file download
def test_download_file(self):
with self.fetcher.download_file(
self.url, self.file_length
) as temp_file:
temp_file.seek(0, io.SEEK_END)
self.assertEqual(self.file_length, temp_file.tell())
# Download file smaller than required max_length
def test_download_file_upper_length(self):
with self.fetcher.download_file(
self.url, self.file_length + 4
) as temp_file:
temp_file.seek(0, io.SEEK_END)
self.assertEqual(self.file_length, temp_file.tell())
# Download a file bigger than expected
def test_download_file_length_mismatch(self):
with self.assertRaises(exceptions.DownloadLengthMismatchError):
yield self.fetcher.download_file(self.url, self.file_length - 4)
# Run unit test.
if __name__ == "__main__":
utils.configure_test_logging(sys.argv)
unittest.main()

View file

@ -53,15 +53,11 @@ def __init__(self) -> None:
self.chunk_size: int = 400000 # bytes
self.sleep_before_round: Optional[int] = None
def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
"""Fetches the contents of HTTP/HTTPS url from a remote server.
Ensures the length of the downloaded data is up to 'max_length'.
def fetch(self, url: str) -> Iterator[bytes]:
"""Fetches the contents of HTTP/HTTPS url from a remote server
Arguments:
url: A URL string that represents a file location.
max_length: An integer value representing the maximum
number of bytes to be downloaded.
Raises:
exceptions.SlowRetrievalError: A timeout occurs while receiving
@ -90,17 +86,14 @@ def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
status = e.response.status_code
raise exceptions.FetcherHTTPError(str(e), status)
return self._chunks(response, max_length)
return self._chunks(response)
def _chunks(
self, response: "requests.Response", max_length: int
) -> Iterator[bytes]:
def _chunks(self, response: "requests.Response") -> Iterator[bytes]:
"""A generator function to be returned by fetch. This way the
caller of fetch can differentiate between connection and actual data
download."""
try:
bytes_received = 0
while True:
# We download a fixed chunk of data in every round. This is
# so that we can defend against slow retrieval attacks.
@ -111,35 +104,19 @@ def _chunks(
if self.sleep_before_round:
time.sleep(self.sleep_before_round)
read_amount = min(
self.chunk_size,
max_length - bytes_received,
)
# NOTE: This may not handle some servers adding a
# Content-Encoding header, which may cause urllib3 to
# misbehave:
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
data = response.raw.read(read_amount)
bytes_received += len(data)
data = response.raw.read(self.chunk_size)
# We might have no more data to read. Check number of bytes
# downloaded.
# We might have no more data to read, we signal
# that the download is complete.
if not data:
# Finally, we signal that the download is complete.
break
yield data
if bytes_received >= max_length:
break
logger.debug(
"Downloaded %d out of %d bytes",
bytes_received,
max_length,
)
except urllib3.exceptions.ReadTimeoutError as e:
raise exceptions.SlowRetrievalError from e

View file

@ -29,15 +29,11 @@ class FetcherInterface:
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
def fetch(self, url: str) -> Iterator[bytes]:
"""Fetches the contents of HTTP/HTTPS url from a remote server.
Ensures the length of the downloaded data is up to 'max_length'.
Arguments:
url: A URL string that represents a file location.
max_length: An integer value representing the maximum
number of bytes to be downloaded.
Raises:
tuf.exceptions.SlowRetrievalError: A timeout occurs while receiving
@ -77,14 +73,22 @@ def download_file(self, url: str, max_length: int) -> Iterator[IO]:
number_of_bytes_received = 0
with tempfile.TemporaryFile() as temp_file:
chunks = self.fetch(url, max_length)
chunks = self.fetch(url)
for chunk in chunks:
temp_file.write(chunk)
number_of_bytes_received += len(chunk)
if number_of_bytes_received > max_length:
raise exceptions.DownloadLengthMismatchError(
max_length, number_of_bytes_received
)
if number_of_bytes_received > max_length:
raise exceptions.DownloadLengthMismatchError(
max_length, number_of_bytes_received
)
temp_file.write(chunk)
logger.debug(
"Downloaded %d out of %d bytes",
number_of_bytes_received,
max_length,
)
temp_file.seek(0)
yield temp_file