mirror of
https://github.com/theupdateframework/python-tuf
synced 2026-05-24 10:08:28 +00:00
Merge pull request #1519 from sechkova/fetcher-max-length
Remove max_length parameter from fetch
This commit is contained in:
commit
deec2eaaa0
3 changed files with 172 additions and 41 deletions
150
tests/test_fetcher_ng.py
Normal file
150
tests/test_fetcher_ng.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021, New York University and the TUF contributors
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
"""Unit test for RequestsFetcher.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import tempfile
|
||||
import math
|
||||
|
||||
from tests import utils
|
||||
from tuf import exceptions, unittest_toolbox
|
||||
from tuf.ngclient._internal.requests_fetcher import RequestsFetcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestFetcher(unittest_toolbox.Modified_TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Launch a SimpleHTTPServer (serves files in the current dir).
|
||||
cls.server_process_handler = utils.TestServerProcess(log=logger)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# Stop server process and perform clean up.
|
||||
cls.server_process_handler.clean()
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a temporary file and launch a simple server in the
|
||||
current working directory.
|
||||
"""
|
||||
|
||||
unittest_toolbox.Modified_TestCase.setUp(self)
|
||||
|
||||
# Making a temporary data file.
|
||||
current_dir = os.getcwd()
|
||||
target_filepath = self.make_temp_data_file(directory=current_dir)
|
||||
|
||||
self.target_fileobj = open(target_filepath, "r")
|
||||
self.file_contents = self.target_fileobj.read()
|
||||
self.file_length = len(self.file_contents)
|
||||
self.rel_target_filepath = os.path.basename(target_filepath)
|
||||
self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/{self.rel_target_filepath}"
|
||||
|
||||
# Instantiate a concrete instance of FetcherInterface
|
||||
self.fetcher = RequestsFetcher()
|
||||
|
||||
def tearDown(self):
|
||||
self.target_fileobj.close()
|
||||
# Remove temporary directory
|
||||
unittest_toolbox.Modified_TestCase.tearDown(self)
|
||||
|
||||
# Simple fetch.
|
||||
def test_fetch(self):
|
||||
with tempfile.TemporaryFile() as temp_file:
|
||||
for chunk in self.fetcher.fetch(self.url):
|
||||
temp_file.write(chunk)
|
||||
|
||||
temp_file.seek(0)
|
||||
self.assertEqual(
|
||||
self.file_contents, temp_file.read().decode("utf-8")
|
||||
)
|
||||
|
||||
# URL data downloaded in more than one chunk
|
||||
def test_fetch_in_chunks(self):
|
||||
# Set a smaller chunk size to ensure that the file will be downloaded
|
||||
# in more than one chunk
|
||||
self.fetcher.chunk_size = 4
|
||||
|
||||
# expected_chunks_count: 3
|
||||
expected_chunks_count = math.ceil(
|
||||
self.file_length / self.fetcher.chunk_size
|
||||
)
|
||||
self.assertEqual(expected_chunks_count, 3)
|
||||
|
||||
chunks_count = 0
|
||||
with tempfile.TemporaryFile() as temp_file:
|
||||
for chunk in self.fetcher.fetch(self.url):
|
||||
temp_file.write(chunk)
|
||||
chunks_count += 1
|
||||
|
||||
temp_file.seek(0)
|
||||
self.assertEqual(
|
||||
self.file_contents, temp_file.read().decode("utf-8")
|
||||
)
|
||||
# Check that we calculate chunks as expected
|
||||
self.assertEqual(chunks_count, expected_chunks_count)
|
||||
|
||||
# Incorrect URL parsing
|
||||
def test_url_parsing(self):
|
||||
with self.assertRaises(exceptions.URLParsingError):
|
||||
self.fetcher.fetch(self.random_string())
|
||||
|
||||
# File not found error
|
||||
def test_http_error(self):
|
||||
with self.assertRaises(exceptions.FetcherHTTPError) as cm:
|
||||
self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/non-existing-path"
|
||||
self.fetcher.fetch(self.url)
|
||||
self.assertEqual(cm.exception.status_code, 404)
|
||||
|
||||
# Simple bytes download
|
||||
def test_download_bytes(self):
|
||||
data = self.fetcher.download_bytes(self.url, self.file_length)
|
||||
self.assertEqual(self.file_contents, data.decode("utf-8"))
|
||||
|
||||
# Download file smaller than required max_length
|
||||
def test_download_bytes_upper_length(self):
|
||||
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
|
||||
self.assertEqual(self.file_contents, data.decode("utf-8"))
|
||||
|
||||
# Download a file bigger than expected
|
||||
def test_download_bytes_length_mismatch(self):
|
||||
with self.assertRaises(exceptions.DownloadLengthMismatchError):
|
||||
self.fetcher.download_bytes(self.url, self.file_length - 4)
|
||||
|
||||
# Simple file download
|
||||
def test_download_file(self):
|
||||
with self.fetcher.download_file(
|
||||
self.url, self.file_length
|
||||
) as temp_file:
|
||||
temp_file.seek(0, io.SEEK_END)
|
||||
self.assertEqual(self.file_length, temp_file.tell())
|
||||
|
||||
# Download file smaller than required max_length
|
||||
def test_download_file_upper_length(self):
|
||||
with self.fetcher.download_file(
|
||||
self.url, self.file_length + 4
|
||||
) as temp_file:
|
||||
temp_file.seek(0, io.SEEK_END)
|
||||
self.assertEqual(self.file_length, temp_file.tell())
|
||||
|
||||
# Download a file bigger than expected
|
||||
def test_download_file_length_mismatch(self):
|
||||
with self.assertRaises(exceptions.DownloadLengthMismatchError):
|
||||
yield self.fetcher.download_file(self.url, self.file_length - 4)
|
||||
|
||||
|
||||
# Run unit test.
|
||||
if __name__ == "__main__":
|
||||
utils.configure_test_logging(sys.argv)
|
||||
unittest.main()
|
||||
|
|
@ -53,15 +53,11 @@ def __init__(self) -> None:
|
|||
self.chunk_size: int = 400000 # bytes
|
||||
self.sleep_before_round: Optional[int] = None
|
||||
|
||||
def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
|
||||
"""Fetches the contents of HTTP/HTTPS url from a remote server.
|
||||
|
||||
Ensures the length of the downloaded data is up to 'max_length'.
|
||||
def fetch(self, url: str) -> Iterator[bytes]:
|
||||
"""Fetches the contents of HTTP/HTTPS url from a remote server
|
||||
|
||||
Arguments:
|
||||
url: A URL string that represents a file location.
|
||||
max_length: An integer value representing the maximum
|
||||
number of bytes to be downloaded.
|
||||
|
||||
Raises:
|
||||
exceptions.SlowRetrievalError: A timeout occurs while receiving
|
||||
|
|
@ -90,17 +86,14 @@ def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
|
|||
status = e.response.status_code
|
||||
raise exceptions.FetcherHTTPError(str(e), status)
|
||||
|
||||
return self._chunks(response, max_length)
|
||||
return self._chunks(response)
|
||||
|
||||
def _chunks(
|
||||
self, response: "requests.Response", max_length: int
|
||||
) -> Iterator[bytes]:
|
||||
def _chunks(self, response: "requests.Response") -> Iterator[bytes]:
|
||||
"""A generator function to be returned by fetch. This way the
|
||||
caller of fetch can differentiate between connection and actual data
|
||||
download."""
|
||||
|
||||
try:
|
||||
bytes_received = 0
|
||||
while True:
|
||||
# We download a fixed chunk of data in every round. This is
|
||||
# so that we can defend against slow retrieval attacks.
|
||||
|
|
@ -111,35 +104,19 @@ def _chunks(
|
|||
if self.sleep_before_round:
|
||||
time.sleep(self.sleep_before_round)
|
||||
|
||||
read_amount = min(
|
||||
self.chunk_size,
|
||||
max_length - bytes_received,
|
||||
)
|
||||
|
||||
# NOTE: This may not handle some servers adding a
|
||||
# Content-Encoding header, which may cause urllib3 to
|
||||
# misbehave:
|
||||
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
|
||||
data = response.raw.read(read_amount)
|
||||
bytes_received += len(data)
|
||||
data = response.raw.read(self.chunk_size)
|
||||
|
||||
# We might have no more data to read. Check number of bytes
|
||||
# downloaded.
|
||||
# We might have no more data to read, we signal
|
||||
# that the download is complete.
|
||||
if not data:
|
||||
# Finally, we signal that the download is complete.
|
||||
break
|
||||
|
||||
yield data
|
||||
|
||||
if bytes_received >= max_length:
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Downloaded %d out of %d bytes",
|
||||
bytes_received,
|
||||
max_length,
|
||||
)
|
||||
|
||||
except urllib3.exceptions.ReadTimeoutError as e:
|
||||
raise exceptions.SlowRetrievalError from e
|
||||
|
||||
|
|
|
|||
|
|
@ -29,15 +29,11 @@ class FetcherInterface:
|
|||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
|
||||
def fetch(self, url: str) -> Iterator[bytes]:
|
||||
"""Fetches the contents of HTTP/HTTPS url from a remote server.
|
||||
|
||||
Ensures the length of the downloaded data is up to 'max_length'.
|
||||
|
||||
Arguments:
|
||||
url: A URL string that represents a file location.
|
||||
max_length: An integer value representing the maximum
|
||||
number of bytes to be downloaded.
|
||||
|
||||
Raises:
|
||||
tuf.exceptions.SlowRetrievalError: A timeout occurs while receiving
|
||||
|
|
@ -77,14 +73,22 @@ def download_file(self, url: str, max_length: int) -> Iterator[IO]:
|
|||
number_of_bytes_received = 0
|
||||
|
||||
with tempfile.TemporaryFile() as temp_file:
|
||||
chunks = self.fetch(url, max_length)
|
||||
chunks = self.fetch(url)
|
||||
for chunk in chunks:
|
||||
temp_file.write(chunk)
|
||||
number_of_bytes_received += len(chunk)
|
||||
if number_of_bytes_received > max_length:
|
||||
raise exceptions.DownloadLengthMismatchError(
|
||||
max_length, number_of_bytes_received
|
||||
)
|
||||
if number_of_bytes_received > max_length:
|
||||
raise exceptions.DownloadLengthMismatchError(
|
||||
max_length, number_of_bytes_received
|
||||
)
|
||||
|
||||
temp_file.write(chunk)
|
||||
|
||||
logger.debug(
|
||||
"Downloaded %d out of %d bytes",
|
||||
number_of_bytes_received,
|
||||
max_length,
|
||||
)
|
||||
|
||||
temp_file.seek(0)
|
||||
yield temp_file
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue