diff --git a/tests/test_fetcher_ng.py b/tests/test_fetcher_ng.py index 37eb67dd..99fcf258 100644 --- a/tests/test_fetcher_ng.py +++ b/tests/test_fetcher_ng.py @@ -13,10 +13,13 @@ import unittest import tempfile import math +import urllib3.exceptions +import requests from tests import utils from tuf import exceptions, unittest_toolbox from tuf.ngclient._internal.requests_fetcher import RequestsFetcher +from unittest.mock import patch logger = logging.getLogger(__name__) @@ -121,6 +124,13 @@ def test_read_timeout(self): slow_server_process_handler.clean() + # Read/connect session timeout error + @patch.object(requests.Session, 'get', side_effect=urllib3.exceptions.TimeoutError) + def test_session_get_timeout(self, mock_session_get): + with self.assertRaises(exceptions.SlowRetrievalError): + self.fetcher.fetch(self.url) + mock_session_get.assert_called_once() + # Simple bytes download def test_download_bytes(self): data = self.fetcher.download_bytes(self.url, self.file_length) diff --git a/tuf/ngclient/_internal/requests_fetcher.py b/tuf/ngclient/_internal/requests_fetcher.py index 3647e58f..60f293e8 100644 --- a/tuf/ngclient/_internal/requests_fetcher.py +++ b/tuf/ngclient/_internal/requests_fetcher.py @@ -77,7 +77,13 @@ def fetch(self, url: str) -> Iterator[bytes]: # requests as: # - connect timeout (max delay before first byte is received) # - read (gap) timeout (max delay between bytes received) - response = session.get(url, stream=True, timeout=self.socket_timeout) + try: + response = session.get( + url, stream=True, timeout=self.socket_timeout + ) + except urllib3.exceptions.TimeoutError as e: + raise exceptions.SlowRetrievalError from e + # Check response status. try: response.raise_for_status()