diff --git a/tests/test_fetcher_ng.py b/tests/test_fetcher_ng.py new file mode 100644 index 00000000..b5714452 --- /dev/null +++ b/tests/test_fetcher_ng.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2021, New York University and the TUF contributors +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Unit test for RequestsFetcher. +""" + +import io +import logging +import os +import sys +import unittest +import tempfile +import math + +from tests import utils +from tuf import exceptions, unittest_toolbox +from tuf.ngclient._internal.requests_fetcher import RequestsFetcher + +logger = logging.getLogger(__name__) + + +class TestFetcher(unittest_toolbox.Modified_TestCase): + + @classmethod + def setUpClass(cls): + # Launch a SimpleHTTPServer (serves files in the current dir). + cls.server_process_handler = utils.TestServerProcess(log=logger) + + @classmethod + def tearDownClass(cls): + # Stop server process and perform clean up. + cls.server_process_handler.clean() + + def setUp(self): + """ + Create a temporary file and launch a simple server in the + current working directory. + """ + + unittest_toolbox.Modified_TestCase.setUp(self) + + # Making a temporary data file. + current_dir = os.getcwd() + target_filepath = self.make_temp_data_file(directory=current_dir) + + self.target_fileobj = open(target_filepath, "r") + self.file_contents = self.target_fileobj.read() + self.file_length = len(self.file_contents) + self.rel_target_filepath = os.path.basename(target_filepath) + self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/{self.rel_target_filepath}" + + # Instantiate a concrete instance of FetcherInterface + self.fetcher = RequestsFetcher() + + def tearDown(self): + self.target_fileobj.close() + # Remove temporary directory + unittest_toolbox.Modified_TestCase.tearDown(self) + + # Simple fetch. + def test_fetch(self): + with tempfile.TemporaryFile() as temp_file: + for chunk in self.fetcher.fetch(self.url): + temp_file.write(chunk) + + temp_file.seek(0) + self.assertEqual( + self.file_contents, temp_file.read().decode("utf-8") + ) + + # URL data downloaded in more than one chunk + def test_fetch_in_chunks(self): + # Set a smaller chunk size to ensure that the file will be downloaded + # in more than one chunk + self.fetcher.chunk_size = 4 + + # expected_chunks_count: 3 + expected_chunks_count = math.ceil( + self.file_length / self.fetcher.chunk_size + ) + self.assertEqual(expected_chunks_count, 3) + + chunks_count = 0 + with tempfile.TemporaryFile() as temp_file: + for chunk in self.fetcher.fetch(self.url): + temp_file.write(chunk) + chunks_count += 1 + + temp_file.seek(0) + self.assertEqual( + self.file_contents, temp_file.read().decode("utf-8") + ) + # Check that we calculate chunks as expected + self.assertEqual(chunks_count, expected_chunks_count) + + # Incorrect URL parsing + def test_url_parsing(self): + with self.assertRaises(exceptions.URLParsingError): + self.fetcher.fetch(self.random_string()) + + # File not found error + def test_http_error(self): + with self.assertRaises(exceptions.FetcherHTTPError) as cm: + self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/non-existing-path" + self.fetcher.fetch(self.url) + self.assertEqual(cm.exception.status_code, 404) + + # Simple bytes download + def test_download_bytes(self): + data = self.fetcher.download_bytes(self.url, self.file_length) + self.assertEqual(self.file_contents, data.decode("utf-8")) + + # Download file smaller than required max_length + def test_download_bytes_upper_length(self): + data = self.fetcher.download_bytes(self.url, self.file_length + 4) + self.assertEqual(self.file_contents, data.decode("utf-8")) + + # Download a file bigger than expected + def test_download_bytes_length_mismatch(self): + with self.assertRaises(exceptions.DownloadLengthMismatchError): + self.fetcher.download_bytes(self.url, self.file_length - 4) + + # Simple file download + def test_download_file(self): + with self.fetcher.download_file( + self.url, self.file_length + ) as temp_file: + temp_file.seek(0, io.SEEK_END) + self.assertEqual(self.file_length, temp_file.tell()) + + # Download file smaller than required max_length + def test_download_file_upper_length(self): + with self.fetcher.download_file( + self.url, self.file_length + 4 + ) as temp_file: + temp_file.seek(0, io.SEEK_END) + self.assertEqual(self.file_length, temp_file.tell()) + + # Download a file bigger than expected + def test_download_file_length_mismatch(self): + with self.assertRaises(exceptions.DownloadLengthMismatchError): + yield self.fetcher.download_file(self.url, self.file_length - 4) + + +# Run unit test. +if __name__ == "__main__": + utils.configure_test_logging(sys.argv) + unittest.main() diff --git a/tuf/ngclient/_internal/requests_fetcher.py b/tuf/ngclient/_internal/requests_fetcher.py index ae68f1a3..3647e58f 100644 --- a/tuf/ngclient/_internal/requests_fetcher.py +++ b/tuf/ngclient/_internal/requests_fetcher.py @@ -53,15 +53,11 @@ def __init__(self) -> None: self.chunk_size: int = 400000 # bytes self.sleep_before_round: Optional[int] = None - def fetch(self, url: str, max_length: int) -> Iterator[bytes]: - """Fetches the contents of HTTP/HTTPS url from a remote server. - - Ensures the length of the downloaded data is up to 'max_length'. + def fetch(self, url: str) -> Iterator[bytes]: + """Fetches the contents of HTTP/HTTPS url from a remote server Arguments: url: A URL string that represents a file location. - max_length: An integer value representing the maximum - number of bytes to be downloaded. Raises: exceptions.SlowRetrievalError: A timeout occurs while receiving @@ -90,17 +86,14 @@ def fetch(self, url: str, max_length: int) -> Iterator[bytes]: status = e.response.status_code raise exceptions.FetcherHTTPError(str(e), status) - return self._chunks(response, max_length) + return self._chunks(response) - def _chunks( - self, response: "requests.Response", max_length: int - ) -> Iterator[bytes]: + def _chunks(self, response: "requests.Response") -> Iterator[bytes]: """A generator function to be returned by fetch. This way the caller of fetch can differentiate between connection and actual data download.""" try: - bytes_received = 0 while True: # We download a fixed chunk of data in every round. This is # so that we can defend against slow retrieval attacks. @@ -111,35 +104,19 @@ def _chunks( if self.sleep_before_round: time.sleep(self.sleep_before_round) - read_amount = min( - self.chunk_size, - max_length - bytes_received, - ) - # NOTE: This may not handle some servers adding a # Content-Encoding header, which may cause urllib3 to # misbehave: # https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582 - data = response.raw.read(read_amount) - bytes_received += len(data) + data = response.raw.read(self.chunk_size) - # We might have no more data to read. Check number of bytes - # downloaded. + # We might have no more data to read, we signal + # that the download is complete. if not data: - # Finally, we signal that the download is complete. break yield data - if bytes_received >= max_length: - break - - logger.debug( - "Downloaded %d out of %d bytes", - bytes_received, - max_length, - ) - except urllib3.exceptions.ReadTimeoutError as e: raise exceptions.SlowRetrievalError from e diff --git a/tuf/ngclient/fetcher.py b/tuf/ngclient/fetcher.py index 89d5a984..6e8f2df2 100644 --- a/tuf/ngclient/fetcher.py +++ b/tuf/ngclient/fetcher.py @@ -29,15 +29,11 @@ class FetcherInterface: __metaclass__ = abc.ABCMeta @abc.abstractmethod - def fetch(self, url: str, max_length: int) -> Iterator[bytes]: + def fetch(self, url: str) -> Iterator[bytes]: """Fetches the contents of HTTP/HTTPS url from a remote server. - Ensures the length of the downloaded data is up to 'max_length'. - Arguments: url: A URL string that represents a file location. - max_length: An integer value representing the maximum - number of bytes to be downloaded. Raises: tuf.exceptions.SlowRetrievalError: A timeout occurs while receiving @@ -77,14 +73,22 @@ def download_file(self, url: str, max_length: int) -> Iterator[IO]: number_of_bytes_received = 0 with tempfile.TemporaryFile() as temp_file: - chunks = self.fetch(url, max_length) + chunks = self.fetch(url) for chunk in chunks: - temp_file.write(chunk) number_of_bytes_received += len(chunk) - if number_of_bytes_received > max_length: - raise exceptions.DownloadLengthMismatchError( - max_length, number_of_bytes_received - ) + if number_of_bytes_received > max_length: + raise exceptions.DownloadLengthMismatchError( + max_length, number_of_bytes_received + ) + + temp_file.write(chunk) + + logger.debug( + "Downloaded %d out of %d bytes", + number_of_bytes_received, + max_length, + ) + temp_file.seek(0) yield temp_file