diff --git a/tests/test_updater_delegation_graphs.py b/tests/test_updater_delegation_graphs.py index 770a1b3d..f917bfed 100644 --- a/tests/test_updater_delegation_graphs.py +++ b/tests/test_updater_delegation_graphs.py @@ -136,7 +136,7 @@ def _assert_files_exist(self, roles: Iterable[str]) -> None: """Assert that local metadata files match 'roles'""" expected_files = [f"{role}.json" for role in roles] found_files = [ - e.name for e in os.scandir(self.metadata_dir) if e.is_file() + e.name for e in os.scandir(self.metadata_dir) if e.is_file() and e.name != ".lock" ] self.assertListEqual(sorted(found_files), sorted(expected_files)) diff --git a/tests/test_updater_ng.py b/tests/test_updater_ng.py index aa6de64f..151dc86d 100644 --- a/tests/test_updater_ng.py +++ b/tests/test_updater_ng.py @@ -8,9 +8,9 @@ import logging import os import shutil +import subprocess import sys import tempfile -import subprocess import unittest from collections.abc import Iterable from typing import TYPE_CHECKING, Callable, ClassVar @@ -158,7 +158,7 @@ def _assert_files_exist(self, roles: Iterable[str]) -> None: """Assert that local metadata files match 'roles'""" expected_files = [f"{role}.json" for role in roles] found_files = [ - e.name for e in os.scandir(self.client_directory) if e.is_file() + e.name for e in os.scandir(self.client_directory) if e.is_file() and e.name != ".lock" ] self.assertListEqual(sorted(found_files), sorted(expected_files)) diff --git a/tests/test_updater_top_level_update.py b/tests/test_updater_top_level_update.py index 76c74d4b..161858f4 100644 --- a/tests/test_updater_top_level_update.py +++ b/tests/test_updater_top_level_update.py @@ -94,7 +94,7 @@ def _assert_files_exist(self, roles: Iterable[str]) -> None: """Assert that local metadata files match 'roles'""" expected_files = [f"{role}.json" for role in roles] found_files = [ - e.name for e in os.scandir(self.metadata_dir) if e.is_file() + e.name for e in os.scandir(self.metadata_dir) if e.is_file() and e.name != ".lock" ] self.assertListEqual(sorted(found_files), sorted(expected_files)) @@ -644,14 +644,16 @@ def test_not_loading_targets_twice(self, wrapped_open: MagicMock) -> None: wrapped_open.reset_mock() # First time looking for "somepath", only 'role1' must be loaded + # (and ".lock" for metadata locking) updater.get_targetinfo("somepath") - wrapped_open.assert_called_once_with( + self.assertEqual(wrapped_open.call_count, 2) + wrapped_open.assert_called_with( os.path.join(self.metadata_dir, "role1.json"), "rb" ) wrapped_open.reset_mock() # Second call to get_targetinfo, all metadata is already loaded updater.get_targetinfo("somepath") - wrapped_open.assert_not_called() + self.assertEqual(wrapped_open.call_count, 1) def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self) -> None: # Test triggering snapshot rollback check on a newly downloaded snapshot @@ -709,6 +711,7 @@ def test_load_metadata_from_cache(self, wrapped_open: MagicMock) -> None: root_dir = os.path.join(self.metadata_dir, "root_history") wrapped_open.assert_has_calls( [ + call(os.path.join(self.metadata_dir, ".lock"), "wb"), call(os.path.join(root_dir, "2.root.json"), "rb"), call(os.path.join(self.metadata_dir, "timestamp.json"), "rb"), call(os.path.join(self.metadata_dir, "snapshot.json"), "rb"), diff --git a/tests/utils.py b/tests/utils.py index bbfb07db..727abadb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -161,7 +161,7 @@ def cleanup_metadata_dir(path: str) -> None: for entry in it: if entry.name == "root_history": cleanup_metadata_dir(entry.path) - elif entry.name.endswith(".json"): + elif entry.name.endswith(".json") or entry.name == ".lock": os.remove(entry.path) else: raise ValueError(f"Unexpected local metadata file {entry.path}") diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index a98e799c..a3477321 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -59,7 +59,7 @@ import shutil import tempfile from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import IO, TYPE_CHECKING, cast from urllib import parse from tuf.api import exceptions @@ -69,10 +69,30 @@ from tuf.ngclient.urllib3_fetcher import Urllib3Fetcher if TYPE_CHECKING: + from collections.abc import Iterator + from tuf.ngclient.fetcher import FetcherInterface logger = logging.getLogger(__name__) +try: + # advisory file locking for posix + import fcntl + def _lock_file(f: IO) -> None: + if f.writable(): + fcntl.lockf(f, fcntl.LOCK_EX) + +except ModuleNotFoundError: + # Windows file locking + import msvcrt + + def _lock_file(f: IO) -> None: + # On Windows we lock bytes, not the file + f.write(b"\0") + f.flush() + f.seek(0) + msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1) + class Updater: """Creates a new ``Updater`` instance and loads trusted root metadata. @@ -139,8 +159,23 @@ def __init__( self._trusted_set = TrustedMetadataSet( bootstrap, self.config.envelope_type ) - self._persist_root(self._trusted_set.root.version, bootstrap) - self._update_root_symlink() + with self._lock_metadata(): + self._persist_root(self._trusted_set.root.version, bootstrap) + self._update_root_symlink() + + + @contextlib.contextmanager + def _lock_metadata(self) -> Iterator[None]: + """Context manager for locking the metadata directory.""" + # Ensure the whole metadata directory structure exists + rootdir = Path(self._dir, "root_history") + rootdir.mkdir(exist_ok=True, parents=True) + + with open(os.path.join(self._dir, ".lock"), "wb") as f: + logger.debug("Getting metadata lock...") + _lock_file(f) + yield + logger.debug("Releasing metadata lock") def refresh(self) -> None: """Refresh top-level metadata. @@ -166,10 +201,11 @@ def refresh(self) -> None: DownloadError: Download of a metadata file failed in some way """ - self._load_root() - self._load_timestamp() - self._load_snapshot() - self._load_targets(Targets.type, Root.type) + with self._lock_metadata(): + self._load_root() + self._load_timestamp() + self._load_snapshot() + self._load_targets(Targets.type, Root.type) def _generate_target_file_path(self, targetinfo: TargetFile) -> str: if self.target_dir is None: @@ -205,9 +241,14 @@ def get_targetinfo(self, target_path: str) -> TargetFile | None: ``TargetFile`` instance or ``None``. """ - if Targets.type not in self._trusted_set: - self.refresh() - return self._preorder_depth_first_walk(target_path) + with self._lock_metadata(): + if Targets.type not in self._trusted_set: + # refresh + self._load_root() + self._load_timestamp() + self._load_snapshot() + self._load_targets(Targets.type, Root.type) + return self._preorder_depth_first_walk(target_path) def find_cached_target( self, @@ -335,7 +376,6 @@ def _persist_root(self, version: int, data: bytes) -> None: "root_history/1.root.json"). """ rootdir = Path(self._dir, "root_history") - rootdir.mkdir(exist_ok=True, parents=True) self._persist_file(str(rootdir / f"{version}.root.json"), data) def _persist_file(self, filename: str, data: bytes) -> None: