From b5a3c705dbbf50de118c1a116ce9142cf751ef63 Mon Sep 17 00:00:00 2001 From: Jussi Kukkonen Date: Tue, 11 Aug 2020 13:25:44 +0300 Subject: [PATCH] Avoid leaving unclosed file objects * move code to only create objects after potential raises * Use 'with' when possible * close manually if those did not help Signed-off-by: Jussi Kukkonen --- tests/test_download.py | 20 +++++++--------- tests/test_updater.py | 53 +++++++++++++++++++++--------------------- tuf/client/updater.py | 21 +++++++++-------- 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/tests/test_download.py b/tests/test_download.py index 2f97048a..4d12c41f 100755 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -105,13 +105,11 @@ def tearDown(self): def test_download_url_to_tempfileobj(self): download_file = download.safe_download - - temp_fileobj = download_file(self.url, self.target_data_length) - temp_fileobj.seek(0) - temp_file_data = temp_fileobj.read().decode('utf-8') - self.assertEqual(self.target_data, temp_file_data) - self.assertEqual(self.target_data_length, len(temp_file_data)) - temp_fileobj.close() + with download_file(self.url, self.target_data_length) as temp_fileobj: + temp_fileobj.seek(0) + temp_file_data = temp_fileobj.read().decode('utf-8') + self.assertEqual(self.target_data, temp_file_data) + self.assertEqual(self.target_data_length, len(temp_file_data)) @@ -344,16 +342,16 @@ def test_https_connection(self): # TODO: Confirm necessity of this session clearing and lay out mechanics. tuf.download._sessions = {} logger.info('Trying HTTPS download of target file: ' + good_https_url) - download.safe_download(good_https_url, target_data_length) - download.unsafe_download(good_https_url, target_data_length) + download.safe_download(good_https_url, target_data_length).close() + download.unsafe_download(good_https_url, target_data_length).close() os.environ['REQUESTS_CA_BUNDLE'] = good2_cert_fname # Clear sessions to ensure that the certificate we just specified is used. # TODO: Confirm necessity of this session clearing and lay out mechanics. tuf.download._sessions = {} logger.info('Trying HTTPS download of target file: ' + good2_https_url) - download.safe_download(good2_https_url, target_data_length) - download.unsafe_download(good2_https_url, target_data_length) + download.safe_download(good2_https_url, target_data_length).close() + download.unsafe_download(good2_https_url, target_data_length).close() finally: for proc in [ diff --git a/tests/test_updater.py b/tests/test_updater.py index 3ea0206a..f16aca63 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -1627,12 +1627,12 @@ def test_9__get_target_hash(self): def test_10__hard_check_file_length(self): # Test for exception if file object is not equal to trusted file length. - temp_file_object = tempfile.TemporaryFile() - temp_file_object.write(b'X') - temp_file_object.seek(0) - self.assertRaises(tuf.exceptions.DownloadLengthMismatchError, - self.repository_updater._hard_check_file_length, - temp_file_object, 10) + with tempfile.TemporaryFile() as temp_file_object: + temp_file_object.write(b'X') + temp_file_object.seek(0) + self.assertRaises(tuf.exceptions.DownloadLengthMismatchError, + self.repository_updater._hard_check_file_length, + temp_file_object, 10) @@ -1640,19 +1640,19 @@ def test_10__hard_check_file_length(self): def test_10__soft_check_file_length(self): # Test for exception if file object is not equal to trusted file length. - temp_file_object = tempfile.TemporaryFile() - temp_file_object.write(b'XXX') - temp_file_object.seek(0) - self.assertRaises(tuf.exceptions.DownloadLengthMismatchError, - self.repository_updater._soft_check_file_length, - temp_file_object, 1) + with tempfile.TemporaryFile() as temp_file_object: + temp_file_object.write(b'XXX') + temp_file_object.seek(0) + self.assertRaises(tuf.exceptions.DownloadLengthMismatchError, + self.repository_updater._soft_check_file_length, + temp_file_object, 1) - # Verify that an exception is not raised if the file length <= the observed - # file length. - temp_file_object.seek(0) - self.repository_updater._soft_check_file_length(temp_file_object, 3) - temp_file_object.seek(0) - self.repository_updater._soft_check_file_length(temp_file_object, 4) + # Verify that an exception is not raised if the file length <= the observed + # file length. + temp_file_object.seek(0) + self.repository_updater._soft_check_file_length(temp_file_object, 3) + temp_file_object.seek(0) + self.repository_updater._soft_check_file_length(temp_file_object, 4) @@ -1763,14 +1763,13 @@ def test_10__visit_child_role(self): def test_11__verify_metadata_file(self): # Test for invalid metadata content. - metadata_file_object = tempfile.TemporaryFile() - metadata_file_object.write(b'X') - metadata_file_object.seek(0) - - self.assertRaises(tuf.exceptions.InvalidMetadataJSONError, - self.repository_updater._verify_metadata_file, - metadata_file_object, 'root') + with tempfile.TemporaryFile() as metadata_file_object: + metadata_file_object.write(b'X') + metadata_file_object.seek(0) + self.assertRaises(tuf.exceptions.InvalidMetadataJSONError, + self.repository_updater._verify_metadata_file, + metadata_file_object, 'root') def test_12__get_file(self): @@ -1788,10 +1787,10 @@ def verify_target_file(targets_path): self.repository_updater._check_hashes(targets_path, file_hashes) self.repository_updater._get_file('targets.json', verify_target_file, - file_type, file_size, download_safely=True) + file_type, file_size, download_safely=True).close() self.repository_updater._get_file('targets.json', verify_target_file, - file_type, file_size, download_safely=False) + file_type, file_size, download_safely=False).close() def test_13__targets_of_role(self): # Test case where a list of targets is given. By default, the 'targets' diff --git a/tuf/client/updater.py b/tuf/client/updater.py index 564e285e..0ae76ee3 100755 --- a/tuf/client/updater.py +++ b/tuf/client/updater.py @@ -1637,7 +1637,9 @@ def _get_metadata_file(self, metadata_role, remote_filename, # Remember the error from this mirror, and "reset" the target file. logger.exception('Update failed from ' + file_mirror + '.') file_mirror_errors[file_mirror] = exception - file_object = None + if file_object: + file_object.close() + file_object = None else: break @@ -3281,15 +3283,9 @@ def download_target(self, target, destination_directory, trusted_length = target['fileinfo']['length'] trusted_hashes = target['fileinfo']['hashes'] - # '_get_target_file()' checks every mirror and returns the first target - # that passes verification. - target_file_object = self._get_target_file(target_filepath, trusted_length, - trusted_hashes, prefix_filename_with_hash) - - # We acquired a target file object from a mirror. Move the file into place - # (i.e., locally to 'destination_directory'). Note: join() discards - # 'destination_directory' if 'target_path' contains a leading path - # separator (i.e., is treated as an absolute path). + # Build absolute 'destination' file path. + # Note: join() discards 'destination_directory' if 'target_path' contains + # a leading path separator (i.e., is treated as an absolute path). destination = os.path.join(destination_directory, target_filepath.lstrip(os.sep)) destination = os.path.abspath(destination) @@ -3310,4 +3306,9 @@ def download_target(self, target, destination_directory, else: raise + # '_get_target_file()' checks every mirror and returns the first target + # that passes verification. + target_file_object = self._get_target_file(target_filepath, trusted_length, + trusted_hashes, prefix_filename_with_hash) + securesystemslib.util.persist_temp_file(target_file_object, destination)