Move network IO logic to RequestsFetcher

Abstract the network IO. Move the network operations from
tuf.download to the RequestsFercher class which is TUF's
implementation of the abstract FetcherInterface.

Signed-off-by: Teodora Sechkova <tsechkova@vmware.com>
This commit is contained in:
Teodora Sechkova 2020-12-14 13:55:32 +02:00
parent 41ffe7aab1
commit 815fe24f00
No known key found for this signature in database
GPG key ID: 65F78F613EA1914E

View file

@ -32,42 +32,23 @@
from __future__ import unicode_literals
import logging
import time
import timeit
import tempfile
import tuf
import requests
import securesystemslib
import securesystemslib.util
import six
import tuf
import tuf.fetcher
import tuf.exceptions
import tuf.formats
import urllib3.exceptions
# See 'log.py' to learn how logging is handled in TUF.
logger = logging.getLogger(__name__)
# From http://docs.python-requests.org/en/master/user/advanced/#session-objects:
#
# "The Session object allows you to persist certain parameters across requests.
# It also persists cookies across all requests made from the Session instance,
# and will use urllib3's connection pooling. So if you're making several
# requests to the same host, the underlying TCP connection will be reused,
# which can result in a significant performance increase (see HTTP persistent
# connection)."
#
# NOTE: We use a separate requests.Session per scheme+hostname combination, in
# order to reuse connections to the same hostname to improve efficiency, but
# avoiding sharing state between different hosts-scheme combinations to
# minimize subtle security issues. Some cookies may not be HTTP-safe.
_sessions = {}
def safe_download(url, required_length):
def safe_download(url, required_length, fetcher):
"""
<Purpose>
Given the 'url' and 'required_length' of the desired file, open a connection
@ -84,6 +65,10 @@ def safe_download(url, required_length):
An integer value representing the length of the file. This is an exact
limit.
fetcher:
An object implementing FetcherInterface that performs the network IO
operations.
<Side Effects>
A file object is created on disk to store the contents of 'url'.
@ -105,13 +90,13 @@ def safe_download(url, required_length):
securesystemslib.formats.URL_SCHEMA.check_match(url)
tuf.formats.LENGTH_SCHEMA.check_match(required_length)
return _download_file(url, required_length, STRICT_REQUIRED_LENGTH=True)
return _download_file(url, required_length, fetcher, STRICT_REQUIRED_LENGTH=True)
def unsafe_download(url, required_length):
def unsafe_download(url, required_length, fetcher):
"""
<Purpose>
Given the 'url' and 'required_length' of the desired file, open a connection
@ -128,6 +113,10 @@ def unsafe_download(url, required_length):
An integer value representing the length of the file. This is an upper
limit.
fetcher:
An object implementing FetcherInterface that performs the network IO
operations.
<Side Effects>
A file object is created on disk to store the contents of 'url'.
@ -149,13 +138,13 @@ def unsafe_download(url, required_length):
securesystemslib.formats.URL_SCHEMA.check_match(url)
tuf.formats.LENGTH_SCHEMA.check_match(required_length)
return _download_file(url, required_length, STRICT_REQUIRED_LENGTH=False)
return _download_file(url, required_length, fetcher, STRICT_REQUIRED_LENGTH=False)
def _download_file(url, required_length, STRICT_REQUIRED_LENGTH=True):
def _download_file(url, required_length, fetcher, STRICT_REQUIRED_LENGTH=True):
"""
<Purpose>
Given the url and length of the desired file, this function opens a
@ -192,12 +181,6 @@ def _download_file(url, required_length, STRICT_REQUIRED_LENGTH=True):
<Returns>
A file object that points to the contents of 'url'.
"""
# Do all of the arguments have the appropriate format?
# Raise 'securesystemslib.exceptions.FormatError' if there is a mismatch.
securesystemslib.formats.URL_SCHEMA.check_match(url)
tuf.formats.LENGTH_SCHEMA.check_match(required_length)
# 'url.replace('\\', '/')' is needed for compatibility with Windows-based
# systems, because they might use back-slashes in place of forward-slashes.
# This converts it to the common format. unquote() replaces %xx escapes in a
@ -209,62 +192,35 @@ def _download_file(url, required_length, STRICT_REQUIRED_LENGTH=True):
# This is the temporary file that we will return to contain the contents of
# the downloaded file.
temp_file = tempfile.TemporaryFile()
start_time = timeit.default_timer()
average_download_speed = 0
number_of_bytes_received = 0
try:
# Use a different requests.Session per schema+hostname combination, to
# reuse connections while minimizing subtle security issues.
parsed_url = six.moves.urllib.parse.urlparse(url)
if not parsed_url.scheme or not parsed_url.hostname:
raise tuf.exceptions.URLParsingError(
'Could not get scheme and hostname from URL: ' + url)
for chunk in fetcher.fetch(url, required_length):
session_index = parsed_url.scheme + '+' + parsed_url.hostname
stop_time = timeit.default_timer()
seconds_spent_receiving = stop_time - start_time
# Measure the average download speed.
number_of_bytes_received += len(chunk)
average_download_speed = number_of_bytes_received / seconds_spent_receiving
logger.debug('url: ' + url)
logger.debug('session index: ' + session_index)
if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
logger.debug('The average download speed dropped below the minimum'
' average download speed set in tuf.settings.py.')
break
session = _sessions.get(session_index)
else:
logger.debug('The average download speed has not dipped below the'
' minimum average download speed set in tuf.settings.py.')
if not session:
session = requests.Session()
_sessions[session_index] = session
# Attach some default headers to every Session.
requests_user_agent = session.headers['User-Agent']
# Follows the RFC: https://tools.ietf.org/html/rfc7231#section-5.5.3
tuf_user_agent = 'tuf/' + tuf.__version__ + ' ' + requests_user_agent
session.headers.update({
# Tell the server not to compress or modify anything.
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding#Directives
'Accept-Encoding': 'identity',
# The TUF user agent.
'User-Agent': tuf_user_agent})
logger.debug('Made new session for ' + session_index)
else:
logger.debug('Reusing session for ' + session_index)
# Get the requests.Response object for this URL.
#
# Defer downloading the response body with stream=True.
# Always set the timeout. This timeout value is interpreted by requests as:
# - connect timeout (max delay before first byte is received)
# - read (gap) timeout (max delay between bytes received)
with session.get(url, stream=True,
timeout=tuf.settings.SOCKET_TIMEOUT) as response:
# Check response status.
response.raise_for_status()
# Download the contents of the URL, up to the required length, to a
# temporary file, and get the total number of downloaded bytes.
total_downloaded, average_download_speed = \
_download_fixed_amount_of_data(response, temp_file, required_length)
temp_file.write(chunk)
# Does the total number of downloaded bytes match the required length?
_check_downloaded_length(total_downloaded, required_length,
_check_downloaded_length(number_of_bytes_received, required_length,
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
average_download_speed=average_download_speed)
@ -280,107 +236,6 @@ def _download_file(url, required_length, STRICT_REQUIRED_LENGTH=True):
def _download_fixed_amount_of_data(response, temp_file, required_length):
"""
<Purpose>
This is a helper function, where the download really happens. While-block
reads data from response a fixed chunk of data at a time, or less, until
'required_length' is reached.
<Arguments>
response:
The object for communicating with the server about the contents of a URL.
temp_file:
A temporary file where the contents at the URL specified by the
'response' object will be stored.
required_length:
The number of bytes that we must download for the file. This is almost
always specified by the TUF metadata for the data file in question
(except in the case of timestamp metadata, in which case we would fix a
reasonable upper bound).
<Side Effects>
Data from the server will be written to 'temp_file'.
<Exceptions>
tuf.exceptions.SlowRetrievalError
will be raised if urllib3.exceptions.ReadTimeoutError is caught (if the
download times out).
Otherwise, runtime or network exceptions will be raised without question.
<Returns>
A (total_downloaded, average_download_speed) tuple, where
'total_downloaded' is the total number of bytes downloaded for the desired
file and the 'average_download_speed' calculated for the download
attempt.
"""
# Keep track of total bytes downloaded.
number_of_bytes_received = 0
average_download_speed = 0
start_time = timeit.default_timer()
try:
while True:
# We download a fixed chunk of data in every round. This is so that we
# can defend against slow retrieval attacks. Furthermore, we do not wish
# to download an extremely large file in one shot.
# Before beginning the round, sleep (if set) for a short amount of time
# so that the CPU is not hogged in the while loop.
if tuf.settings.SLEEP_BEFORE_ROUND:
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)
read_amount = min(
tuf.settings.CHUNK_SIZE, required_length - number_of_bytes_received)
# NOTE: This may not handle some servers adding a Content-Encoding
# header, which may cause urllib3 to misbehave:
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
data = response.raw.read(read_amount)
number_of_bytes_received = number_of_bytes_received + len(data)
# Data successfully read from the response. Store it.
temp_file.write(data)
if number_of_bytes_received == required_length:
break
stop_time = timeit.default_timer()
seconds_spent_receiving = stop_time - start_time
# Measure the average download speed.
average_download_speed = number_of_bytes_received / seconds_spent_receiving
if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
logger.debug('The average download speed dropped below the minimum'
' average download speed set in tuf.settings.py.')
break
else:
logger.debug('The average download speed has not dipped below the'
' minimum average download speed set in tuf.settings.py.')
# We might have no more data to read. Check number of bytes downloaded.
if not data:
logger.debug('Downloaded ' + repr(number_of_bytes_received) + '/' +
repr(required_length) + ' bytes.')
# Finally, we signal that the download is complete.
break
except urllib3.exceptions.ReadTimeoutError as e:
raise tuf.exceptions.SlowRetrievalError(str(e))
return number_of_bytes_received, average_download_speed
def _check_downloaded_length(total_downloaded, required_length,
STRICT_REQUIRED_LENGTH=True,
average_download_speed=None):