Address mypy warnings

This commit includes manual fixes for a lot of mypy warnings.
When there were warnings that we are calling non-annotated function
in annotated context I decided to add annotations instead of ignoring
those warnings.
That's how I end up adding annotations in the whole tests/utils.py
module.

Signed-off-by: Martin Vrachev <mvrachev@vmware.com>
This commit is contained in:
Martin Vrachev 2021-11-18 18:58:16 +02:00
parent 0d4d7f820c
commit e2deff3148
12 changed files with 224 additions and 177 deletions

View file

@ -93,11 +93,7 @@ class RepositorySimulator(FetcherInterface):
"""Simulates a repository that can be used for testing."""
# pylint: disable=too-many-instance-attributes
def __init__(self):
self.md_root: Metadata[Root] = None
self.md_timestamp: Metadata[Timestamp] = None
self.md_snapshot: Metadata[Snapshot] = None
self.md_targets: Metadata[Targets] = None
def __init__(self) -> None:
self.md_delegates: Dict[str, Metadata[Targets]] = {}
# other metadata is signed on-demand (when fetched) but roots must be
@ -117,7 +113,7 @@ def __init__(self):
# Enable hash-prefixed target file names
self.prefix_targets_with_hash = True
self.dump_dir = None
self.dump_dir: Optional[str] = None
self.dump_version = 0
now = datetime.utcnow()
@ -152,12 +148,12 @@ def create_key() -> Tuple[Key, SSlibSigner]:
sslib_key = generate_ed25519_key()
return Key.from_securesystemslib_key(sslib_key), SSlibSigner(sslib_key)
def add_signer(self, role: str, signer: SSlibSigner):
def add_signer(self, role: str, signer: SSlibSigner) -> None:
if role not in self.signers:
self.signers[role] = {}
self.signers[role][signer.key_dict["keyid"]] = signer
def _initialize(self):
def _initialize(self) -> None:
"""Setup a minimal valid repository."""
targets = Targets(1, SPEC_VER, self.safe_expiry, {}, None)
@ -182,7 +178,7 @@ def _initialize(self):
self.md_root = Metadata(root, OrderedDict())
self.publish_root()
def publish_root(self):
def publish_root(self) -> None:
"""Sign and store a new serialized version of root."""
self.md_root.signatures.clear()
for signer in self.signers["root"].values():
@ -199,12 +195,12 @@ def fetch(self, url: str) -> Iterator[bytes]:
if path.startswith("/metadata/") and path.endswith(".json"):
# figure out rolename and version
ver_and_name = path[len("/metadata/") :][: -len(".json")]
version, _, role = ver_and_name.partition(".")
version_str, _, role = ver_and_name.partition(".")
# root is always version-prefixed while timestamp is always NOT
if role == "root" or (
self.root.consistent_snapshot and ver_and_name != "timestamp"
):
version = int(version)
version: Optional[int] = int(version_str)
else:
# the file is not version-prefixed
role = ver_and_name
@ -216,11 +212,10 @@ def fetch(self, url: str) -> Iterator[bytes]:
target_path = path[len("/targets/") :]
dir_parts, sep, prefixed_filename = target_path.rpartition("/")
# extract the hash prefix, if any
prefix: Optional[str] = None
filename = prefixed_filename
if self.root.consistent_snapshot and self.prefix_targets_with_hash:
prefix, _, filename = prefixed_filename.partition(".")
else:
filename = prefixed_filename
prefix = None
target_path = f"{dir_parts}{sep}{filename}"
yield self._fetch_target(target_path, prefix)
@ -261,8 +256,9 @@ def _fetch_metadata(
return self.signed_roots[version - 1]
# sign and serialize the requested metadata
md: Optional[Metadata]
if role == "timestamp":
md: Metadata = self.md_timestamp
md = self.md_timestamp
elif role == "snapshot":
md = self.md_snapshot
elif role == "targets":
@ -294,7 +290,7 @@ def _compute_hashes_and_length(
hashes = {sslib_hash.DEFAULT_HASH_ALGORITHM: digest_object.hexdigest()}
return hashes, len(data)
def update_timestamp(self):
def update_timestamp(self) -> None:
"""Update timestamp and assign snapshot version to snapshot_meta
version.
"""
@ -307,7 +303,7 @@ def update_timestamp(self):
self.timestamp.version += 1
def update_snapshot(self):
def update_snapshot(self) -> None:
"""Update snapshot, assign targets versions and update timestamp."""
for role, delegate in self.all_targets():
hashes = None
@ -322,7 +318,7 @@ def update_snapshot(self):
self.snapshot.version += 1
self.update_timestamp()
def add_target(self, role: str, data: bytes, path: str):
def add_target(self, role: str, data: bytes, path: str) -> None:
"""Create a target from data and add it to the target_files."""
if role == "targets":
targets = self.targets
@ -341,7 +337,7 @@ def add_delegation(
terminating: bool,
paths: Optional[List[str]],
hash_prefixes: Optional[List[str]],
):
) -> None:
"""Add delegated target role to the repository."""
if delegator_name == "targets":
delegator = self.targets
@ -351,7 +347,7 @@ def add_delegation(
# Create delegation
role = DelegatedRole(name, [], 1, terminating, paths, hash_prefixes)
if delegator.delegations is None:
delegator.delegations = Delegations({}, {})
delegator.delegations = Delegations({}, OrderedDict())
# put delegation last by default
delegator.delegations.roles[role.name] = role
@ -363,7 +359,7 @@ def add_delegation(
# Add metadata for the role
self.md_delegates[role.name] = Metadata(targets, OrderedDict())
def write(self):
def write(self) -> None:
"""Dump current repository metadata to self.dump_dir
This is a debugging tool: dumping repository state before running

View file

@ -14,6 +14,7 @@
import tempfile
import unittest
from datetime import datetime, timedelta
from typing import ClassVar, Dict
from dateutil.relativedelta import relativedelta
from securesystemslib import hash as sslib_hash
@ -28,6 +29,7 @@
from tuf import exceptions
from tuf.api.metadata import (
DelegatedRole,
Delegations,
Key,
Metadata,
MetaFile,
@ -47,8 +49,13 @@
class TestMetadata(unittest.TestCase):
"""Tests for public API of all classes in 'tuf/api/metadata.py'."""
temporary_directory: ClassVar[str]
repo_dir: ClassVar[str]
keystore_dir: ClassVar[str]
keystore: ClassVar[Dict[str, str]]
@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
# Create a temporary directory to store the repository, metadata, and
# target files. 'temporary_directory' must be deleted in
# TearDownClass() so that temporary files are always removed, even when
@ -78,12 +85,12 @@ def setUpClass(cls):
)
@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
# Remove the temporary repository directory, which should contain all
# the metadata, targets, and key files generated for the test cases.
shutil.rmtree(cls.temporary_directory)
def test_generic_read(self):
def test_generic_read(self) -> None:
for metadata, inner_metadata_cls in [
("root", Root),
("snapshot", Snapshot),
@ -120,7 +127,7 @@ def test_generic_read(self):
os.remove(bad_metadata_path)
def test_compact_json(self):
def test_compact_json(self) -> None:
path = os.path.join(self.repo_dir, "metadata", "targets.json")
md_obj = Metadata.from_file(path)
self.assertTrue(
@ -128,7 +135,7 @@ def test_compact_json(self):
< len(JSONSerializer().serialize(md_obj))
)
def test_read_write_read_compare(self):
def test_read_write_read_compare(self) -> None:
for metadata in ["root", "snapshot", "timestamp", "targets"]:
path = os.path.join(self.repo_dir, "metadata", metadata + ".json")
md_obj = Metadata.from_file(path)
@ -140,7 +147,7 @@ def test_read_write_read_compare(self):
os.remove(path_2)
def test_to_from_bytes(self):
def test_to_from_bytes(self) -> None:
for metadata in ["root", "snapshot", "timestamp", "targets"]:
path = os.path.join(self.repo_dir, "metadata", metadata + ".json")
with open(path, "rb") as f:
@ -157,7 +164,7 @@ def test_to_from_bytes(self):
metadata_obj_2 = Metadata.from_bytes(obj_bytes)
self.assertEqual(metadata_obj_2.to_bytes(), obj_bytes)
def test_sign_verify(self):
def test_sign_verify(self) -> None:
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
root = Metadata[Root].from_file(root_path).signed
@ -183,7 +190,7 @@ def test_sign_verify(self):
# Test verifying with explicitly set serializer
targets_key.verify_signature(md_obj, CanonicalJSONSerializer())
with self.assertRaises(exceptions.UnsignedMetadataError):
targets_key.verify_signature(md_obj, JSONSerializer())
targets_key.verify_signature(md_obj, JSONSerializer()) # type: ignore[arg-type]
sslib_signer = SSlibSigner(self.keystore["snapshot"])
# Append a new signature with the unrelated key and assert that ...
@ -206,7 +213,7 @@ def test_sign_verify(self):
with self.assertRaises(exceptions.UnsignedMetadataError):
targets_key.verify_signature(md_obj)
def test_verify_failures(self):
def test_verify_failures(self) -> None:
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
root = Metadata[Root].from_file(root_path).signed
@ -248,7 +255,7 @@ def test_verify_failures(self):
timestamp_key.verify_signature(md_obj)
sig.signature = correct_sig
def test_metadata_base(self):
def test_metadata_base(self) -> None:
# Use of Snapshot is arbitrary, we're just testing the base class
# features with real data
snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json")
@ -290,7 +297,7 @@ def test_metadata_base(self):
with self.assertRaises(ValueError):
Metadata.from_dict(data)
def test_metadata_snapshot(self):
def test_metadata_snapshot(self) -> None:
snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json")
snapshot = Metadata[Snapshot].from_file(snapshot_path)
@ -309,7 +316,7 @@ def test_metadata_snapshot(self):
snapshot.signed.meta["role1.json"].to_dict(), fileinfo.to_dict()
)
def test_metadata_timestamp(self):
def test_metadata_timestamp(self) -> None:
timestamp_path = os.path.join(
self.repo_dir, "metadata", "timestamp.json"
)
@ -349,7 +356,7 @@ def test_metadata_timestamp(self):
timestamp.signed.snapshot_meta.to_dict(), fileinfo.to_dict()
)
def test_metadata_verify_delegate(self):
def test_metadata_verify_delegate(self) -> None:
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
root = Metadata[Root].from_file(root_path)
snapshot_path = os.path.join(self.repo_dir, "metadata", "snapshot.json")
@ -410,14 +417,14 @@ def test_metadata_verify_delegate(self):
snapshot.sign(SSlibSigner(self.keystore["timestamp"]), append=True)
root.verify_delegate("snapshot", snapshot)
def test_key_class(self):
def test_key_class(self) -> None:
# Test if from_securesystemslib_key removes the private key from keyval
# of a securesystemslib key dictionary.
sslib_key = generate_ed25519_key()
key = Key.from_securesystemslib_key(sslib_key)
self.assertFalse("private" in key.keyval.keys())
def test_root_add_key_and_remove_key(self):
def test_root_add_key_and_remove_key(self) -> None:
root_path = os.path.join(self.repo_dir, "metadata", "root.json")
root = Metadata[Root].from_file(root_path)
@ -475,7 +482,7 @@ def test_root_add_key_and_remove_key(self):
with self.assertRaises(ValueError):
root.signed.remove_key("nosuchrole", keyid)
def test_is_target_in_pathpattern(self):
def test_is_target_in_pathpattern(self) -> None:
# pylint: disable=protected-access
supported_use_cases = [
("foo.tgz", "foo.tgz"),
@ -507,7 +514,7 @@ def test_is_target_in_pathpattern(self):
DelegatedRole._is_target_in_pathpattern(targetpath, pathpattern)
)
def test_metadata_targets(self):
def test_metadata_targets(self) -> None:
targets_path = os.path.join(self.repo_dir, "metadata", "targets.json")
targets = Metadata[Targets].from_file(targets_path)
@ -531,7 +538,7 @@ def test_metadata_targets(self):
targets.signed.targets[filename].to_dict(), fileinfo.to_dict()
)
def test_targets_key_api(self):
def test_targets_key_api(self) -> None:
targets_path = os.path.join(self.repo_dir, "metadata", "targets.json")
targets: Targets = Metadata[Targets].from_file(targets_path).signed
@ -545,6 +552,7 @@ def test_targets_key_api(self):
"threshold": 1,
}
)
assert isinstance(targets.delegations, Delegations)
targets.delegations.roles["role2"] = delegated_role
key_dict = {
@ -608,7 +616,7 @@ def test_targets_key_api(self):
targets.remove_key("role1", key.keyid)
self.assertTrue(targets.delegations is None)
def test_length_and_hash_validation(self):
def test_length_and_hash_validation(self) -> None:
# Test metadata files' hash and length verification.
# Use timestamp to get a MetaFile object and snapshot
@ -648,7 +656,7 @@ def test_length_and_hash_validation(self):
# Test wrong algorithm format (sslib.FormatError)
snapshot_metafile.hashes = {
256: "8f88e2ba48b412c3843e9bb26e1b6f8fc9e98aceb0fbaa97ba37b4c98717d7ab"
256: "8f88e2ba48b412c3843e9bb26e1b6f8fc9e98aceb0fbaa97ba37b4c98717d7ab" # type: ignore[dict-item]
}
with self.assertRaises(exceptions.LengthOrHashMismatchError):
snapshot_metafile.verify_length_and_hashes(data)
@ -678,7 +686,7 @@ def test_length_and_hash_validation(self):
with self.assertRaises(exceptions.LengthOrHashMismatchError):
file1_targetfile.verify_length_and_hashes(file1)
def test_targetfile_from_file(self):
def test_targetfile_from_file(self) -> None:
# Test with an existing file and valid hash algorithm
file_path = os.path.join(self.repo_dir, "targets", "file1.txt")
targetfile_from_file = TargetFile.from_file(
@ -700,7 +708,7 @@ def test_targetfile_from_file(self):
with self.assertRaises(exceptions.UnsupportedAlgorithmError):
TargetFile.from_file(file_path, file_path, ["123"])
def test_targetfile_from_data(self):
def test_targetfile_from_data(self) -> None:
data = b"Inline test content"
target_file_path = os.path.join(self.repo_dir, "targets", "file1.txt")
@ -714,7 +722,7 @@ def test_targetfile_from_data(self):
targetfile_from_data = TargetFile.from_data(target_file_path, data)
targetfile_from_data.verify_length_and_hashes(data)
def test_is_delegated_role(self):
def test_is_delegated_role(self) -> None:
# test path matches
# see more extensive tests in test_is_target_in_pathpattern()
for paths in [

View file

@ -13,6 +13,7 @@
import sys
import tempfile
import unittest
from typing import Any, ClassVar, Iterator
from unittest.mock import Mock, patch
import requests
@ -28,17 +29,19 @@
class TestFetcher(unittest_toolbox.Modified_TestCase):
"""Test RequestsFetcher class."""
server_process_handler: ClassVar[utils.TestServerProcess]
@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
# Launch a SimpleHTTPServer (serves files in the current dir).
cls.server_process_handler = utils.TestServerProcess(log=logger)
@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
# Stop server process and perform clean up.
cls.server_process_handler.clean()
def setUp(self):
def setUp(self) -> None:
"""
Create a temporary file and launch a simple server in the
current working directory.
@ -64,12 +67,12 @@ def setUp(self):
# Instantiate a concrete instance of FetcherInterface
self.fetcher = RequestsFetcher()
def tearDown(self):
def tearDown(self) -> None:
# Remove temporary directory
unittest_toolbox.Modified_TestCase.tearDown(self)
# Simple fetch.
def test_fetch(self):
def test_fetch(self) -> None:
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
@ -80,7 +83,7 @@ def test_fetch(self):
)
# URL data downloaded in more than one chunk
def test_fetch_in_chunks(self):
def test_fetch_in_chunks(self) -> None:
# Set a smaller chunk size to ensure that the file will be downloaded
# in more than one chunk
self.fetcher.chunk_size = 4
@ -105,12 +108,12 @@ def test_fetch_in_chunks(self):
self.assertEqual(chunks_count, expected_chunks_count)
# Incorrect URL parsing
def test_url_parsing(self):
def test_url_parsing(self) -> None:
with self.assertRaises(exceptions.URLParsingError):
self.fetcher.fetch(self.random_string())
# File not found error
def test_http_error(self):
def test_http_error(self) -> None:
with self.assertRaises(exceptions.FetcherHTTPError) as cm:
self.url = f"{self.url_prefix}/non-existing-path"
self.fetcher.fetch(self.url)
@ -118,7 +121,7 @@ def test_http_error(self):
# Response read timeout error
@patch.object(requests.Session, "get")
def test_response_read_timeout(self, mock_session_get):
def test_response_read_timeout(self, mock_session_get: Any) -> None:
mock_response = Mock()
attr = {
"raw.read.side_effect": urllib3.exceptions.ReadTimeoutError(
@ -136,28 +139,28 @@ def test_response_read_timeout(self, mock_session_get):
@patch.object(
requests.Session, "get", side_effect=urllib3.exceptions.TimeoutError
)
def test_session_get_timeout(self, mock_session_get):
def test_session_get_timeout(self, mock_session_get: Any) -> None:
with self.assertRaises(exceptions.SlowRetrievalError):
self.fetcher.fetch(self.url)
mock_session_get.assert_called_once()
# Simple bytes download
def test_download_bytes(self):
def test_download_bytes(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data.decode("utf-8"))
# Download file smaller than required max_length
def test_download_bytes_upper_length(self):
def test_download_bytes_upper_length(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
self.assertEqual(self.file_contents, data.decode("utf-8"))
# Download a file bigger than expected
def test_download_bytes_length_mismatch(self):
def test_download_bytes_length_mismatch(self) -> None:
with self.assertRaises(exceptions.DownloadLengthMismatchError):
self.fetcher.download_bytes(self.url, self.file_length - 4)
# Simple file download
def test_download_file(self):
def test_download_file(self) -> None:
with self.fetcher.download_file(
self.url, self.file_length
) as temp_file:
@ -165,7 +168,7 @@ def test_download_file(self):
self.assertEqual(self.file_length, temp_file.tell())
# Download file smaller than required max_length
def test_download_file_upper_length(self):
def test_download_file_upper_length(self) -> None:
with self.fetcher.download_file(
self.url, self.file_length + 4
) as temp_file:
@ -173,8 +176,10 @@ def test_download_file_upper_length(self):
self.assertEqual(self.file_length, temp_file.tell())
# Download a file bigger than expected
def test_download_file_length_mismatch(self):
def test_download_file_length_mismatch(self) -> Iterator[Any]:
with self.assertRaises(exceptions.DownloadLengthMismatchError):
# Force download_file to execute and raise the error since it is a
# context manager and returns Iterator[IO]
yield self.fetcher.download_file(self.url, self.file_length - 4)

View file

@ -53,7 +53,7 @@ class TestSerialization(unittest.TestCase):
}
@utils.run_sub_tests_with_dataset(invalid_signed)
def test_invalid_signed_serialization(self, test_case_data: str):
def test_invalid_signed_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises((KeyError, ValueError, TypeError)):
Snapshot.from_dict(copy.deepcopy(case_dict))
@ -68,7 +68,7 @@ def test_invalid_signed_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_keys)
def test_valid_key_serialization(self, test_case_data: str):
def test_valid_key_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
key = Key.from_dict("id", copy.copy(case_dict))
self.assertDictEqual(case_dict, key.to_dict())
@ -85,7 +85,7 @@ def test_valid_key_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_keys)
def test_invalid_key_serialization(self, test_case_data: str):
def test_invalid_key_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises((TypeError, KeyError)):
keyid = case_dict.pop("keyid")
@ -100,7 +100,7 @@ def test_invalid_key_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_roles)
def test_invalid_role_serialization(self, test_case_data: str):
def test_invalid_role_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises((KeyError, TypeError, ValueError)):
Role.from_dict(copy.deepcopy(case_dict))
@ -113,7 +113,7 @@ def test_invalid_role_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_roles)
def test_role_serialization(self, test_case_data: str):
def test_role_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
role = Role.from_dict(copy.deepcopy(case_dict))
self.assertDictEqual(case_dict, role.to_dict())
@ -161,7 +161,7 @@ def test_role_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_roots)
def test_root_serialization(self, test_case_data: str):
def test_root_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
root = Root.from_dict(copy.deepcopy(case_dict))
self.assertDictEqual(case_dict, root.to_dict())
@ -203,7 +203,7 @@ def test_root_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_roots)
def test_invalid_root_serialization(self, test_case_data: str):
def test_invalid_root_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises(ValueError):
Root.from_dict(copy.deepcopy(case_dict))
@ -218,7 +218,7 @@ def test_invalid_root_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_metafiles)
def test_invalid_metafile_serialization(self, test_case_data: str):
def test_invalid_metafile_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises((TypeError, ValueError, AttributeError)):
MetaFile.from_dict(copy.deepcopy(case_dict))
@ -232,7 +232,7 @@ def test_invalid_metafile_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_metafiles)
def test_metafile_serialization(self, test_case_data: str):
def test_metafile_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
metafile = MetaFile.from_dict(copy.copy(case_dict))
self.assertDictEqual(case_dict, metafile.to_dict())
@ -242,7 +242,7 @@ def test_metafile_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_timestamps)
def test_invalid_timestamp_serialization(self, test_case_data: str):
def test_invalid_timestamp_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises((ValueError, KeyError)):
Timestamp.from_dict(copy.deepcopy(case_dict))
@ -255,7 +255,7 @@ def test_invalid_timestamp_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_timestamps)
def test_timestamp_serialization(self, test_case_data: str):
def test_timestamp_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
timestamp = Timestamp.from_dict(copy.deepcopy(case_dict))
self.assertDictEqual(case_dict, timestamp.to_dict())
@ -274,7 +274,7 @@ def test_timestamp_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_snapshots)
def test_snapshot_serialization(self, test_case_data: str):
def test_snapshot_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
snapshot = Snapshot.from_dict(copy.deepcopy(case_dict))
self.assertDictEqual(case_dict, snapshot.to_dict())
@ -295,7 +295,7 @@ def test_snapshot_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_delegated_roles)
def test_delegated_role_serialization(self, test_case_data: str):
def test_delegated_role_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
deserialized_role = DelegatedRole.from_dict(copy.copy(case_dict))
self.assertDictEqual(case_dict, deserialized_role.to_dict())
@ -312,7 +312,9 @@ def test_delegated_role_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_delegated_roles)
def test_invalid_delegated_role_serialization(self, test_case_data: str):
def test_invalid_delegated_role_serialization(
self, test_case_data: str
) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises(ValueError):
DelegatedRole.from_dict(copy.copy(case_dict))
@ -339,7 +341,9 @@ def test_invalid_delegated_role_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_delegations)
def test_invalid_delegation_serialization(self, test_case_data: str):
def test_invalid_delegation_serialization(
self, test_case_data: str
) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises((ValueError, KeyError, AttributeError)):
Delegations.from_dict(copy.deepcopy(case_dict))
@ -361,7 +365,7 @@ def test_invalid_delegation_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_delegations)
def test_delegation_serialization(self, test_case_data: str):
def test_delegation_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
delegation = Delegations.from_dict(copy.deepcopy(case_dict))
self.assertDictEqual(case_dict, delegation.to_dict())
@ -374,7 +378,9 @@ def test_delegation_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(invalid_targetfiles)
def test_invalid_targetfile_serialization(self, test_case_data: str):
def test_invalid_targetfile_serialization(
self, test_case_data: str
) -> None:
case_dict = json.loads(test_case_data)
with self.assertRaises(KeyError):
TargetFile.from_dict(copy.deepcopy(case_dict), "file1.txt")
@ -388,7 +394,7 @@ def test_invalid_targetfile_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_targetfiles)
def test_targetfile_serialization(self, test_case_data: str):
def test_targetfile_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
target_file = TargetFile.from_dict(copy.copy(case_dict), "file1.txt")
self.assertDictEqual(case_dict, target_file.to_dict())
@ -420,7 +426,7 @@ def test_targetfile_serialization(self, test_case_data: str):
}
@utils.run_sub_tests_with_dataset(valid_targets)
def test_targets_serialization(self, test_case_data: str):
def test_targets_serialization(self, test_case_data: str) -> None:
case_dict = json.loads(test_case_data)
targets = Targets.from_dict(copy.deepcopy(case_dict))
self.assertDictEqual(case_dict, targets.to_dict())

View file

@ -4,7 +4,7 @@
import sys
import unittest
from datetime import datetime
from typing import Callable, Optional
from typing import Callable, ClassVar, Dict, List, Optional, Tuple
from securesystemslib.interface import (
import_ed25519_privatekey_from_file,
@ -18,7 +18,6 @@
Metadata,
MetaFile,
Root,
Signed,
Snapshot,
Targets,
Timestamp,
@ -31,8 +30,13 @@
class TestTrustedMetadataSet(unittest.TestCase):
"""Tests for all public API of the TrustedMetadataSet class."""
keystore: ClassVar[Dict[str, SSlibSigner]]
metadata: ClassVar[Dict[str, bytes]]
repo_dir: ClassVar[str]
@classmethod
def modify_metadata(
self, rolename: str, modification_func: Callable[[Signed], None]
cls, rolename: str, modification_func: Callable
) -> bytes:
"""Instantiate metadata from rolename type, call modification_func and
sign it again with self.keystore[rolename] signer.
@ -42,13 +46,13 @@ def modify_metadata(
modification_func: Function that will be called to modify the signed
portion of metadata bytes.
"""
metadata = Metadata.from_bytes(self.metadata[rolename])
metadata = Metadata.from_bytes(cls.metadata[rolename])
modification_func(metadata.signed)
metadata.sign(self.keystore[rolename])
metadata.sign(cls.keystore[rolename])
return metadata.to_bytes()
@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
cls.repo_dir = os.path.join(
os.getcwd(), "repository_data", "repository", "metadata"
)
@ -81,7 +85,7 @@ def hashes_length_modifier(timestamp: Timestamp) -> None:
timestamp.snapshot_meta.length = None
cls.metadata["timestamp"] = cls.modify_metadata(
cls, "timestamp", hashes_length_modifier
"timestamp", hashes_length_modifier
)
def setUp(self) -> None:
@ -91,7 +95,7 @@ def _update_all_besides_targets(
self,
timestamp_bytes: Optional[bytes] = None,
snapshot_bytes: Optional[bytes] = None,
):
) -> None:
"""Update all metadata roles besides targets.
Args:
@ -109,7 +113,7 @@ def _update_all_besides_targets(
snapshot_bytes = snapshot_bytes or self.metadata["snapshot"]
self.trusted_set.update_snapshot(snapshot_bytes)
def test_update(self):
def test_update(self) -> None:
self.trusted_set.update_timestamp(self.metadata["timestamp"])
self.trusted_set.update_snapshot(self.metadata["snapshot"])
self.trusted_set.update_targets(self.metadata["targets"])
@ -129,8 +133,10 @@ def test_update(self):
self.assertTrue(count, 6)
def test_update_metadata_output(self):
timestamp = self.trusted_set.update_timestamp(self.metadata["timestamp"])
def test_update_metadata_output(self) -> None:
timestamp = self.trusted_set.update_timestamp(
self.metadata["timestamp"]
)
snapshot = self.trusted_set.update_snapshot(self.metadata["snapshot"])
targets = self.trusted_set.update_targets(self.metadata["targets"])
delegeted_targets_1 = self.trusted_set.update_delegated_targets(
@ -145,7 +151,7 @@ def test_update_metadata_output(self):
self.assertIsInstance(delegeted_targets_1.signed, Targets)
self.assertIsInstance(delegeted_targets_2.signed, Targets)
def test_out_of_order_ops(self):
def test_out_of_order_ops(self) -> None:
# Update snapshot before timestamp
with self.assertRaises(RuntimeError):
self.trusted_set.update_snapshot(self.metadata["snapshot"])
@ -182,7 +188,7 @@ def test_out_of_order_ops(self):
self.metadata["role1"], "role1", "targets"
)
def test_root_with_invalid_json(self):
def test_root_with_invalid_json(self) -> None:
# Test loading initial root and root update
for test_func in [TrustedMetadataSet, self.trusted_set.update_root]:
# root is not json
@ -199,8 +205,8 @@ def test_root_with_invalid_json(self):
with self.assertRaises(exceptions.RepositoryError):
test_func(self.metadata["snapshot"])
def test_top_level_md_with_invalid_json(self):
top_level_md = [
def test_top_level_md_with_invalid_json(self) -> None:
top_level_md: List[Tuple[bytes, Callable[[bytes], Metadata]]] = [
(self.metadata["timestamp"], self.trusted_set.update_timestamp),
(self.metadata["snapshot"], self.trusted_set.update_snapshot),
(self.metadata["targets"], self.trusted_set.update_targets),
@ -222,7 +228,7 @@ def test_top_level_md_with_invalid_json(self):
update_func(metadata)
def test_update_root_new_root(self):
def test_update_root_new_root(self) -> None:
# test that root can be updated with a new valid version
def root_new_version_modifier(root: Root) -> None:
root.version += 1
@ -230,19 +236,19 @@ def root_new_version_modifier(root: Root) -> None:
root = self.modify_metadata("root", root_new_version_modifier)
self.trusted_set.update_root(root)
def test_update_root_new_root_cannot_be_verified_with_threshold(self):
def test_update_root_new_root_fail_threshold_verification(self) -> None:
# new_root data with threshold which cannot be verified.
root = Metadata.from_bytes(self.metadata["root"])
# remove root role keyids representing root signatures
root.signed.roles["root"].keyids = []
root.signed.roles["root"].keyids = set()
with self.assertRaises(exceptions.UnsignedMetadataError):
self.trusted_set.update_root(root.to_bytes())
def test_update_root_new_root_ver_same_as_trusted_root_ver(self):
def test_update_root_new_root_ver_same_as_trusted_root_ver(self) -> None:
with self.assertRaises(exceptions.ReplayedMetadataError):
self.trusted_set.update_root(self.metadata["root"])
def test_root_expired_final_root(self):
def test_root_expired_final_root(self) -> None:
def root_expired_modifier(root: Root) -> None:
root.expires = datetime(1970, 1, 1)
@ -253,7 +259,7 @@ def root_expired_modifier(root: Root) -> None:
with self.assertRaises(exceptions.ExpiredMetadataError):
tmp_trusted_set.update_timestamp(self.metadata["timestamp"])
def test_update_timestamp_new_timestamp_ver_below_trusted_ver(self):
def test_update_timestamp_new_timestamp_ver_below_trusted_ver(self) -> None:
# new_timestamp.version < trusted_timestamp.version
def version_modifier(timestamp: Timestamp) -> None:
timestamp.version = 3
@ -263,7 +269,7 @@ def version_modifier(timestamp: Timestamp) -> None:
with self.assertRaises(exceptions.ReplayedMetadataError):
self.trusted_set.update_timestamp(self.metadata["timestamp"])
def test_update_timestamp_snapshot_ver_below_current(self):
def test_update_timestamp_snapshot_ver_below_current(self) -> None:
def bump_snapshot_version(timestamp: Timestamp) -> None:
timestamp.snapshot_meta.version = 2
@ -275,7 +281,7 @@ def bump_snapshot_version(timestamp: Timestamp) -> None:
with self.assertRaises(exceptions.ReplayedMetadataError):
self.trusted_set.update_timestamp(self.metadata["timestamp"])
def test_update_timestamp_expired(self):
def test_update_timestamp_expired(self) -> None:
# new_timestamp has expired
def timestamp_expired_modifier(timestamp: Timestamp) -> None:
timestamp.expires = datetime(1970, 1, 1)
@ -291,7 +297,7 @@ def timestamp_expired_modifier(timestamp: Timestamp) -> None:
with self.assertRaises(exceptions.ExpiredMetadataError):
self.trusted_set.update_snapshot(self.metadata["snapshot"])
def test_update_snapshot_length_or_hash_mismatch(self):
def test_update_snapshot_length_or_hash_mismatch(self) -> None:
def modify_snapshot_length(timestamp: Timestamp) -> None:
timestamp.snapshot_meta.length = 1
@ -302,14 +308,16 @@ def modify_snapshot_length(timestamp: Timestamp) -> None:
with self.assertRaises(exceptions.RepositoryError):
self.trusted_set.update_snapshot(self.metadata["snapshot"])
def test_update_snapshot_cannot_verify_snapshot_with_threshold(self):
def test_update_snapshot_fail_threshold_verification(self) -> None:
self.trusted_set.update_timestamp(self.metadata["timestamp"])
snapshot = Metadata.from_bytes(self.metadata["snapshot"])
snapshot.signatures.clear()
with self.assertRaises(exceptions.UnsignedMetadataError):
self.trusted_set.update_snapshot(snapshot.to_bytes())
def test_update_snapshot_version_different_timestamp_snapshot_version(self):
def test_update_snapshot_version_diverge_timestamp_snapshot_version(
self,
) -> None:
def timestamp_version_modifier(timestamp: Timestamp) -> None:
timestamp.snapshot_meta.version = 2
@ -326,7 +334,7 @@ def timestamp_version_modifier(timestamp: Timestamp) -> None:
with self.assertRaises(exceptions.BadVersionNumberError):
self.trusted_set.update_targets(self.metadata["targets"])
def test_update_snapshot_file_removed_from_meta(self):
def test_update_snapshot_file_removed_from_meta(self) -> None:
self._update_all_besides_targets(self.metadata["timestamp"])
def remove_file_from_meta(snapshot: Snapshot) -> None:
@ -337,7 +345,7 @@ def remove_file_from_meta(snapshot: Snapshot) -> None:
with self.assertRaises(exceptions.RepositoryError):
self.trusted_set.update_snapshot(snapshot)
def test_update_snapshot_meta_version_decreases(self):
def test_update_snapshot_meta_version_decreases(self) -> None:
self.trusted_set.update_timestamp(self.metadata["timestamp"])
def version_meta_modifier(snapshot: Snapshot) -> None:
@ -349,7 +357,7 @@ def version_meta_modifier(snapshot: Snapshot) -> None:
with self.assertRaises(exceptions.BadVersionNumberError):
self.trusted_set.update_snapshot(self.metadata["snapshot"])
def test_update_snapshot_expired_new_snapshot(self):
def test_update_snapshot_expired_new_snapshot(self) -> None:
self.trusted_set.update_timestamp(self.metadata["timestamp"])
def snapshot_expired_modifier(snapshot: Snapshot) -> None:
@ -364,7 +372,7 @@ def snapshot_expired_modifier(snapshot: Snapshot) -> None:
with self.assertRaises(exceptions.ExpiredMetadataError):
self.trusted_set.update_targets(self.metadata["targets"])
def test_update_snapshot_successful_rollback_checks(self):
def test_update_snapshot_successful_rollback_checks(self) -> None:
def meta_version_bump(timestamp: Timestamp) -> None:
timestamp.snapshot_meta.version += 1
@ -386,7 +394,7 @@ def version_bump(snapshot: Snapshot) -> None:
# update targets to trigger final snapshot meta version check
self.trusted_set.update_targets(self.metadata["targets"])
def test_update_targets_no_meta_in_snapshot(self):
def test_update_targets_no_meta_in_snapshot(self) -> None:
def no_meta_modifier(snapshot: Snapshot) -> None:
snapshot.meta = {}
@ -396,7 +404,7 @@ def no_meta_modifier(snapshot: Snapshot) -> None:
with self.assertRaises(exceptions.RepositoryError):
self.trusted_set.update_targets(self.metadata["targets"])
def test_update_targets_hash_different_than_snapshot_meta_hash(self):
def test_update_targets_hash_diverge_from_snapshot_meta_hash(self) -> None:
def meta_length_modifier(snapshot: Snapshot) -> None:
for metafile_path in snapshot.meta:
snapshot.meta[metafile_path] = MetaFile(version=1, length=1)
@ -407,7 +415,7 @@ def meta_length_modifier(snapshot: Snapshot) -> None:
with self.assertRaises(exceptions.RepositoryError):
self.trusted_set.update_targets(self.metadata["targets"])
def test_update_targets_version_different_snapshot_meta_version(self):
def test_update_targets_version_diverge_snapshot_meta_version(self) -> None:
def meta_modifier(snapshot: Snapshot) -> None:
for metafile_path in snapshot.meta:
snapshot.meta[metafile_path] = MetaFile(version=2)
@ -418,7 +426,7 @@ def meta_modifier(snapshot: Snapshot) -> None:
with self.assertRaises(exceptions.BadVersionNumberError):
self.trusted_set.update_targets(self.metadata["targets"])
def test_update_targets_expired_new_target(self):
def test_update_targets_expired_new_target(self) -> None:
self._update_all_besides_targets()
# new_delegated_target has expired
def target_expired_modifier(target: Targets) -> None:

View file

@ -37,8 +37,8 @@ class TestUpdaterKeyRotations(unittest.TestCase):
dump_dir: Optional[str] = None
def setUp(self) -> None:
self.sim = None
self.metadata_dir = None
self.sim: RepositorySimulator
self.metadata_dir: str
self.subtest_count = 0
# pylint: disable-next=consider-using-with
self.temp_dir = tempfile.TemporaryDirectory()

View file

@ -12,14 +12,14 @@
import sys
import tempfile
import unittest
from typing import List
from typing import Callable, ClassVar, List
from securesystemslib.interface import import_rsa_privatekey_from_file
from securesystemslib.signer import SSlibSigner
from tests import utils
from tuf import exceptions, ngclient, unittest_toolbox
from tuf.api.metadata import Metadata, TargetFile
from tuf.api.metadata import Metadata, Root, TargetFile
logger = logging.getLogger(__name__)
@ -27,8 +27,11 @@
class TestUpdater(unittest_toolbox.Modified_TestCase):
"""Test the Updater class from 'tuf/ngclient/updater.py'."""
temporary_directory: ClassVar[str]
server_process_handler: ClassVar[utils.TestServerProcess]
@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
# Create a temporary directory to store the repository, metadata, and
# target files. 'temporary_directory' must be deleted in
# TearDownModule() so that temporary files are always removed, even when
@ -38,18 +41,18 @@ def setUpClass(cls):
# Needed because in some tests simple_server.py cannot be found.
# The reason is that the current working directory
# has been changed when executing a subprocess.
cls.SIMPLE_SERVER_PATH = os.path.join(os.getcwd(), "simple_server.py")
SIMPLE_SERVER_PATH = os.path.join(os.getcwd(), "simple_server.py")
# Launch a SimpleHTTPServer (serves files in the current directory).
# Test cases will request metadata and target files that have been
# pre-generated in 'tuf/tests/repository_data', which will be served
# by the SimpleHTTPServer launched here.
cls.server_process_handler = utils.TestServerProcess(
log=logger, server=cls.SIMPLE_SERVER_PATH
log=logger, server=SIMPLE_SERVER_PATH
)
@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
# Cleans the resources and flush the logged lines (if any).
cls.server_process_handler.clean()
@ -57,7 +60,7 @@ def tearDownClass(cls):
# the metadata, targets, and key files generated for the test cases
shutil.rmtree(cls.temporary_directory)
def setUp(self):
def setUp(self) -> None:
# We are inheriting from custom class.
unittest_toolbox.Modified_TestCase.setUp(self)
@ -124,7 +127,7 @@ def setUp(self):
target_base_url=self.targets_url,
)
def tearDown(self):
def tearDown(self) -> None:
# We are inheriting from custom class.
unittest_toolbox.Modified_TestCase.tearDown(self)
@ -132,7 +135,9 @@ def tearDown(self):
self.server_process_handler.flush_log()
def _modify_repository_root(
self, modification_func, bump_version=False
self,
modification_func: Callable[[Metadata], None],
bump_version: bool = False,
) -> None:
"""Apply 'modification_func' to root and persist it."""
role_path = os.path.join(
@ -159,13 +164,13 @@ def _modify_repository_root(
)
)
def _assert_files(self, roles: List[str]):
def _assert_files(self, roles: List[str]) -> None:
"""Assert that local metadata files exist for 'roles'"""
expected_files = [f"{role}.json" for role in roles]
client_files = sorted(os.listdir(self.client_directory))
self.assertEqual(client_files, expected_files)
def test_refresh_and_download(self):
def test_refresh_and_download(self) -> None:
# Test refresh without consistent targets - targets without hash prefix.
# top-level targets are already in local cache (but remove others)
@ -179,10 +184,12 @@ def test_refresh_and_download(self):
# Get targetinfos, assert that cache does not contain files
info1 = self.updater.get_targetinfo("file1.txt")
assert isinstance(info1, TargetFile)
self._assert_files(["root", "snapshot", "targets", "timestamp"])
# Get targetinfo for 'file3.txt' listed in the delegated role1
info3 = self.updater.get_targetinfo("file3.txt")
assert isinstance(info3, TargetFile)
expected_files = ["role1", "root", "snapshot", "targets", "timestamp"]
self._assert_files(expected_files)
self.assertIsNone(self.updater.find_cached_target(info1))
@ -200,7 +207,7 @@ def test_refresh_and_download(self):
path = self.updater.find_cached_target(info3)
self.assertEqual(path, os.path.join(self.dl_dir, info3.path))
def test_refresh_with_only_local_root(self):
def test_refresh_with_only_local_root(self) -> None:
os.remove(os.path.join(self.client_directory, "timestamp.json"))
os.remove(os.path.join(self.client_directory, "snapshot.json"))
os.remove(os.path.join(self.client_directory, "targets.json"))
@ -217,7 +224,7 @@ def test_refresh_with_only_local_root(self):
expected_files = ["role1", "root", "snapshot", "targets", "timestamp"]
self._assert_files(expected_files)
def test_implicit_refresh_with_only_local_root(self):
def test_implicit_refresh_with_only_local_root(self) -> None:
os.remove(os.path.join(self.client_directory, "timestamp.json"))
os.remove(os.path.join(self.client_directory, "snapshot.json"))
os.remove(os.path.join(self.client_directory, "targets.json"))
@ -231,7 +238,7 @@ def test_implicit_refresh_with_only_local_root(self):
expected_files = ["role1", "root", "snapshot", "targets", "timestamp"]
self._assert_files(expected_files)
def test_both_target_urls_not_set(self):
def test_both_target_urls_not_set(self) -> None:
# target_base_url = None and Updater._target_base_url = None
updater = ngclient.Updater(
self.client_directory, self.metadata_url, self.dl_dir
@ -240,7 +247,7 @@ def test_both_target_urls_not_set(self):
with self.assertRaises(ValueError):
updater.download_target(info)
def test_no_target_dir_no_filepath(self):
def test_no_target_dir_no_filepath(self) -> None:
# filepath = None and Updater.target_dir = None
updater = ngclient.Updater(self.client_directory, self.metadata_url)
info = TargetFile(1, {"sha256": ""}, "targetpath")
@ -249,15 +256,17 @@ def test_no_target_dir_no_filepath(self):
with self.assertRaises(ValueError):
updater.download_target(info)
def test_external_targets_url(self):
def test_external_targets_url(self) -> None:
self.updater.refresh()
info = self.updater.get_targetinfo("file1.txt")
assert isinstance(info, TargetFile)
self.updater.download_target(info, target_base_url=self.targets_url)
def test_length_hash_mismatch(self):
def test_length_hash_mismatch(self) -> None:
self.updater.refresh()
targetinfo = self.updater.get_targetinfo("file1.txt")
assert isinstance(targetinfo, TargetFile)
length = targetinfo.length
with self.assertRaises(exceptions.RepositoryError):
@ -270,13 +279,13 @@ def test_length_hash_mismatch(self):
self.updater.download_target(targetinfo)
# pylint: disable=protected-access
def test_updating_root(self):
def test_updating_root(self) -> None:
# Bump root version, resign and refresh
self._modify_repository_root(lambda root: None, bump_version=True)
self.updater.refresh()
self.assertEqual(self.updater._trusted_set.root.signed.version, 2)
def test_missing_targetinfo(self):
def test_missing_targetinfo(self) -> None:
self.updater.refresh()
# Get targetinfo for non-existing file

View file

@ -349,7 +349,7 @@ def test_new_snapshot_unsigned(self) -> None:
self._assert_files_exist(["root", "timestamp"])
def test_new_snapshot_version_mismatch(self):
def test_new_snapshot_version_mismatch(self) -> None:
# Check against timestamp roles snapshot version
# Increase snapshot version without updating timestamp
@ -414,7 +414,7 @@ def test_new_targets_unsigned(self) -> None:
self._assert_files_exist(["root", "timestamp", "snapshot"])
def test_new_targets_version_mismatch(self):
def test_new_targets_version_mismatch(self) -> None:
# Check against snapshot roles targets version
# Increase targets version without updating snapshot

View file

@ -16,7 +16,7 @@
from tests import utils
from tests.repository_simulator import RepositorySimulator
from tuf.api.metadata import SPECIFICATION_VERSION, Targets
from tuf.api.metadata import SPECIFICATION_VERSION, TargetFile, Targets
from tuf.exceptions import BadVersionNumberError, UnsignedMetadataError
from tuf.ngclient import Updater
@ -27,7 +27,7 @@ class TestUpdater(unittest.TestCase):
# set dump_dir to trigger repository state dumps
dump_dir: Optional[str] = None
def setUp(self):
def setUp(self) -> None:
# pylint: disable-next=consider-using-with
self.temp_dir = tempfile.TemporaryDirectory()
self.metadata_dir = os.path.join(self.temp_dir.name, "metadata")
@ -49,7 +49,7 @@ def setUp(self):
self.sim.dump_dir = os.path.join(self.dump_dir, name)
os.mkdir(self.sim.dump_dir)
def tearDown(self):
def tearDown(self) -> None:
self.temp_dir.cleanup()
def _run_refresh(self) -> Updater:
@ -67,7 +67,7 @@ def _run_refresh(self) -> Updater:
updater.refresh()
return updater
def test_refresh(self):
def test_refresh(self) -> None:
# Update top level metadata
self._run_refresh()
@ -99,7 +99,7 @@ def test_refresh(self):
}
@utils.run_sub_tests_with_dataset(targets)
def test_targets(self, test_case_data: Tuple[str, bytes, str]):
def test_targets(self, test_case_data: Tuple[str, bytes, str]) -> None:
targetpath, content, encoded_path = test_case_data
path = os.path.join(self.targets_dir, encoded_path)
@ -117,7 +117,7 @@ def test_targets(self, test_case_data: Tuple[str, bytes, str]):
updater = self._run_refresh()
# target now exists, is not in cache yet
info = updater.get_targetinfo(targetpath)
self.assertIsNotNone(info)
assert info is not None
# Test without and with explicit local filepath
self.assertIsNone(updater.find_cached_target(info))
self.assertIsNone(updater.find_cached_target(info, path))
@ -136,7 +136,7 @@ def test_targets(self, test_case_data: Tuple[str, bytes, str]):
self.assertEqual(path, updater.find_cached_target(info))
self.assertEqual(path, updater.find_cached_target(info, path))
def test_fishy_rolenames(self):
def test_fishy_rolenames(self) -> None:
roles_to_filenames = {
"../a": "..%2Fa.json",
"": ".json",
@ -162,7 +162,7 @@ def test_fishy_rolenames(self):
for fname in roles_to_filenames.values():
self.assertTrue(fname in local_metadata)
def test_keys_and_signatures(self):
def test_keys_and_signatures(self) -> None:
"""Example of the two trickiest test areas: keys and root updates"""
# Update top level metadata
@ -202,7 +202,7 @@ def test_keys_and_signatures(self):
self._run_refresh()
def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self):
def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self) -> None:
# Test triggering snapshot rollback check on a newly downloaded snapshot
# when the local snapshot is loaded even when there is a hash mismatch
# with timestamp.snapshot_meta.
@ -233,7 +233,7 @@ def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self):
self._run_refresh()
@patch.object(builtins, "open", wraps=builtins.open)
def test_not_loading_targets_twice(self, wrapped_open: MagicMock):
def test_not_loading_targets_twice(self, wrapped_open: MagicMock) -> None:
# Do not load targets roles more than once when traversing
# the delegations tree

View file

@ -21,7 +21,7 @@
"""
from contextlib import contextmanager
from typing import Dict, Any, Callable
from typing import Any, Callable, Dict, IO, Optional, Callable, List, Iterator
import unittest
import argparse
import errno
@ -48,9 +48,13 @@
# Test runner decorator: Runs the test as a set of N SubTests,
# (where N is number of items in dataset), feeding the actual test
# function one test case at a time
def run_sub_tests_with_dataset(dataset: DataSet):
def real_decorator(function: Callable[[unittest.TestCase, Any], None]):
def wrapper(test_cls: unittest.TestCase):
def run_sub_tests_with_dataset(
dataset: DataSet
) -> Callable[[Callable], Callable]:
def real_decorator(
function: Callable[[unittest.TestCase, Any], None]
) -> Callable[[unittest.TestCase], None]:
def wrapper(test_cls: unittest.TestCase) -> None:
for case, data in dataset.items():
with test_cls.subTest(case=case):
function(test_cls, data)
@ -60,15 +64,15 @@ def wrapper(test_cls: unittest.TestCase):
class TestServerProcessError(Exception):
def __init__(self, value="TestServerProcess"):
def __init__(self, value: str="TestServerProcess") -> None:
self.value = value
def __str__(self):
def __str__(self) -> str:
return repr(self.value)
@contextmanager
def ignore_deprecation_warnings(module):
def ignore_deprecation_warnings(module: str) -> Iterator[None]:
with warnings.catch_warnings():
warnings.filterwarnings('ignore',
category=DeprecationWarning,
@ -82,13 +86,16 @@ def ignore_deprecation_warnings(module):
# but the current blocking connect() seems to work fast on Linux and seems
# to at least work on Windows (ECONNREFUSED unfortunately has a 2 second
# timeout on Windows)
def wait_for_server(host, server, port, timeout=10):
def wait_for_server(host: str, server: str, port: int, timeout: int=10) -> None:
start = time.time()
remaining_timeout = timeout
succeeded = False
while not succeeded and remaining_timeout > 0:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock: Optional[socket.socket] = socket.socket(
socket.AF_INET, socket.SOCK_STREAM
)
assert sock is not None
sock.settimeout(remaining_timeout)
sock.connect((host, port))
succeeded = True
@ -104,14 +111,14 @@ def wait_for_server(host, server, port, timeout=10):
if sock:
sock.close()
sock = None
remaining_timeout = timeout - (time.time() - start)
remaining_timeout = int(timeout - (time.time() - start))
if not succeeded:
raise TimeoutError("Could not connect to the " + server \
+ " on port " + str(port) + "!")
def configure_test_logging(argv):
def configure_test_logging(argv: List[str]) -> None:
# parse arguments but only handle '-v': argv may contain
# other things meant for unittest argument parser
parser = argparse.ArgumentParser(add_help=False)
@ -165,13 +172,14 @@ class TestServerProcess():
"""
def __init__(self, log, server='simple_server.py',
timeout=10, popen_cwd=".", extra_cmd_args=None):
def __init__(self, log: logging.Logger, server: str='simple_server.py',
timeout: int=10, popen_cwd: str=".", extra_cmd_args: Optional[List[str]]=None
):
self.server = server
self.__logger = log
# Stores popped messages from the queue.
self.__logged_messages = []
self.__logged_messages: List[str] = []
if extra_cmd_args is None:
extra_cmd_args = []
@ -185,7 +193,9 @@ def __init__(self, log, server='simple_server.py',
def _start_server(self, timeout, extra_cmd_args, popen_cwd):
def _start_server(
self, timeout: int, extra_cmd_args: List[str], popen_cwd: str
) -> None:
"""
Start the server subprocess and a thread
responsible to redirect stdout/stderr to the Queue.
@ -201,7 +211,7 @@ def _start_server(self, timeout, extra_cmd_args, popen_cwd):
def _start_process(self, extra_cmd_args, popen_cwd):
def _start_process(self, extra_cmd_args: List[str], popen_cwd: str) -> None:
"""Starts the process running the server."""
# The "-u" option forces stdin, stdout and stderr to be unbuffered.
@ -213,7 +223,7 @@ def _start_process(self, extra_cmd_args, popen_cwd):
def _start_redirect_thread(self):
def _start_redirect_thread(self) -> None:
"""Starts a thread responsible to redirect stdout/stderr to the Queue."""
# Run log_queue_worker() in a thread.
@ -228,7 +238,7 @@ def _start_redirect_thread(self):
@staticmethod
def _log_queue_worker(stream, line_queue):
def _log_queue_worker(stream: IO, line_queue: queue.Queue) -> None:
"""
Worker function to run in a seprate thread.
Reads from 'stream', puts lines in a Queue (Queue is thread-safe).
@ -247,7 +257,7 @@ def _log_queue_worker(stream, line_queue):
def _wait_for_port(self, timeout):
def _wait_for_port(self, timeout: int) -> None:
"""
Validates the first item from the Queue against the port message.
If validation is successful, self.port is set.
@ -279,7 +289,7 @@ def _wait_for_port(self, timeout):
def _kill_server_process(self):
def _kill_server_process(self) -> None:
"""Kills the server subprocess if it's running."""
if self.is_process_running():
@ -290,7 +300,7 @@ def _kill_server_process(self):
def flush_log(self):
def flush_log(self) -> None:
"""Flushes the log lines from the logging queue."""
while True:
@ -311,7 +321,7 @@ def flush_log(self):
def clean(self):
def clean(self) -> None:
"""
Kills the subprocess and closes the TempFile.
Calls flush_log to check for logged information, but not yet flushed.
@ -324,5 +334,5 @@ def clean(self):
def is_process_running(self):
def is_process_running(self) -> bool:
return True if self.__server_process.poll() is None else False

View file

@ -182,7 +182,7 @@ def filter(self, record):
def set_log_level(log_level=_DEFAULT_LOG_LEVEL):
def set_log_level(log_level: int=_DEFAULT_LOG_LEVEL):
"""
<Purpose>
Allow the default log level to be overridden. If 'log_level' is not

View file

@ -29,6 +29,7 @@
import random
import string
from typing import Optional
class Modified_TestCase(unittest.TestCase):
"""
@ -70,12 +71,12 @@ def setUp():
"""
def setUp(self):
def setUp(self) -> None:
self._cleanup = []
def tearDown(self):
def tearDown(self) -> None:
for cleanup_function in self._cleanup:
# Perform clean up by executing clean-up functions.
try:
@ -87,7 +88,7 @@ def tearDown(self):
def make_temp_directory(self, directory=None):
def make_temp_directory(self, directory: Optional[str]=None) -> str:
"""Creates and returns an absolute path of a directory."""
prefix = self.__class__.__name__+'_'
@ -102,7 +103,9 @@ def _destroy_temp_directory():
def make_temp_file(self, suffix='.txt', directory=None):
def make_temp_file(
self,suffix: str='.txt', directory: Optional[str]=None
) -> str:
"""Creates and returns an absolute path of an empty file."""
prefix='tmp_file_'+self.__class__.__name__+'_'
temp_file = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=directory)
@ -113,7 +116,9 @@ def _destroy_temp_file():
def make_temp_data_file(self, suffix='', directory=None, data = 'junk data'):
def make_temp_data_file(
self, suffix: str='', directory: Optional[str]=None, data: str = 'junk data'
) -> str:
"""Returns an absolute path of a temp file containing data."""
temp_file_path = self.make_temp_file(suffix=suffix, directory=directory)
temp_file = open(temp_file_path, 'wt', encoding='utf8')
@ -123,7 +128,7 @@ def make_temp_data_file(self, suffix='', directory=None, data = 'junk data'):
def random_path(self, length = 7):
def random_path(self, length: int = 7) -> str:
"""Generate a 'random' path consisting of random n-length strings."""
rand_path = '/' + self.random_string(length)
@ -136,7 +141,7 @@ def random_path(self, length = 7):
@staticmethod
def random_string(length=15):
def random_string(length: int=15) -> str:
"""Generate a random string of specified length."""
rand_str = ''