mirror of
https://github.com/theupdateframework/python-tuf
synced 2026-05-24 10:08:28 +00:00
Address mypy warnings
This commit includes manual fixes for a lot of mypy warnings. When there were warnings that we are calling non-annotated function in annotated context I decided to add annotations instead of ignoring those warnings. That's how I end up adding annotations in the whole tests/utils.py module. Signed-off-by: Martin Vrachev <mvrachev@vmware.com>
This commit is contained in:
parent
0d4d7f820c
commit
e2deff3148
12 changed files with 224 additions and 177 deletions
|
|
@ -93,11 +93,7 @@ class RepositorySimulator(FetcherInterface):
|
|||
"""Simulates a repository that can be used for testing."""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
def __init__(self):
|
||||
self.md_root: Metadata[Root] = None
|
||||
self.md_timestamp: Metadata[Timestamp] = None
|
||||
self.md_snapshot: Metadata[Snapshot] = None
|
||||
self.md_targets: Metadata[Targets] = None
|
||||
def __init__(self) -> None:
|
||||
self.md_delegates: Dict[str, Metadata[Targets]] = {}
|
||||
|
||||
# other metadata is signed on-demand (when fetched) but roots must be
|
||||
|
|
@ -117,7 +113,7 @@ def __init__(self):
|
|||
# Enable hash-prefixed target file names
|
||||
self.prefix_targets_with_hash = True
|
||||
|
||||
self.dump_dir = None
|
||||
self.dump_dir: Optional[str] = None
|
||||
self.dump_version = 0
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
|
@ -152,12 +148,12 @@ def create_key() -> Tuple[Key, SSlibSigner]:
|
|||
sslib_key = generate_ed25519_key()
|
||||
return Key.from_securesystemslib_key(sslib_key), SSlibSigner(sslib_key)
|
||||
|
||||
def add_signer(self, role: str, signer: SSlibSigner):
|
||||
def add_signer(self, role: str, signer: SSlibSigner) -> None:
|
||||
if role not in self.signers:
|
||||
self.signers[role] = {}
|
||||
self.signers[role][signer.key_dict["keyid"]] = signer
|
||||
|
||||
def _initialize(self):
|
||||
def _initialize(self) -> None:
|
||||
"""Setup a minimal valid repository."""
|
||||
|
||||
targets = Targets(1, SPEC_VER, self.safe_expiry, {}, None)
|
||||
|
|
@ -182,7 +178,7 @@ def _initialize(self):
|
|||
self.md_root = Metadata(root, OrderedDict())
|
||||
self.publish_root()
|
||||
|
||||
def publish_root(self):
|
||||
def publish_root(self) -> None:
|
||||
"""Sign and store a new serialized version of root."""
|
||||
self.md_root.signatures.clear()
|
||||
for signer in self.signers["root"].values():
|
||||
|
|
@ -199,12 +195,12 @@ def fetch(self, url: str) -> Iterator[bytes]:
|
|||
if path.startswith("/metadata/") and path.endswith(".json"):
|
||||
# figure out rolename and version
|
||||
ver_and_name = path[len("/metadata/") :][: -len(".json")]
|
||||
version, _, role = ver_and_name.partition(".")
|
||||
version_str, _, role = ver_and_name.partition(".")
|
||||
# root is always version-prefixed while timestamp is always NOT
|
||||
if role == "root" or (
|
||||
self.root.consistent_snapshot and ver_and_name != "timestamp"
|
||||
):
|
||||
version = int(version)
|
||||
version: Optional[int] = int(version_str)
|
||||
else:
|
||||
# the file is not version-prefixed
|
||||
role = ver_and_name
|
||||
|
|
@ -216,11 +212,10 @@ def fetch(self, url: str) -> Iterator[bytes]:
|
|||
target_path = path[len("/targets/") :]
|
||||
dir_parts, sep, prefixed_filename = target_path.rpartition("/")
|
||||
# extract the hash prefix, if any
|
||||
prefix: Optional[str] = None
|
||||
filename = prefixed_filename
|
||||
if self.root.consistent_snapshot and self.prefix_targets_with_hash:
|
||||
prefix, _, filename = prefixed_filename.partition(".")
|
||||
else:
|
||||
filename = prefixed_filename
|
||||
prefix = None
|
||||
target_path = f"{dir_parts}{sep}{filename}"
|
||||
|
||||
yield self._fetch_target(target_path, prefix)
|
||||
|
|
@ -261,8 +256,9 @@ def _fetch_metadata(
|
|||
return self.signed_roots[version - 1]
|
||||
|
||||
# sign and serialize the requested metadata
|
||||
md: Optional[Metadata]
|
||||
if role == "timestamp":
|
||||
md: Metadata = self.md_timestamp
|
||||
md = self.md_timestamp
|
||||
elif role == "snapshot":
|
||||
md = self.md_snapshot
|
||||
elif role == "targets":
|
||||
|
|
@ -294,7 +290,7 @@ def _compute_hashes_and_length(
|
|||
hashes = {sslib_hash.DEFAULT_HASH_ALGORITHM: digest_object.hexdigest()}
|
||||
return hashes, len(data)
|
||||
|
||||
def update_timestamp(self):
|
||||
def update_timestamp(self) -> None:
|
||||
"""Update timestamp and assign snapshot version to snapshot_meta
|
||||
version.
|
||||
"""
|
||||
|
|
@ -307,7 +303,7 @@ def update_timestamp(self):
|
|||
|
||||
self.timestamp.version += 1
|
||||
|
||||
def update_snapshot(self):
|
||||
def update_snapshot(self) -> None:
|
||||
"""Update snapshot, assign targets versions and update timestamp."""
|
||||
for role, delegate in self.all_targets():
|
||||
hashes = None
|
||||
|
|
@ -322,7 +318,7 @@ def update_snapshot(self):
|
|||
self.snapshot.version += 1
|
||||
self.update_timestamp()
|
||||
|
||||
def add_target(self, role: str, data: bytes, path: str):
|
||||
def add_target(self, role: str, data: bytes, path: str) -> None:
|
||||
"""Create a target from data and add it to the target_files."""
|
||||
if role == "targets":
|
||||
targets = self.targets
|
||||
|
|
@ -341,7 +337,7 @@ def add_delegation(
|
|||
terminating: bool,
|
||||
paths: Optional[List[str]],
|
||||
hash_prefixes: Optional[List[str]],
|
||||
):
|
||||
) -> None:
|
||||
"""Add delegated target role to the repository."""
|
||||
if delegator_name == "targets":
|
||||
delegator = self.targets
|
||||
|
|
@ -351,7 +347,7 @@ def add_delegation(
|
|||
# Create delegation
|
||||
role = DelegatedRole(name, [], 1, terminating, paths, hash_prefixes)
|
||||
if delegator.delegations is None:
|
||||
delegator.delegations = Delegations({}, {})
|
||||
delegator.delegations = Delegations({}, OrderedDict())
|
||||
# put delegation last by default
|
||||
delegator.delegations.roles[role.name] = role
|
||||
|
||||
|
|
@ -363,7 +359,7 @@ def add_delegation(
|
|||
# Add metadata for the role
|
||||
self.md_delegates[role.name] = Metadata(targets, OrderedDict())
|
||||
|
||||
def write(self):
|
||||
def write(self) -> None:
|
||||
"""Dump current repository metadata to self.dump_dir
|
||||
|
||||
This is a debugging tool: dumping repository state before running
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from typing import ClassVar, Dict
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from securesystemslib import hash as sslib_hash
|
||||
|
|
@ -28,6 +29,7 @@
|
|||
from tuf import exceptions
|
||||
from tuf.api.metadata import (
|
||||
DelegatedRole,
|
||||
Delegations,
|
||||
Key,
|
||||
Metadata,
|
||||
MetaFile,
|
||||
|
|
@ -47,8 +49,13 @@
|
|||
class TestMetadata(unittest.TestCase):
|
||||
"""Tests for public API of all classes in 'tuf/api/metadata.py'."""
|
||||
|
||||
temporary_directory: ClassVar[str]
|
||||
repo_dir: ClassVar[str]
|
||||
keystore_dir: ClassVar[str]
|
||||
keystore: ClassVar[Dict[str, str]]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUpClass(cls) -> None:
|
||||
# Create a temporary directory to store the repository, metadata, and
|
||||
# target files. 'temporary_directory' must be deleted in
|
||||
# TearDownClass() so that temporary files are always removed, even when
|
||||
|
|
@ -78,12 +85,12 @@ def setUpClass(cls):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
# Remove the temporary repository directory, which should contain all
|
||||
# the metadata, targets, and key files generated for the test cases.
|
||||
shutil.rmtree(cls.temporary_directory)
|
||||
|
||||
def test_generic_read(self):
|
||||
def test_generic_read(self) -> None:
|
||||
for metadata, inner_metadata_cls in [
|
||||
("root", Root),
|
||||
("snapshot", Snapshot),
|
||||
|
|
@ -120,7 +127,7 @@ def test_generic_read(self):
|
|||
|
||||
os.remove(bad_metadata_path)
|
||||
|
||||
def test_compact_json(self):
|
||||
def test_compact_json(self) -> None:
|
||||
path = os.path.join(self.repo_dir, "metadata", "targets.json")
|
||||
md_obj = Metadata.from_file(path)
|
||||
self.assertTrue(
|
||||
|
|
@ -128,7 +135,7 @@ def test_compact_json(self):
|
|||
< len(JSONSerializer().serialize(md_obj))
|
||||
)
|
||||
|
||||
def test_read_write_read_compare(self):
|
||||
def test_read_write_read_compare(self) -> None:
|
||||
for metadata in ["root", "snapshot", "timestamp", "targets"]:
|
||||
path = os.path.join(self.repo_dir, "metadata", metadata + ".json")
|
||||
md_obj = Metadata.from_file(path)
|
||||
|
|
@ -140,7 +147,7 @@ def test_read_write_read_compare(self):
|
|||
|
||||
os.remove(path_2)
|
||||
|
||||
def test_to_from_bytes(self):
|
||||
def test_to_from_bytes(self) -> None:
|
||||
for metadata in ["root", "snapshot", "timestamp", "targets"]:
|
||||
path = os.path.join(self.repo_dir, "metadata", metadata + ".json")
|
||||
with open(path, "rb") as f:
|
||||
|
|
@ -157,7 +164,7 @@ def test_to_from_bytes(self):
|
|||
metadata_obj_2 = Metadata.from_bytes(obj_bytes)
|
||||
self.assertEqual(metadata_obj_2.to_bytes(), obj_bytes)
|
||||
|
||||
def test_sign_verify(self):
|
||||
def test_sign_verify(self) -> None:
|
||||
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
|
||||
root = Metadata[Root].from_file(root_path).signed
|
||||
|
||||
|
|
@ -183,7 +190,7 @@ def test_sign_verify(self):
|
|||
# Test verifying with explicitly set serializer
|
||||
targets_key.verify_signature(md_obj, CanonicalJSONSerializer())
|
||||
with self.assertRaises(exceptions.UnsignedMetadataError):
|
||||
targets_key.verify_signature(md_obj, JSONSerializer())
|
||||
targets_key.verify_signature(md_obj, JSONSerializer()) # type: ignore[arg-type]
|
||||
|
||||
sslib_signer = SSlibSigner(self.keystore["snapshot"])
|
||||
# Append a new signature with the unrelated key and assert that ...
|
||||
|
|
@ -206,7 +213,7 @@ def test_sign_verify(self):
|
|||
with self.assertRaises(exceptions.UnsignedMetadataError):
|
||||
targets_key.verify_signature(md_obj)
|
||||
|
||||
def test_verify_failures(self):
|
||||
def test_verify_failures(self) -> None:
|
||||
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
|
||||
root = Metadata[Root].from_file(root_path).signed
|
||||
|
||||
|
|
@ -248,7 +255,7 @@ def test_verify_failures(self):
|
|||
timestamp_key.verify_signature(md_obj)
|
||||
sig.signature = correct_sig
|
||||
|
||||
def test_metadata_base(self):
|
||||
def test_metadata_base(self) -> None:
|
||||
# Use of Snapshot is arbitrary, we're just testing the base class
|
||||
# features with real data
|
||||
snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json")
|
||||
|
|
@ -290,7 +297,7 @@ def test_metadata_base(self):
|
|||
with self.assertRaises(ValueError):
|
||||
Metadata.from_dict(data)
|
||||
|
||||
def test_metadata_snapshot(self):
|
||||
def test_metadata_snapshot(self) -> None:
|
||||
snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json")
|
||||
snapshot = Metadata[Snapshot].from_file(snapshot_path)
|
||||
|
||||
|
|
@ -309,7 +316,7 @@ def test_metadata_snapshot(self):
|
|||
snapshot.signed.meta["role1.json"].to_dict(), fileinfo.to_dict()
|
||||
)
|
||||
|
||||
def test_metadata_timestamp(self):
|
||||
def test_metadata_timestamp(self) -> None:
|
||||
timestamp_path = os.path.join(
|
||||
self.repo_dir, "metadata", "timestamp.json"
|
||||
)
|
||||
|
|
@ -349,7 +356,7 @@ def test_metadata_timestamp(self):
|
|||
timestamp.signed.snapshot_meta.to_dict(), fileinfo.to_dict()
|
||||
)
|
||||
|
||||
def test_metadata_verify_delegate(self):
|
||||
def test_metadata_verify_delegate(self) -> None:
|
||||
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
|
||||
root = Metadata[Root].from_file(root_path)
|
||||
snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json")
|
||||
|
|
@ -410,14 +417,14 @@ def test_metadata_verify_delegate(self):
|
|||
snapshot.sign(SSlibSigner(self.keystore["timestamp"]), append=True)
|
||||
root.verify_delegate("snapshot", snapshot)
|
||||
|
||||
def test_key_class(self):
|
||||
def test_key_class(self) -> None:
|
||||
# Test if from_securesystemslib_key removes the private key from keyval
|
||||
# of a securesystemslib key dictionary.
|
||||
sslib_key = generate_ed25519_key()
|
||||
key = Key.from_securesystemslib_key(sslib_key)
|
||||
self.assertFalse("private" in key.keyval.keys())
|
||||
|
||||
def test_root_add_key_and_remove_key(self):
|
||||
def test_root_add_key_and_remove_key(self) -> None:
|
||||
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
|
||||
root = Metadata[Root].from_file(root_path)
|
||||
|
||||
|
|
@ -475,7 +482,7 @@ def test_root_add_key_and_remove_key(self):
|
|||
with self.assertRaises(ValueError):
|
||||
root.signed.remove_key("nosuchrole", keyid)
|
||||
|
||||
def test_is_target_in_pathpattern(self):
|
||||
def test_is_target_in_pathpattern(self) -> None:
|
||||
# pylint: disable=protected-access
|
||||
supported_use_cases = [
|
||||
("foo.tgz", "foo.tgz"),
|
||||
|
|
@ -507,7 +514,7 @@ def test_is_target_in_pathpattern(self):
|
|||
DelegatedRole._is_target_in_pathpattern(targetpath, pathpattern)
|
||||
)
|
||||
|
||||
def test_metadata_targets(self):
|
||||
def test_metadata_targets(self) -> None:
|
||||
targets_path = os.path.join(self.repo_dir, "metadata", "targets.json")
|
||||
targets = Metadata[Targets].from_file(targets_path)
|
||||
|
||||
|
|
@ -531,7 +538,7 @@ def test_metadata_targets(self):
|
|||
targets.signed.targets[filename].to_dict(), fileinfo.to_dict()
|
||||
)
|
||||
|
||||
def test_targets_key_api(self):
|
||||
def test_targets_key_api(self) -> None:
|
||||
targets_path = os.path.join(self.repo_dir, "metadata", "targets.json")
|
||||
targets: Targets = Metadata[Targets].from_file(targets_path).signed
|
||||
|
||||
|
|
@ -545,6 +552,7 @@ def test_targets_key_api(self):
|
|||
"threshold": 1,
|
||||
}
|
||||
)
|
||||
assert isinstance(targets.delegations, Delegations)
|
||||
targets.delegations.roles["role2"] = delegated_role
|
||||
|
||||
key_dict = {
|
||||
|
|
@ -608,7 +616,7 @@ def test_targets_key_api(self):
|
|||
targets.remove_key("role1", key.keyid)
|
||||
self.assertTrue(targets.delegations is None)
|
||||
|
||||
def test_length_and_hash_validation(self):
|
||||
def test_length_and_hash_validation(self) -> None:
|
||||
|
||||
# Test metadata files' hash and length verification.
|
||||
# Use timestamp to get a MetaFile object and snapshot
|
||||
|
|
@ -648,7 +656,7 @@ def test_length_and_hash_validation(self):
|
|||
|
||||
# Test wrong algorithm format (sslib.FormatError)
|
||||
snapshot_metafile.hashes = {
|
||||
256: "8f88e2ba48b412c3843e9bb26e1b6f8fc9e98aceb0fbaa97ba37b4c98717d7ab"
|
||||
256: "8f88e2ba48b412c3843e9bb26e1b6f8fc9e98aceb0fbaa97ba37b4c98717d7ab" # type: ignore[dict-item]
|
||||
}
|
||||
with self.assertRaises(exceptions.LengthOrHashMismatchError):
|
||||
snapshot_metafile.verify_length_and_hashes(data)
|
||||
|
|
@ -678,7 +686,7 @@ def test_length_and_hash_validation(self):
|
|||
with self.assertRaises(exceptions.LengthOrHashMismatchError):
|
||||
file1_targetfile.verify_length_and_hashes(file1)
|
||||
|
||||
def test_targetfile_from_file(self):
|
||||
def test_targetfile_from_file(self) -> None:
|
||||
# Test with an existing file and valid hash algorithm
|
||||
file_path = os.path.join(self.repo_dir, "targets", "file1.txt")
|
||||
targetfile_from_file = TargetFile.from_file(
|
||||
|
|
@ -700,7 +708,7 @@ def test_targetfile_from_file(self):
|
|||
with self.assertRaises(exceptions.UnsupportedAlgorithmError):
|
||||
TargetFile.from_file(file_path, file_path, ["123"])
|
||||
|
||||
def test_targetfile_from_data(self):
|
||||
def test_targetfile_from_data(self) -> None:
|
||||
data = b"Inline test content"
|
||||
target_file_path = os.path.join(self.repo_dir, "targets", "file1.txt")
|
||||
|
||||
|
|
@ -714,7 +722,7 @@ def test_targetfile_from_data(self):
|
|||
targetfile_from_data = TargetFile.from_data(target_file_path, data)
|
||||
targetfile_from_data.verify_length_and_hashes(data)
|
||||
|
||||
def test_is_delegated_role(self):
|
||||
def test_is_delegated_role(self) -> None:
|
||||
# test path matches
|
||||
# see more extensive tests in test_is_target_in_pathpattern()
|
||||
for paths in [
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, ClassVar, Iterator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import requests
|
||||
|
|
@ -28,17 +29,19 @@
|
|||
class TestFetcher(unittest_toolbox.Modified_TestCase):
|
||||
"""Test RequestsFetcher class."""
|
||||
|
||||
server_process_handler: ClassVar[utils.TestServerProcess]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUpClass(cls) -> None:
|
||||
# Launch a SimpleHTTPServer (serves files in the current dir).
|
||||
cls.server_process_handler = utils.TestServerProcess(log=logger)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
# Stop server process and perform clean up.
|
||||
cls.server_process_handler.clean()
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Create a temporary file and launch a simple server in the
|
||||
current working directory.
|
||||
|
|
@ -64,12 +67,12 @@ def setUp(self):
|
|||
# Instantiate a concrete instance of FetcherInterface
|
||||
self.fetcher = RequestsFetcher()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
# Remove temporary directory
|
||||
unittest_toolbox.Modified_TestCase.tearDown(self)
|
||||
|
||||
# Simple fetch.
|
||||
def test_fetch(self):
|
||||
def test_fetch(self) -> None:
|
||||
with tempfile.TemporaryFile() as temp_file:
|
||||
for chunk in self.fetcher.fetch(self.url):
|
||||
temp_file.write(chunk)
|
||||
|
|
@ -80,7 +83,7 @@ def test_fetch(self):
|
|||
)
|
||||
|
||||
# URL data downloaded in more than one chunk
|
||||
def test_fetch_in_chunks(self):
|
||||
def test_fetch_in_chunks(self) -> None:
|
||||
# Set a smaller chunk size to ensure that the file will be downloaded
|
||||
# in more than one chunk
|
||||
self.fetcher.chunk_size = 4
|
||||
|
|
@ -105,12 +108,12 @@ def test_fetch_in_chunks(self):
|
|||
self.assertEqual(chunks_count, expected_chunks_count)
|
||||
|
||||
# Incorrect URL parsing
|
||||
def test_url_parsing(self):
|
||||
def test_url_parsing(self) -> None:
|
||||
with self.assertRaises(exceptions.URLParsingError):
|
||||
self.fetcher.fetch(self.random_string())
|
||||
|
||||
# File not found error
|
||||
def test_http_error(self):
|
||||
def test_http_error(self) -> None:
|
||||
with self.assertRaises(exceptions.FetcherHTTPError) as cm:
|
||||
self.url = f"{self.url_prefix}/non-existing-path"
|
||||
self.fetcher.fetch(self.url)
|
||||
|
|
@ -118,7 +121,7 @@ def test_http_error(self):
|
|||
|
||||
# Response read timeout error
|
||||
@patch.object(requests.Session, "get")
|
||||
def test_response_read_timeout(self, mock_session_get):
|
||||
def test_response_read_timeout(self, mock_session_get: Any) -> None:
|
||||
mock_response = Mock()
|
||||
attr = {
|
||||
"raw.read.side_effect": urllib3.exceptions.ReadTimeoutError(
|
||||
|
|
@ -136,28 +139,28 @@ def test_response_read_timeout(self, mock_session_get):
|
|||
@patch.object(
|
||||
requests.Session, "get", side_effect=urllib3.exceptions.TimeoutError
|
||||
)
|
||||
def test_session_get_timeout(self, mock_session_get):
|
||||
def test_session_get_timeout(self, mock_session_get: Any) -> None:
|
||||
with self.assertRaises(exceptions.SlowRetrievalError):
|
||||
self.fetcher.fetch(self.url)
|
||||
mock_session_get.assert_called_once()
|
||||
|
||||
# Simple bytes download
|
||||
def test_download_bytes(self):
|
||||
def test_download_bytes(self) -> None:
|
||||
data = self.fetcher.download_bytes(self.url, self.file_length)
|
||||
self.assertEqual(self.file_contents, data.decode("utf-8"))
|
||||
|
||||
# Download file smaller than required max_length
|
||||
def test_download_bytes_upper_length(self):
|
||||
def test_download_bytes_upper_length(self) -> None:
|
||||
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
|
||||
self.assertEqual(self.file_contents, data.decode("utf-8"))
|
||||
|
||||
# Download a file bigger than expected
|
||||
def test_download_bytes_length_mismatch(self):
|
||||
def test_download_bytes_length_mismatch(self) -> None:
|
||||
with self.assertRaises(exceptions.DownloadLengthMismatchError):
|
||||
self.fetcher.download_bytes(self.url, self.file_length - 4)
|
||||
|
||||
# Simple file download
|
||||
def test_download_file(self):
|
||||
def test_download_file(self) -> None:
|
||||
with self.fetcher.download_file(
|
||||
self.url, self.file_length
|
||||
) as temp_file:
|
||||
|
|
@ -165,7 +168,7 @@ def test_download_file(self):
|
|||
self.assertEqual(self.file_length, temp_file.tell())
|
||||
|
||||
# Download file smaller than required max_length
|
||||
def test_download_file_upper_length(self):
|
||||
def test_download_file_upper_length(self) -> None:
|
||||
with self.fetcher.download_file(
|
||||
self.url, self.file_length + 4
|
||||
) as temp_file:
|
||||
|
|
@ -173,8 +176,10 @@ def test_download_file_upper_length(self):
|
|||
self.assertEqual(self.file_length, temp_file.tell())
|
||||
|
||||
# Download a file bigger than expected
|
||||
def test_download_file_length_mismatch(self):
|
||||
def test_download_file_length_mismatch(self) -> Iterator[Any]:
|
||||
with self.assertRaises(exceptions.DownloadLengthMismatchError):
|
||||
# Force download_file to execute and raise the error since it is a
|
||||
# context manager and returns Iterator[IO]
|
||||
yield self.fetcher.download_file(self.url, self.file_length - 4)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class TestSerialization(unittest.TestCase):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_signed)
|
||||
def test_invalid_signed_serialization(self, test_case_data: str):
|
||||
def test_invalid_signed_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises((KeyError, ValueError, TypeError)):
|
||||
Snapshot.from_dict(copy.deepcopy(case_dict))
|
||||
|
|
@ -68,7 +68,7 @@ def test_invalid_signed_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_keys)
|
||||
def test_valid_key_serialization(self, test_case_data: str):
|
||||
def test_valid_key_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
key = Key.from_dict("id", copy.copy(case_dict))
|
||||
self.assertDictEqual(case_dict, key.to_dict())
|
||||
|
|
@ -85,7 +85,7 @@ def test_valid_key_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_keys)
|
||||
def test_invalid_key_serialization(self, test_case_data: str):
|
||||
def test_invalid_key_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises((TypeError, KeyError)):
|
||||
keyid = case_dict.pop("keyid")
|
||||
|
|
@ -100,7 +100,7 @@ def test_invalid_key_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_roles)
|
||||
def test_invalid_role_serialization(self, test_case_data: str):
|
||||
def test_invalid_role_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises((KeyError, TypeError, ValueError)):
|
||||
Role.from_dict(copy.deepcopy(case_dict))
|
||||
|
|
@ -113,7 +113,7 @@ def test_invalid_role_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_roles)
|
||||
def test_role_serialization(self, test_case_data: str):
|
||||
def test_role_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
role = Role.from_dict(copy.deepcopy(case_dict))
|
||||
self.assertDictEqual(case_dict, role.to_dict())
|
||||
|
|
@ -161,7 +161,7 @@ def test_role_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_roots)
|
||||
def test_root_serialization(self, test_case_data: str):
|
||||
def test_root_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
root = Root.from_dict(copy.deepcopy(case_dict))
|
||||
self.assertDictEqual(case_dict, root.to_dict())
|
||||
|
|
@ -203,7 +203,7 @@ def test_root_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_roots)
|
||||
def test_invalid_root_serialization(self, test_case_data: str):
|
||||
def test_invalid_root_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises(ValueError):
|
||||
Root.from_dict(copy.deepcopy(case_dict))
|
||||
|
|
@ -218,7 +218,7 @@ def test_invalid_root_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_metafiles)
|
||||
def test_invalid_metafile_serialization(self, test_case_data: str):
|
||||
def test_invalid_metafile_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises((TypeError, ValueError, AttributeError)):
|
||||
MetaFile.from_dict(copy.deepcopy(case_dict))
|
||||
|
|
@ -232,7 +232,7 @@ def test_invalid_metafile_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_metafiles)
|
||||
def test_metafile_serialization(self, test_case_data: str):
|
||||
def test_metafile_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
metafile = MetaFile.from_dict(copy.copy(case_dict))
|
||||
self.assertDictEqual(case_dict, metafile.to_dict())
|
||||
|
|
@ -242,7 +242,7 @@ def test_metafile_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_timestamps)
|
||||
def test_invalid_timestamp_serialization(self, test_case_data: str):
|
||||
def test_invalid_timestamp_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises((ValueError, KeyError)):
|
||||
Timestamp.from_dict(copy.deepcopy(case_dict))
|
||||
|
|
@ -255,7 +255,7 @@ def test_invalid_timestamp_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_timestamps)
|
||||
def test_timestamp_serialization(self, test_case_data: str):
|
||||
def test_timestamp_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
timestamp = Timestamp.from_dict(copy.deepcopy(case_dict))
|
||||
self.assertDictEqual(case_dict, timestamp.to_dict())
|
||||
|
|
@ -274,7 +274,7 @@ def test_timestamp_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_snapshots)
|
||||
def test_snapshot_serialization(self, test_case_data: str):
|
||||
def test_snapshot_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
snapshot = Snapshot.from_dict(copy.deepcopy(case_dict))
|
||||
self.assertDictEqual(case_dict, snapshot.to_dict())
|
||||
|
|
@ -295,7 +295,7 @@ def test_snapshot_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_delegated_roles)
|
||||
def test_delegated_role_serialization(self, test_case_data: str):
|
||||
def test_delegated_role_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
deserialized_role = DelegatedRole.from_dict(copy.copy(case_dict))
|
||||
self.assertDictEqual(case_dict, deserialized_role.to_dict())
|
||||
|
|
@ -312,7 +312,9 @@ def test_delegated_role_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_delegated_roles)
|
||||
def test_invalid_delegated_role_serialization(self, test_case_data: str):
|
||||
def test_invalid_delegated_role_serialization(
|
||||
self, test_case_data: str
|
||||
) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises(ValueError):
|
||||
DelegatedRole.from_dict(copy.copy(case_dict))
|
||||
|
|
@ -339,7 +341,9 @@ def test_invalid_delegated_role_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_delegations)
|
||||
def test_invalid_delegation_serialization(self, test_case_data: str):
|
||||
def test_invalid_delegation_serialization(
|
||||
self, test_case_data: str
|
||||
) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises((ValueError, KeyError, AttributeError)):
|
||||
Delegations.from_dict(copy.deepcopy(case_dict))
|
||||
|
|
@ -361,7 +365,7 @@ def test_invalid_delegation_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_delegations)
|
||||
def test_delegation_serialization(self, test_case_data: str):
|
||||
def test_delegation_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
delegation = Delegations.from_dict(copy.deepcopy(case_dict))
|
||||
self.assertDictEqual(case_dict, delegation.to_dict())
|
||||
|
|
@ -374,7 +378,9 @@ def test_delegation_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(invalid_targetfiles)
|
||||
def test_invalid_targetfile_serialization(self, test_case_data: str):
|
||||
def test_invalid_targetfile_serialization(
|
||||
self, test_case_data: str
|
||||
) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
with self.assertRaises(KeyError):
|
||||
TargetFile.from_dict(copy.deepcopy(case_dict), "file1.txt")
|
||||
|
|
@ -388,7 +394,7 @@ def test_invalid_targetfile_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_targetfiles)
|
||||
def test_targetfile_serialization(self, test_case_data: str):
|
||||
def test_targetfile_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
target_file = TargetFile.from_dict(copy.copy(case_dict), "file1.txt")
|
||||
self.assertDictEqual(case_dict, target_file.to_dict())
|
||||
|
|
@ -420,7 +426,7 @@ def test_targetfile_serialization(self, test_case_data: str):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(valid_targets)
|
||||
def test_targets_serialization(self, test_case_data: str):
|
||||
def test_targets_serialization(self, test_case_data: str) -> None:
|
||||
case_dict = json.loads(test_case_data)
|
||||
targets = Targets.from_dict(copy.deepcopy(case_dict))
|
||||
self.assertDictEqual(case_dict, targets.to_dict())
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import sys
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple
|
||||
|
||||
from securesystemslib.interface import (
|
||||
import_ed25519_privatekey_from_file,
|
||||
|
|
@ -18,7 +18,6 @@
|
|||
Metadata,
|
||||
MetaFile,
|
||||
Root,
|
||||
Signed,
|
||||
Snapshot,
|
||||
Targets,
|
||||
Timestamp,
|
||||
|
|
@ -31,8 +30,13 @@
|
|||
class TestTrustedMetadataSet(unittest.TestCase):
|
||||
"""Tests for all public API of the TrustedMetadataSet class."""
|
||||
|
||||
keystore: ClassVar[Dict[str, SSlibSigner]]
|
||||
metadata: ClassVar[Dict[str, bytes]]
|
||||
repo_dir: ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def modify_metadata(
|
||||
self, rolename: str, modification_func: Callable[[Signed], None]
|
||||
cls, rolename: str, modification_func: Callable
|
||||
) -> bytes:
|
||||
"""Instantiate metadata from rolename type, call modification_func and
|
||||
sign it again with self.keystore[rolename] signer.
|
||||
|
|
@ -42,13 +46,13 @@ def modify_metadata(
|
|||
modification_func: Function that will be called to modify the signed
|
||||
portion of metadata bytes.
|
||||
"""
|
||||
metadata = Metadata.from_bytes(self.metadata[rolename])
|
||||
metadata = Metadata.from_bytes(cls.metadata[rolename])
|
||||
modification_func(metadata.signed)
|
||||
metadata.sign(self.keystore[rolename])
|
||||
metadata.sign(cls.keystore[rolename])
|
||||
return metadata.to_bytes()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUpClass(cls) -> None:
|
||||
cls.repo_dir = os.path.join(
|
||||
os.getcwd(), "repository_data", "repository", "metadata"
|
||||
)
|
||||
|
|
@ -81,7 +85,7 @@ def hashes_length_modifier(timestamp: Timestamp) -> None:
|
|||
timestamp.snapshot_meta.length = None
|
||||
|
||||
cls.metadata["timestamp"] = cls.modify_metadata(
|
||||
cls, "timestamp", hashes_length_modifier
|
||||
"timestamp", hashes_length_modifier
|
||||
)
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
|
@ -91,7 +95,7 @@ def _update_all_besides_targets(
|
|||
self,
|
||||
timestamp_bytes: Optional[bytes] = None,
|
||||
snapshot_bytes: Optional[bytes] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Update all metadata roles besides targets.
|
||||
|
||||
Args:
|
||||
|
|
@ -109,7 +113,7 @@ def _update_all_besides_targets(
|
|||
snapshot_bytes = snapshot_bytes or self.metadata["snapshot"]
|
||||
self.trusted_set.update_snapshot(snapshot_bytes)
|
||||
|
||||
def test_update(self):
|
||||
def test_update(self) -> None:
|
||||
self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
self.trusted_set.update_snapshot(self.metadata["snapshot"])
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
|
@ -129,8 +133,10 @@ def test_update(self):
|
|||
|
||||
self.assertTrue(count, 6)
|
||||
|
||||
def test_update_metadata_output(self):
|
||||
timestamp = self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
def test_update_metadata_output(self) -> None:
|
||||
timestamp = self.trusted_set.update_timestamp(
|
||||
self.metadata["timestamp"]
|
||||
)
|
||||
snapshot = self.trusted_set.update_snapshot(self.metadata["snapshot"])
|
||||
targets = self.trusted_set.update_targets(self.metadata["targets"])
|
||||
delegeted_targets_1 = self.trusted_set.update_delegated_targets(
|
||||
|
|
@ -145,7 +151,7 @@ def test_update_metadata_output(self):
|
|||
self.assertIsInstance(delegeted_targets_1.signed, Targets)
|
||||
self.assertIsInstance(delegeted_targets_2.signed, Targets)
|
||||
|
||||
def test_out_of_order_ops(self):
|
||||
def test_out_of_order_ops(self) -> None:
|
||||
# Update snapshot before timestamp
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.trusted_set.update_snapshot(self.metadata["snapshot"])
|
||||
|
|
@ -182,7 +188,7 @@ def test_out_of_order_ops(self):
|
|||
self.metadata["role1"], "role1", "targets"
|
||||
)
|
||||
|
||||
def test_root_with_invalid_json(self):
|
||||
def test_root_with_invalid_json(self) -> None:
|
||||
# Test loading initial root and root update
|
||||
for test_func in [TrustedMetadataSet, self.trusted_set.update_root]:
|
||||
# root is not json
|
||||
|
|
@ -199,8 +205,8 @@ def test_root_with_invalid_json(self):
|
|||
with self.assertRaises(exceptions.RepositoryError):
|
||||
test_func(self.metadata["snapshot"])
|
||||
|
||||
def test_top_level_md_with_invalid_json(self):
|
||||
top_level_md = [
|
||||
def test_top_level_md_with_invalid_json(self) -> None:
|
||||
top_level_md: List[Tuple[bytes, Callable[[bytes], Metadata]]] = [
|
||||
(self.metadata["timestamp"], self.trusted_set.update_timestamp),
|
||||
(self.metadata["snapshot"], self.trusted_set.update_snapshot),
|
||||
(self.metadata["targets"], self.trusted_set.update_targets),
|
||||
|
|
@ -222,7 +228,7 @@ def test_top_level_md_with_invalid_json(self):
|
|||
|
||||
update_func(metadata)
|
||||
|
||||
def test_update_root_new_root(self):
|
||||
def test_update_root_new_root(self) -> None:
|
||||
# test that root can be updated with a new valid version
|
||||
def root_new_version_modifier(root: Root) -> None:
|
||||
root.version += 1
|
||||
|
|
@ -230,19 +236,19 @@ def root_new_version_modifier(root: Root) -> None:
|
|||
root = self.modify_metadata("root", root_new_version_modifier)
|
||||
self.trusted_set.update_root(root)
|
||||
|
||||
def test_update_root_new_root_cannot_be_verified_with_threshold(self):
|
||||
def test_update_root_new_root_fail_threshold_verification(self) -> None:
|
||||
# new_root data with threshold which cannot be verified.
|
||||
root = Metadata.from_bytes(self.metadata["root"])
|
||||
# remove root role keyids representing root signatures
|
||||
root.signed.roles["root"].keyids = []
|
||||
root.signed.roles["root"].keyids = set()
|
||||
with self.assertRaises(exceptions.UnsignedMetadataError):
|
||||
self.trusted_set.update_root(root.to_bytes())
|
||||
|
||||
def test_update_root_new_root_ver_same_as_trusted_root_ver(self):
|
||||
def test_update_root_new_root_ver_same_as_trusted_root_ver(self) -> None:
|
||||
with self.assertRaises(exceptions.ReplayedMetadataError):
|
||||
self.trusted_set.update_root(self.metadata["root"])
|
||||
|
||||
def test_root_expired_final_root(self):
|
||||
def test_root_expired_final_root(self) -> None:
|
||||
def root_expired_modifier(root: Root) -> None:
|
||||
root.expires = datetime(1970, 1, 1)
|
||||
|
||||
|
|
@ -253,7 +259,7 @@ def root_expired_modifier(root: Root) -> None:
|
|||
with self.assertRaises(exceptions.ExpiredMetadataError):
|
||||
tmp_trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
|
||||
def test_update_timestamp_new_timestamp_ver_below_trusted_ver(self):
|
||||
def test_update_timestamp_new_timestamp_ver_below_trusted_ver(self) -> None:
|
||||
# new_timestamp.version < trusted_timestamp.version
|
||||
def version_modifier(timestamp: Timestamp) -> None:
|
||||
timestamp.version = 3
|
||||
|
|
@ -263,7 +269,7 @@ def version_modifier(timestamp: Timestamp) -> None:
|
|||
with self.assertRaises(exceptions.ReplayedMetadataError):
|
||||
self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
|
||||
def test_update_timestamp_snapshot_ver_below_current(self):
|
||||
def test_update_timestamp_snapshot_ver_below_current(self) -> None:
|
||||
def bump_snapshot_version(timestamp: Timestamp) -> None:
|
||||
timestamp.snapshot_meta.version = 2
|
||||
|
||||
|
|
@ -275,7 +281,7 @@ def bump_snapshot_version(timestamp: Timestamp) -> None:
|
|||
with self.assertRaises(exceptions.ReplayedMetadataError):
|
||||
self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
|
||||
def test_update_timestamp_expired(self):
|
||||
def test_update_timestamp_expired(self) -> None:
|
||||
# new_timestamp has expired
|
||||
def timestamp_expired_modifier(timestamp: Timestamp) -> None:
|
||||
timestamp.expires = datetime(1970, 1, 1)
|
||||
|
|
@ -291,7 +297,7 @@ def timestamp_expired_modifier(timestamp: Timestamp) -> None:
|
|||
with self.assertRaises(exceptions.ExpiredMetadataError):
|
||||
self.trusted_set.update_snapshot(self.metadata["snapshot"])
|
||||
|
||||
def test_update_snapshot_length_or_hash_mismatch(self):
|
||||
def test_update_snapshot_length_or_hash_mismatch(self) -> None:
|
||||
def modify_snapshot_length(timestamp: Timestamp) -> None:
|
||||
timestamp.snapshot_meta.length = 1
|
||||
|
||||
|
|
@ -302,14 +308,16 @@ def modify_snapshot_length(timestamp: Timestamp) -> None:
|
|||
with self.assertRaises(exceptions.RepositoryError):
|
||||
self.trusted_set.update_snapshot(self.metadata["snapshot"])
|
||||
|
||||
def test_update_snapshot_cannot_verify_snapshot_with_threshold(self):
|
||||
def test_update_snapshot_fail_threshold_verification(self) -> None:
|
||||
self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
snapshot = Metadata.from_bytes(self.metadata["snapshot"])
|
||||
snapshot.signatures.clear()
|
||||
with self.assertRaises(exceptions.UnsignedMetadataError):
|
||||
self.trusted_set.update_snapshot(snapshot.to_bytes())
|
||||
|
||||
def test_update_snapshot_version_different_timestamp_snapshot_version(self):
|
||||
def test_update_snapshot_version_diverge_timestamp_snapshot_version(
|
||||
self,
|
||||
) -> None:
|
||||
def timestamp_version_modifier(timestamp: Timestamp) -> None:
|
||||
timestamp.snapshot_meta.version = 2
|
||||
|
||||
|
|
@ -326,7 +334,7 @@ def timestamp_version_modifier(timestamp: Timestamp) -> None:
|
|||
with self.assertRaises(exceptions.BadVersionNumberError):
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
||||
def test_update_snapshot_file_removed_from_meta(self):
|
||||
def test_update_snapshot_file_removed_from_meta(self) -> None:
|
||||
self._update_all_besides_targets(self.metadata["timestamp"])
|
||||
|
||||
def remove_file_from_meta(snapshot: Snapshot) -> None:
|
||||
|
|
@ -337,7 +345,7 @@ def remove_file_from_meta(snapshot: Snapshot) -> None:
|
|||
with self.assertRaises(exceptions.RepositoryError):
|
||||
self.trusted_set.update_snapshot(snapshot)
|
||||
|
||||
def test_update_snapshot_meta_version_decreases(self):
|
||||
def test_update_snapshot_meta_version_decreases(self) -> None:
|
||||
self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
|
||||
def version_meta_modifier(snapshot: Snapshot) -> None:
|
||||
|
|
@ -349,7 +357,7 @@ def version_meta_modifier(snapshot: Snapshot) -> None:
|
|||
with self.assertRaises(exceptions.BadVersionNumberError):
|
||||
self.trusted_set.update_snapshot(self.metadata["snapshot"])
|
||||
|
||||
def test_update_snapshot_expired_new_snapshot(self):
|
||||
def test_update_snapshot_expired_new_snapshot(self) -> None:
|
||||
self.trusted_set.update_timestamp(self.metadata["timestamp"])
|
||||
|
||||
def snapshot_expired_modifier(snapshot: Snapshot) -> None:
|
||||
|
|
@ -364,7 +372,7 @@ def snapshot_expired_modifier(snapshot: Snapshot) -> None:
|
|||
with self.assertRaises(exceptions.ExpiredMetadataError):
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
||||
def test_update_snapshot_successful_rollback_checks(self):
|
||||
def test_update_snapshot_successful_rollback_checks(self) -> None:
|
||||
def meta_version_bump(timestamp: Timestamp) -> None:
|
||||
timestamp.snapshot_meta.version += 1
|
||||
|
||||
|
|
@ -386,7 +394,7 @@ def version_bump(snapshot: Snapshot) -> None:
|
|||
# update targets to trigger final snapshot meta version check
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
||||
def test_update_targets_no_meta_in_snapshot(self):
|
||||
def test_update_targets_no_meta_in_snapshot(self) -> None:
|
||||
def no_meta_modifier(snapshot: Snapshot) -> None:
|
||||
snapshot.meta = {}
|
||||
|
||||
|
|
@ -396,7 +404,7 @@ def no_meta_modifier(snapshot: Snapshot) -> None:
|
|||
with self.assertRaises(exceptions.RepositoryError):
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
||||
def test_update_targets_hash_different_than_snapshot_meta_hash(self):
|
||||
def test_update_targets_hash_diverge_from_snapshot_meta_hash(self) -> None:
|
||||
def meta_length_modifier(snapshot: Snapshot) -> None:
|
||||
for metafile_path in snapshot.meta:
|
||||
snapshot.meta[metafile_path] = MetaFile(version=1, length=1)
|
||||
|
|
@ -407,7 +415,7 @@ def meta_length_modifier(snapshot: Snapshot) -> None:
|
|||
with self.assertRaises(exceptions.RepositoryError):
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
||||
def test_update_targets_version_different_snapshot_meta_version(self):
|
||||
def test_update_targets_version_diverge_snapshot_meta_version(self) -> None:
|
||||
def meta_modifier(snapshot: Snapshot) -> None:
|
||||
for metafile_path in snapshot.meta:
|
||||
snapshot.meta[metafile_path] = MetaFile(version=2)
|
||||
|
|
@ -418,7 +426,7 @@ def meta_modifier(snapshot: Snapshot) -> None:
|
|||
with self.assertRaises(exceptions.BadVersionNumberError):
|
||||
self.trusted_set.update_targets(self.metadata["targets"])
|
||||
|
||||
def test_update_targets_expired_new_target(self):
|
||||
def test_update_targets_expired_new_target(self) -> None:
|
||||
self._update_all_besides_targets()
|
||||
# new_delegated_target has expired
|
||||
def target_expired_modifier(target: Targets) -> None:
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ class TestUpdaterKeyRotations(unittest.TestCase):
|
|||
dump_dir: Optional[str] = None
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.sim = None
|
||||
self.metadata_dir = None
|
||||
self.sim: RepositorySimulator
|
||||
self.metadata_dir: str
|
||||
self.subtest_count = 0
|
||||
# pylint: disable-next=consider-using-with
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
from typing import Callable, ClassVar, List
|
||||
|
||||
from securesystemslib.interface import import_rsa_privatekey_from_file
|
||||
from securesystemslib.signer import SSlibSigner
|
||||
|
||||
from tests import utils
|
||||
from tuf import exceptions, ngclient, unittest_toolbox
|
||||
from tuf.api.metadata import Metadata, TargetFile
|
||||
from tuf.api.metadata import Metadata, Root, TargetFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -27,8 +27,11 @@
|
|||
class TestUpdater(unittest_toolbox.Modified_TestCase):
|
||||
"""Test the Updater class from 'tuf/ngclient/updater.py'."""
|
||||
|
||||
temporary_directory: ClassVar[str]
|
||||
server_process_handler: ClassVar[utils.TestServerProcess]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUpClass(cls) -> None:
|
||||
# Create a temporary directory to store the repository, metadata, and
|
||||
# target files. 'temporary_directory' must be deleted in
|
||||
# TearDownModule() so that temporary files are always removed, even when
|
||||
|
|
@ -38,18 +41,18 @@ def setUpClass(cls):
|
|||
# Needed because in some tests simple_server.py cannot be found.
|
||||
# The reason is that the current working directory
|
||||
# has been changed when executing a subprocess.
|
||||
cls.SIMPLE_SERVER_PATH = os.path.join(os.getcwd(), "simple_server.py")
|
||||
SIMPLE_SERVER_PATH = os.path.join(os.getcwd(), "simple_server.py")
|
||||
|
||||
# Launch a SimpleHTTPServer (serves files in the current directory).
|
||||
# Test cases will request metadata and target files that have been
|
||||
# pre-generated in 'tuf/tests/repository_data', which will be served
|
||||
# by the SimpleHTTPServer launched here.
|
||||
cls.server_process_handler = utils.TestServerProcess(
|
||||
log=logger, server=cls.SIMPLE_SERVER_PATH
|
||||
log=logger, server=SIMPLE_SERVER_PATH
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
# Cleans the resources and flush the logged lines (if any).
|
||||
cls.server_process_handler.clean()
|
||||
|
||||
|
|
@ -57,7 +60,7 @@ def tearDownClass(cls):
|
|||
# the metadata, targets, and key files generated for the test cases
|
||||
shutil.rmtree(cls.temporary_directory)
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
# We are inheriting from custom class.
|
||||
unittest_toolbox.Modified_TestCase.setUp(self)
|
||||
|
||||
|
|
@ -124,7 +127,7 @@ def setUp(self):
|
|||
target_base_url=self.targets_url,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
# We are inheriting from custom class.
|
||||
unittest_toolbox.Modified_TestCase.tearDown(self)
|
||||
|
||||
|
|
@ -132,7 +135,9 @@ def tearDown(self):
|
|||
self.server_process_handler.flush_log()
|
||||
|
||||
def _modify_repository_root(
|
||||
self, modification_func, bump_version=False
|
||||
self,
|
||||
modification_func: Callable[[Metadata], None],
|
||||
bump_version: bool = False,
|
||||
) -> None:
|
||||
"""Apply 'modification_func' to root and persist it."""
|
||||
role_path = os.path.join(
|
||||
|
|
@ -159,13 +164,13 @@ def _modify_repository_root(
|
|||
)
|
||||
)
|
||||
|
||||
def _assert_files(self, roles: List[str]):
|
||||
def _assert_files(self, roles: List[str]) -> None:
|
||||
"""Assert that local metadata files exist for 'roles'"""
|
||||
expected_files = [f"{role}.json" for role in roles]
|
||||
client_files = sorted(os.listdir(self.client_directory))
|
||||
self.assertEqual(client_files, expected_files)
|
||||
|
||||
def test_refresh_and_download(self):
|
||||
def test_refresh_and_download(self) -> None:
|
||||
# Test refresh without consistent targets - targets without hash prefix.
|
||||
|
||||
# top-level targets are already in local cache (but remove others)
|
||||
|
|
@ -179,10 +184,12 @@ def test_refresh_and_download(self):
|
|||
|
||||
# Get targetinfos, assert that cache does not contain files
|
||||
info1 = self.updater.get_targetinfo("file1.txt")
|
||||
assert isinstance(info1, TargetFile)
|
||||
self._assert_files(["root", "snapshot", "targets", "timestamp"])
|
||||
|
||||
# Get targetinfo for 'file3.txt' listed in the delegated role1
|
||||
info3 = self.updater.get_targetinfo("file3.txt")
|
||||
assert isinstance(info3, TargetFile)
|
||||
expected_files = ["role1", "root", "snapshot", "targets", "timestamp"]
|
||||
self._assert_files(expected_files)
|
||||
self.assertIsNone(self.updater.find_cached_target(info1))
|
||||
|
|
@ -200,7 +207,7 @@ def test_refresh_and_download(self):
|
|||
path = self.updater.find_cached_target(info3)
|
||||
self.assertEqual(path, os.path.join(self.dl_dir, info3.path))
|
||||
|
||||
def test_refresh_with_only_local_root(self):
|
||||
def test_refresh_with_only_local_root(self) -> None:
|
||||
os.remove(os.path.join(self.client_directory, "timestamp.json"))
|
||||
os.remove(os.path.join(self.client_directory, "snapshot.json"))
|
||||
os.remove(os.path.join(self.client_directory, "targets.json"))
|
||||
|
|
@ -217,7 +224,7 @@ def test_refresh_with_only_local_root(self):
|
|||
expected_files = ["role1", "root", "snapshot", "targets", "timestamp"]
|
||||
self._assert_files(expected_files)
|
||||
|
||||
def test_implicit_refresh_with_only_local_root(self):
|
||||
def test_implicit_refresh_with_only_local_root(self) -> None:
|
||||
os.remove(os.path.join(self.client_directory, "timestamp.json"))
|
||||
os.remove(os.path.join(self.client_directory, "snapshot.json"))
|
||||
os.remove(os.path.join(self.client_directory, "targets.json"))
|
||||
|
|
@ -231,7 +238,7 @@ def test_implicit_refresh_with_only_local_root(self):
|
|||
expected_files = ["role1", "root", "snapshot", "targets", "timestamp"]
|
||||
self._assert_files(expected_files)
|
||||
|
||||
def test_both_target_urls_not_set(self):
|
||||
def test_both_target_urls_not_set(self) -> None:
|
||||
# target_base_url = None and Updater._target_base_url = None
|
||||
updater = ngclient.Updater(
|
||||
self.client_directory, self.metadata_url, self.dl_dir
|
||||
|
|
@ -240,7 +247,7 @@ def test_both_target_urls_not_set(self):
|
|||
with self.assertRaises(ValueError):
|
||||
updater.download_target(info)
|
||||
|
||||
def test_no_target_dir_no_filepath(self):
|
||||
def test_no_target_dir_no_filepath(self) -> None:
|
||||
# filepath = None and Updater.target_dir = None
|
||||
updater = ngclient.Updater(self.client_directory, self.metadata_url)
|
||||
info = TargetFile(1, {"sha256": ""}, "targetpath")
|
||||
|
|
@ -249,15 +256,17 @@ def test_no_target_dir_no_filepath(self):
|
|||
with self.assertRaises(ValueError):
|
||||
updater.download_target(info)
|
||||
|
||||
def test_external_targets_url(self):
|
||||
def test_external_targets_url(self) -> None:
|
||||
self.updater.refresh()
|
||||
info = self.updater.get_targetinfo("file1.txt")
|
||||
assert isinstance(info, TargetFile)
|
||||
|
||||
self.updater.download_target(info, target_base_url=self.targets_url)
|
||||
|
||||
def test_length_hash_mismatch(self):
|
||||
def test_length_hash_mismatch(self) -> None:
|
||||
self.updater.refresh()
|
||||
targetinfo = self.updater.get_targetinfo("file1.txt")
|
||||
assert isinstance(targetinfo, TargetFile)
|
||||
|
||||
length = targetinfo.length
|
||||
with self.assertRaises(exceptions.RepositoryError):
|
||||
|
|
@ -270,13 +279,13 @@ def test_length_hash_mismatch(self):
|
|||
self.updater.download_target(targetinfo)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_updating_root(self):
|
||||
def test_updating_root(self) -> None:
|
||||
# Bump root version, resign and refresh
|
||||
self._modify_repository_root(lambda root: None, bump_version=True)
|
||||
self.updater.refresh()
|
||||
self.assertEqual(self.updater._trusted_set.root.signed.version, 2)
|
||||
|
||||
def test_missing_targetinfo(self):
|
||||
def test_missing_targetinfo(self) -> None:
|
||||
self.updater.refresh()
|
||||
|
||||
# Get targetinfo for non-existing file
|
||||
|
|
|
|||
|
|
@ -349,7 +349,7 @@ def test_new_snapshot_unsigned(self) -> None:
|
|||
|
||||
self._assert_files_exist(["root", "timestamp"])
|
||||
|
||||
def test_new_snapshot_version_mismatch(self):
|
||||
def test_new_snapshot_version_mismatch(self) -> None:
|
||||
# Check against timestamp role’s snapshot version
|
||||
|
||||
# Increase snapshot version without updating timestamp
|
||||
|
|
@ -414,7 +414,7 @@ def test_new_targets_unsigned(self) -> None:
|
|||
|
||||
self._assert_files_exist(["root", "timestamp", "snapshot"])
|
||||
|
||||
def test_new_targets_version_mismatch(self):
|
||||
def test_new_targets_version_mismatch(self) -> None:
|
||||
# Check against snapshot role’s targets version
|
||||
|
||||
# Increase targets version without updating snapshot
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
from tests import utils
|
||||
from tests.repository_simulator import RepositorySimulator
|
||||
from tuf.api.metadata import SPECIFICATION_VERSION, Targets
|
||||
from tuf.api.metadata import SPECIFICATION_VERSION, TargetFile, Targets
|
||||
from tuf.exceptions import BadVersionNumberError, UnsignedMetadataError
|
||||
from tuf.ngclient import Updater
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class TestUpdater(unittest.TestCase):
|
|||
# set dump_dir to trigger repository state dumps
|
||||
dump_dir: Optional[str] = None
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable-next=consider-using-with
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.metadata_dir = os.path.join(self.temp_dir.name, "metadata")
|
||||
|
|
@ -49,7 +49,7 @@ def setUp(self):
|
|||
self.sim.dump_dir = os.path.join(self.dump_dir, name)
|
||||
os.mkdir(self.sim.dump_dir)
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def _run_refresh(self) -> Updater:
|
||||
|
|
@ -67,7 +67,7 @@ def _run_refresh(self) -> Updater:
|
|||
updater.refresh()
|
||||
return updater
|
||||
|
||||
def test_refresh(self):
|
||||
def test_refresh(self) -> None:
|
||||
# Update top level metadata
|
||||
self._run_refresh()
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ def test_refresh(self):
|
|||
}
|
||||
|
||||
@utils.run_sub_tests_with_dataset(targets)
|
||||
def test_targets(self, test_case_data: Tuple[str, bytes, str]):
|
||||
def test_targets(self, test_case_data: Tuple[str, bytes, str]) -> None:
|
||||
targetpath, content, encoded_path = test_case_data
|
||||
path = os.path.join(self.targets_dir, encoded_path)
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ def test_targets(self, test_case_data: Tuple[str, bytes, str]):
|
|||
updater = self._run_refresh()
|
||||
# target now exists, is not in cache yet
|
||||
info = updater.get_targetinfo(targetpath)
|
||||
self.assertIsNotNone(info)
|
||||
assert info is not None
|
||||
# Test without and with explicit local filepath
|
||||
self.assertIsNone(updater.find_cached_target(info))
|
||||
self.assertIsNone(updater.find_cached_target(info, path))
|
||||
|
|
@ -136,7 +136,7 @@ def test_targets(self, test_case_data: Tuple[str, bytes, str]):
|
|||
self.assertEqual(path, updater.find_cached_target(info))
|
||||
self.assertEqual(path, updater.find_cached_target(info, path))
|
||||
|
||||
def test_fishy_rolenames(self):
|
||||
def test_fishy_rolenames(self) -> None:
|
||||
roles_to_filenames = {
|
||||
"../a": "..%2Fa.json",
|
||||
"": ".json",
|
||||
|
|
@ -162,7 +162,7 @@ def test_fishy_rolenames(self):
|
|||
for fname in roles_to_filenames.values():
|
||||
self.assertTrue(fname in local_metadata)
|
||||
|
||||
def test_keys_and_signatures(self):
|
||||
def test_keys_and_signatures(self) -> None:
|
||||
"""Example of the two trickiest test areas: keys and root updates"""
|
||||
|
||||
# Update top level metadata
|
||||
|
|
@ -202,7 +202,7 @@ def test_keys_and_signatures(self):
|
|||
|
||||
self._run_refresh()
|
||||
|
||||
def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self):
|
||||
def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self) -> None:
|
||||
# Test triggering snapshot rollback check on a newly downloaded snapshot
|
||||
# when the local snapshot is loaded even when there is a hash mismatch
|
||||
# with timestamp.snapshot_meta.
|
||||
|
|
@ -233,7 +233,7 @@ def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self):
|
|||
self._run_refresh()
|
||||
|
||||
@patch.object(builtins, "open", wraps=builtins.open)
|
||||
def test_not_loading_targets_twice(self, wrapped_open: MagicMock):
|
||||
def test_not_loading_targets_twice(self, wrapped_open: MagicMock) -> None:
|
||||
# Do not load targets roles more than once when traversing
|
||||
# the delegations tree
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Any, Callable
|
||||
from typing import Any, Callable, Dict, IO, Optional, Callable, List, Iterator
|
||||
import unittest
|
||||
import argparse
|
||||
import errno
|
||||
|
|
@ -48,9 +48,13 @@
|
|||
# Test runner decorator: Runs the test as a set of N SubTests,
|
||||
# (where N is number of items in dataset), feeding the actual test
|
||||
# function one test case at a time
|
||||
def run_sub_tests_with_dataset(dataset: DataSet):
|
||||
def real_decorator(function: Callable[[unittest.TestCase, Any], None]):
|
||||
def wrapper(test_cls: unittest.TestCase):
|
||||
def run_sub_tests_with_dataset(
|
||||
dataset: DataSet
|
||||
) -> Callable[[Callable], Callable]:
|
||||
def real_decorator(
|
||||
function: Callable[[unittest.TestCase, Any], None]
|
||||
) -> Callable[[unittest.TestCase], None]:
|
||||
def wrapper(test_cls: unittest.TestCase) -> None:
|
||||
for case, data in dataset.items():
|
||||
with test_cls.subTest(case=case):
|
||||
function(test_cls, data)
|
||||
|
|
@ -60,15 +64,15 @@ def wrapper(test_cls: unittest.TestCase):
|
|||
|
||||
class TestServerProcessError(Exception):
|
||||
|
||||
def __init__(self, value="TestServerProcess"):
|
||||
def __init__(self, value: str="TestServerProcess") -> None:
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return repr(self.value)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ignore_deprecation_warnings(module):
|
||||
def ignore_deprecation_warnings(module: str) -> Iterator[None]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore',
|
||||
category=DeprecationWarning,
|
||||
|
|
@ -82,13 +86,16 @@ def ignore_deprecation_warnings(module):
|
|||
# but the current blocking connect() seems to work fast on Linux and seems
|
||||
# to at least work on Windows (ECONNREFUSED unfortunately has a 2 second
|
||||
# timeout on Windows)
|
||||
def wait_for_server(host, server, port, timeout=10):
|
||||
def wait_for_server(host: str, server: str, port: int, timeout: int=10) -> None:
|
||||
start = time.time()
|
||||
remaining_timeout = timeout
|
||||
succeeded = False
|
||||
while not succeeded and remaining_timeout > 0:
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock: Optional[socket.socket] = socket.socket(
|
||||
socket.AF_INET, socket.SOCK_STREAM
|
||||
)
|
||||
assert sock is not None
|
||||
sock.settimeout(remaining_timeout)
|
||||
sock.connect((host, port))
|
||||
succeeded = True
|
||||
|
|
@ -104,14 +111,14 @@ def wait_for_server(host, server, port, timeout=10):
|
|||
if sock:
|
||||
sock.close()
|
||||
sock = None
|
||||
remaining_timeout = timeout - (time.time() - start)
|
||||
remaining_timeout = int(timeout - (time.time() - start))
|
||||
|
||||
if not succeeded:
|
||||
raise TimeoutError("Could not connect to the " + server \
|
||||
+ " on port " + str(port) + "!")
|
||||
|
||||
|
||||
def configure_test_logging(argv):
|
||||
def configure_test_logging(argv: List[str]) -> None:
|
||||
# parse arguments but only handle '-v': argv may contain
|
||||
# other things meant for unittest argument parser
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
|
@ -165,13 +172,14 @@ class TestServerProcess():
|
|||
"""
|
||||
|
||||
|
||||
def __init__(self, log, server='simple_server.py',
|
||||
timeout=10, popen_cwd=".", extra_cmd_args=None):
|
||||
def __init__(self, log: logging.Logger, server: str='simple_server.py',
|
||||
timeout: int=10, popen_cwd: str=".", extra_cmd_args: Optional[List[str]]=None
|
||||
):
|
||||
|
||||
self.server = server
|
||||
self.__logger = log
|
||||
# Stores popped messages from the queue.
|
||||
self.__logged_messages = []
|
||||
self.__logged_messages: List[str] = []
|
||||
if extra_cmd_args is None:
|
||||
extra_cmd_args = []
|
||||
|
||||
|
|
@ -185,7 +193,9 @@ def __init__(self, log, server='simple_server.py',
|
|||
|
||||
|
||||
|
||||
def _start_server(self, timeout, extra_cmd_args, popen_cwd):
|
||||
def _start_server(
|
||||
self, timeout: int, extra_cmd_args: List[str], popen_cwd: str
|
||||
) -> None:
|
||||
"""
|
||||
Start the server subprocess and a thread
|
||||
responsible to redirect stdout/stderr to the Queue.
|
||||
|
|
@ -201,7 +211,7 @@ def _start_server(self, timeout, extra_cmd_args, popen_cwd):
|
|||
|
||||
|
||||
|
||||
def _start_process(self, extra_cmd_args, popen_cwd):
|
||||
def _start_process(self, extra_cmd_args: List[str], popen_cwd: str) -> None:
|
||||
"""Starts the process running the server."""
|
||||
|
||||
# The "-u" option forces stdin, stdout and stderr to be unbuffered.
|
||||
|
|
@ -213,7 +223,7 @@ def _start_process(self, extra_cmd_args, popen_cwd):
|
|||
|
||||
|
||||
|
||||
def _start_redirect_thread(self):
|
||||
def _start_redirect_thread(self) -> None:
|
||||
"""Starts a thread responsible to redirect stdout/stderr to the Queue."""
|
||||
|
||||
# Run log_queue_worker() in a thread.
|
||||
|
|
@ -228,7 +238,7 @@ def _start_redirect_thread(self):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def _log_queue_worker(stream, line_queue):
|
||||
def _log_queue_worker(stream: IO, line_queue: queue.Queue) -> None:
|
||||
"""
|
||||
Worker function to run in a seprate thread.
|
||||
Reads from 'stream', puts lines in a Queue (Queue is thread-safe).
|
||||
|
|
@ -247,7 +257,7 @@ def _log_queue_worker(stream, line_queue):
|
|||
|
||||
|
||||
|
||||
def _wait_for_port(self, timeout):
|
||||
def _wait_for_port(self, timeout: int) -> None:
|
||||
"""
|
||||
Validates the first item from the Queue against the port message.
|
||||
If validation is successful, self.port is set.
|
||||
|
|
@ -279,7 +289,7 @@ def _wait_for_port(self, timeout):
|
|||
|
||||
|
||||
|
||||
def _kill_server_process(self):
|
||||
def _kill_server_process(self) -> None:
|
||||
"""Kills the server subprocess if it's running."""
|
||||
|
||||
if self.is_process_running():
|
||||
|
|
@ -290,7 +300,7 @@ def _kill_server_process(self):
|
|||
|
||||
|
||||
|
||||
def flush_log(self):
|
||||
def flush_log(self) -> None:
|
||||
"""Flushes the log lines from the logging queue."""
|
||||
|
||||
while True:
|
||||
|
|
@ -311,7 +321,7 @@ def flush_log(self):
|
|||
|
||||
|
||||
|
||||
def clean(self):
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Kills the subprocess and closes the TempFile.
|
||||
Calls flush_log to check for logged information, but not yet flushed.
|
||||
|
|
@ -324,5 +334,5 @@ def clean(self):
|
|||
|
||||
|
||||
|
||||
def is_process_running(self):
|
||||
def is_process_running(self) -> bool:
|
||||
return True if self.__server_process.poll() is None else False
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ def filter(self, record):
|
|||
|
||||
|
||||
|
||||
def set_log_level(log_level=_DEFAULT_LOG_LEVEL):
|
||||
def set_log_level(log_level: int=_DEFAULT_LOG_LEVEL):
|
||||
"""
|
||||
<Purpose>
|
||||
Allow the default log level to be overridden. If 'log_level' is not
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
import random
|
||||
import string
|
||||
|
||||
from typing import Optional
|
||||
|
||||
class Modified_TestCase(unittest.TestCase):
|
||||
"""
|
||||
|
|
@ -70,12 +71,12 @@ def setUp():
|
|||
"""
|
||||
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self._cleanup = []
|
||||
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
for cleanup_function in self._cleanup:
|
||||
# Perform clean up by executing clean-up functions.
|
||||
try:
|
||||
|
|
@ -87,7 +88,7 @@ def tearDown(self):
|
|||
|
||||
|
||||
|
||||
def make_temp_directory(self, directory=None):
|
||||
def make_temp_directory(self, directory: Optional[str]=None) -> str:
|
||||
"""Creates and returns an absolute path of a directory."""
|
||||
|
||||
prefix = self.__class__.__name__+'_'
|
||||
|
|
@ -102,7 +103,9 @@ def _destroy_temp_directory():
|
|||
|
||||
|
||||
|
||||
def make_temp_file(self, suffix='.txt', directory=None):
|
||||
def make_temp_file(
|
||||
self,suffix: str='.txt', directory: Optional[str]=None
|
||||
) -> str:
|
||||
"""Creates and returns an absolute path of an empty file."""
|
||||
prefix='tmp_file_'+self.__class__.__name__+'_'
|
||||
temp_file = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=directory)
|
||||
|
|
@ -113,7 +116,9 @@ def _destroy_temp_file():
|
|||
|
||||
|
||||
|
||||
def make_temp_data_file(self, suffix='', directory=None, data = 'junk data'):
|
||||
def make_temp_data_file(
|
||||
self, suffix: str='', directory: Optional[str]=None, data: str = 'junk data'
|
||||
) -> str:
|
||||
"""Returns an absolute path of a temp file containing data."""
|
||||
temp_file_path = self.make_temp_file(suffix=suffix, directory=directory)
|
||||
temp_file = open(temp_file_path, 'wt', encoding='utf8')
|
||||
|
|
@ -123,7 +128,7 @@ def make_temp_data_file(self, suffix='', directory=None, data = 'junk data'):
|
|||
|
||||
|
||||
|
||||
def random_path(self, length = 7):
|
||||
def random_path(self, length: int = 7) -> str:
|
||||
"""Generate a 'random' path consisting of random n-length strings."""
|
||||
|
||||
rand_path = '/' + self.random_string(length)
|
||||
|
|
@ -136,7 +141,7 @@ def random_path(self, length = 7):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def random_string(length=15):
|
||||
def random_string(length: int=15) -> str:
|
||||
"""Generate a random string of specified length."""
|
||||
|
||||
rand_str = ''
|
||||
|
|
|
|||
Loading…
Reference in a new issue