Avoid leaving unclosed file objects

* move code to only create objects after potential raises
* Use 'with' when possible
* close manually if those did not help

Signed-off-by: Jussi Kukkonen <jkukkonen@vmware.com>
This commit is contained in:
Jussi Kukkonen 2020-08-11 13:25:44 +03:00
parent 87e92d589f
commit b5a3c705db
3 changed files with 46 additions and 48 deletions

View file

@ -105,13 +105,11 @@ def tearDown(self):
def test_download_url_to_tempfileobj(self):
download_file = download.safe_download
temp_fileobj = download_file(self.url, self.target_data_length)
temp_fileobj.seek(0)
temp_file_data = temp_fileobj.read().decode('utf-8')
self.assertEqual(self.target_data, temp_file_data)
self.assertEqual(self.target_data_length, len(temp_file_data))
temp_fileobj.close()
with download_file(self.url, self.target_data_length) as temp_fileobj:
temp_fileobj.seek(0)
temp_file_data = temp_fileobj.read().decode('utf-8')
self.assertEqual(self.target_data, temp_file_data)
self.assertEqual(self.target_data_length, len(temp_file_data))
@ -344,16 +342,16 @@ def test_https_connection(self):
# TODO: Confirm necessity of this session clearing and lay out mechanics.
tuf.download._sessions = {}
logger.info('Trying HTTPS download of target file: ' + good_https_url)
download.safe_download(good_https_url, target_data_length)
download.unsafe_download(good_https_url, target_data_length)
download.safe_download(good_https_url, target_data_length).close()
download.unsafe_download(good_https_url, target_data_length).close()
os.environ['REQUESTS_CA_BUNDLE'] = good2_cert_fname
# Clear sessions to ensure that the certificate we just specified is used.
# TODO: Confirm necessity of this session clearing and lay out mechanics.
tuf.download._sessions = {}
logger.info('Trying HTTPS download of target file: ' + good2_https_url)
download.safe_download(good2_https_url, target_data_length)
download.unsafe_download(good2_https_url, target_data_length)
download.safe_download(good2_https_url, target_data_length).close()
download.unsafe_download(good2_https_url, target_data_length).close()
finally:
for proc in [

View file

@ -1627,12 +1627,12 @@ def test_9__get_target_hash(self):
def test_10__hard_check_file_length(self):
# Test for exception if file object is not equal to trusted file length.
temp_file_object = tempfile.TemporaryFile()
temp_file_object.write(b'X')
temp_file_object.seek(0)
self.assertRaises(tuf.exceptions.DownloadLengthMismatchError,
self.repository_updater._hard_check_file_length,
temp_file_object, 10)
with tempfile.TemporaryFile() as temp_file_object:
temp_file_object.write(b'X')
temp_file_object.seek(0)
self.assertRaises(tuf.exceptions.DownloadLengthMismatchError,
self.repository_updater._hard_check_file_length,
temp_file_object, 10)
@ -1640,19 +1640,19 @@ def test_10__hard_check_file_length(self):
def test_10__soft_check_file_length(self):
# Test for exception if file object is not equal to trusted file length.
temp_file_object = tempfile.TemporaryFile()
temp_file_object.write(b'XXX')
temp_file_object.seek(0)
self.assertRaises(tuf.exceptions.DownloadLengthMismatchError,
self.repository_updater._soft_check_file_length,
temp_file_object, 1)
with tempfile.TemporaryFile() as temp_file_object:
temp_file_object.write(b'XXX')
temp_file_object.seek(0)
self.assertRaises(tuf.exceptions.DownloadLengthMismatchError,
self.repository_updater._soft_check_file_length,
temp_file_object, 1)
# Verify that an exception is not raised if the file length <= the observed
# file length.
temp_file_object.seek(0)
self.repository_updater._soft_check_file_length(temp_file_object, 3)
temp_file_object.seek(0)
self.repository_updater._soft_check_file_length(temp_file_object, 4)
# Verify that an exception is not raised if the file length <= the observed
# file length.
temp_file_object.seek(0)
self.repository_updater._soft_check_file_length(temp_file_object, 3)
temp_file_object.seek(0)
self.repository_updater._soft_check_file_length(temp_file_object, 4)
@ -1763,14 +1763,13 @@ def test_10__visit_child_role(self):
def test_11__verify_metadata_file(self):
# Test for invalid metadata content.
metadata_file_object = tempfile.TemporaryFile()
metadata_file_object.write(b'X')
metadata_file_object.seek(0)
self.assertRaises(tuf.exceptions.InvalidMetadataJSONError,
self.repository_updater._verify_metadata_file,
metadata_file_object, 'root')
with tempfile.TemporaryFile() as metadata_file_object:
metadata_file_object.write(b'X')
metadata_file_object.seek(0)
self.assertRaises(tuf.exceptions.InvalidMetadataJSONError,
self.repository_updater._verify_metadata_file,
metadata_file_object, 'root')
def test_12__get_file(self):
@ -1788,10 +1787,10 @@ def verify_target_file(targets_path):
self.repository_updater._check_hashes(targets_path, file_hashes)
self.repository_updater._get_file('targets.json', verify_target_file,
file_type, file_size, download_safely=True)
file_type, file_size, download_safely=True).close()
self.repository_updater._get_file('targets.json', verify_target_file,
file_type, file_size, download_safely=False)
file_type, file_size, download_safely=False).close()
def test_13__targets_of_role(self):
# Test case where a list of targets is given. By default, the 'targets'

View file

@ -1637,7 +1637,9 @@ def _get_metadata_file(self, metadata_role, remote_filename,
# Remember the error from this mirror, and "reset" the target file.
logger.exception('Update failed from ' + file_mirror + '.')
file_mirror_errors[file_mirror] = exception
file_object = None
if file_object:
file_object.close()
file_object = None
else:
break
@ -3281,15 +3283,9 @@ def download_target(self, target, destination_directory,
trusted_length = target['fileinfo']['length']
trusted_hashes = target['fileinfo']['hashes']
# '_get_target_file()' checks every mirror and returns the first target
# that passes verification.
target_file_object = self._get_target_file(target_filepath, trusted_length,
trusted_hashes, prefix_filename_with_hash)
# We acquired a target file object from a mirror. Move the file into place
# (i.e., locally to 'destination_directory'). Note: join() discards
# 'destination_directory' if 'target_path' contains a leading path
# separator (i.e., is treated as an absolute path).
# Build absolute 'destination' file path.
# Note: join() discards 'destination_directory' if 'target_path' contains
# a leading path separator (i.e., is treated as an absolute path).
destination = os.path.join(destination_directory,
target_filepath.lstrip(os.sep))
destination = os.path.abspath(destination)
@ -3310,4 +3306,9 @@ def download_target(self, target, destination_directory,
else:
raise
# '_get_target_file()' checks every mirror and returns the first target
# that passes verification.
target_file_object = self._get_target_file(target_filepath, trusted_length,
trusted_hashes, prefix_filename_with_hash)
securesystemslib.util.persist_temp_file(target_file_object, destination)