mirror of
https://github.com/theupdateframework/python-tuf
synced 2026-05-24 10:08:28 +00:00
183 lines
6.2 KiB
Python
183 lines
6.2 KiB
Python
# Copyright 2021, New York University and the TUF contributors
|
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
|
|
|
"""Unit test for Urllib3Fetcher."""
|
|
|
|
import io
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from typing import ClassVar
|
|
from unittest.mock import Mock, patch
|
|
|
|
import urllib3
|
|
|
|
from tests import utils
|
|
from tuf.api import exceptions
|
|
from tuf.ngclient import Urllib3Fetcher
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TestFetcher(unittest.TestCase):
|
|
"""Test Urllib3Fetcher class."""
|
|
|
|
server_process_handler: ClassVar[utils.TestServerProcess]
|
|
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
"""
|
|
Create a temporary file and launch a simple server in the
|
|
current working directory.
|
|
"""
|
|
cls.server_process_handler = utils.TestServerProcess(log=logger)
|
|
|
|
cls.file_contents = b"junk data"
|
|
cls.file_length = len(cls.file_contents)
|
|
with tempfile.NamedTemporaryFile(
|
|
dir=os.getcwd(), delete=False
|
|
) as cls.target_file:
|
|
cls.target_file.write(cls.file_contents)
|
|
|
|
cls.url_prefix = (
|
|
f"http://{utils.TEST_HOST_ADDRESS}:"
|
|
f"{cls.server_process_handler.port!s}"
|
|
)
|
|
target_filename = os.path.basename(cls.target_file.name)
|
|
cls.url = f"{cls.url_prefix}/{target_filename}"
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
# Stop server process and perform clean up.
|
|
cls.server_process_handler.clean()
|
|
os.remove(cls.target_file.name)
|
|
|
|
def setUp(self) -> None:
|
|
# Instantiate a concrete instance of FetcherInterface
|
|
self.fetcher = Urllib3Fetcher()
|
|
|
|
# Simple fetch.
|
|
def test_fetch(self) -> None:
|
|
with tempfile.TemporaryFile() as temp_file:
|
|
for chunk in self.fetcher.fetch(self.url):
|
|
temp_file.write(chunk)
|
|
|
|
temp_file.seek(0)
|
|
self.assertEqual(self.file_contents, temp_file.read())
|
|
|
|
# URL data downloaded in more than one chunk
|
|
def test_fetch_in_chunks(self) -> None:
|
|
# Set a smaller chunk size to ensure that the file will be downloaded
|
|
# in more than one chunk
|
|
self.fetcher.chunk_size = 4
|
|
|
|
# expected_chunks_count: 3 (depends on length of self.file_length)
|
|
expected_chunks_count = math.ceil(
|
|
self.file_length / self.fetcher.chunk_size
|
|
)
|
|
self.assertEqual(expected_chunks_count, 3)
|
|
|
|
chunks_count = 0
|
|
with tempfile.TemporaryFile() as temp_file:
|
|
for chunk in self.fetcher.fetch(self.url):
|
|
temp_file.write(chunk)
|
|
chunks_count += 1
|
|
|
|
temp_file.seek(0)
|
|
self.assertEqual(self.file_contents, temp_file.read())
|
|
# Check that we calculate chunks as expected
|
|
self.assertEqual(chunks_count, expected_chunks_count)
|
|
|
|
# Incorrect URL parsing
|
|
def test_url_parsing(self) -> None:
|
|
with self.assertRaises(exceptions.DownloadError):
|
|
self.fetcher.fetch("http://invalid/")
|
|
|
|
# File not found error
|
|
def test_http_error(self) -> None:
|
|
with self.assertRaises(exceptions.DownloadHTTPError) as cm:
|
|
self.url = f"{self.url_prefix}/non-existing-path"
|
|
self.fetcher.fetch(self.url)
|
|
self.assertEqual(cm.exception.status_code, 404)
|
|
|
|
# Response read timeout error
|
|
@patch.object(urllib3.PoolManager, "request")
|
|
def test_response_read_timeout(self, mock_session_get: Mock) -> None:
|
|
mock_response = Mock()
|
|
mock_response.status = 200
|
|
attr = {
|
|
"stream.side_effect": urllib3.exceptions.MaxRetryError(
|
|
urllib3.connectionpool.ConnectionPool("localhost"),
|
|
"",
|
|
urllib3.exceptions.TimeoutError(),
|
|
)
|
|
}
|
|
mock_response.configure_mock(**attr)
|
|
mock_session_get.return_value = mock_response
|
|
|
|
with self.assertRaises(exceptions.SlowRetrievalError):
|
|
next(self.fetcher.fetch(self.url))
|
|
mock_response.stream.assert_called_once()
|
|
|
|
# Read/connect session timeout error
|
|
@patch.object(
|
|
urllib3.PoolManager,
|
|
"request",
|
|
side_effect=urllib3.exceptions.MaxRetryError(
|
|
urllib3.connectionpool.ConnectionPool("localhost"),
|
|
"",
|
|
urllib3.exceptions.TimeoutError(),
|
|
),
|
|
)
|
|
def test_session_get_timeout(self, mock_session_get: Mock) -> None:
|
|
with self.assertRaises(exceptions.SlowRetrievalError):
|
|
self.fetcher.fetch(self.url)
|
|
mock_session_get.assert_called_once()
|
|
|
|
# Simple bytes download
|
|
def test_download_bytes(self) -> None:
|
|
data = self.fetcher.download_bytes(self.url, self.file_length)
|
|
self.assertEqual(self.file_contents, data)
|
|
|
|
# Download file smaller than required max_length
|
|
def test_download_bytes_upper_length(self) -> None:
|
|
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
|
|
self.assertEqual(self.file_contents, data)
|
|
|
|
# Download a file bigger than expected
|
|
def test_download_bytes_length_mismatch(self) -> None:
|
|
with self.assertRaises(exceptions.DownloadLengthMismatchError):
|
|
self.fetcher.download_bytes(self.url, self.file_length - 4)
|
|
|
|
# Simple file download
|
|
def test_download_file(self) -> None:
|
|
with self.fetcher.download_file(
|
|
self.url, self.file_length
|
|
) as temp_file:
|
|
temp_file.seek(0, io.SEEK_END)
|
|
self.assertEqual(self.file_length, temp_file.tell())
|
|
|
|
# Download file smaller than required max_length
|
|
def test_download_file_upper_length(self) -> None:
|
|
with self.fetcher.download_file(
|
|
self.url, self.file_length + 4
|
|
) as temp_file:
|
|
temp_file.seek(0, io.SEEK_END)
|
|
self.assertEqual(self.file_length, temp_file.tell())
|
|
|
|
# Download a file bigger than expected
|
|
def test_download_file_length_mismatch(self) -> None:
|
|
with (
|
|
self.assertRaises(exceptions.DownloadLengthMismatchError),
|
|
self.fetcher.download_file(self.url, self.file_length - 4),
|
|
):
|
|
pass # we never get here as download_file() raises
|
|
|
|
|
|
# Run unit test.
|
|
if __name__ == "__main__":
|
|
utils.configure_test_logging(sys.argv)
|
|
unittest.main()
|