From b5a5bd2dd14386c514d93864d280cb8d79dcaeab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Diot?= Date: Fri, 7 Jun 2024 12:00:59 +0100 Subject: [PATCH] chore: Fix database backup potential issues when testing write access at the same time --- src/common/core/backup/utils.py | 88 ++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/src/common/core/backup/utils.py b/src/common/core/backup/utils.py index e3f1c47f5..f7ecbca65 100755 --- a/src/common/core/backup/utils.py +++ b/src/common/core/backup/utils.py @@ -45,49 +45,57 @@ def backup_database(current_time: datetime, db: Database = None, backup_dir: Pat database: Literal["sqlite", "mariadb", "mysql", "postgresql"] = db.database_uri.split(":")[0].split("+")[0] # type: ignore backup_file = backup_dir.joinpath(f"backup-{database}-{current_time.strftime('%Y-%m-%d_%H-%M-%S')}.zip") LOGGER.debug(f"Backup file path: {backup_file}") + stderr = "Table 'db.test_" + current_time = datetime.now() - if database == "sqlite": - match = DB_STRING_RX.search(db.database_uri) - if not match: - LOGGER.error(f"Invalid database string provided: {db.database_uri}, skipping backup ...") + while "Table 'db.test_" in stderr and (datetime.now() - current_time).total_seconds() < 10: + if database == "sqlite": + match = DB_STRING_RX.search(db.database_uri) + if not match: + LOGGER.error(f"Invalid database string provided: {db.database_uri}, skipping backup ...") + sys_exit(1) + + db_path = Path(match.group("path")) + + LOGGER.info("Creating a backup for the SQLite database ...") + + proc = run(["sqlite3", db_path.as_posix(), ".dump"], stdout=PIPE, stderr=PIPE) + else: + db_host = db.database_uri.rsplit("@", 1)[1].split("/")[0].split(":") + db_port = None + if len(db_host) == 1: + db_host = db_host[0] + else: + db_host, db_port = db_host + + db_user = db.database_uri.split("://")[1].split(":")[0] + db_password = db.database_uri.split("://")[1].split(":")[1].rsplit("@", 1)[0] + db_database_name = db.database_uri.split("/")[-1].split("?")[0] + + if database in ("mariadb", "mysql"): + LOGGER.info("Creating a backup for the MariaDB/MySQL database ...") + + cmd = ["mysqldump", "-h", db_host, "-u", db_user, db_database_name] + if db_port: + cmd.extend(["-P", db_port]) + + proc = run(cmd, stdout=PIPE, stderr=PIPE, env=environ | {"MYSQL_PWD": db_password}) + elif database == "postgresql": + LOGGER.info("Creating a backup for the PostgreSQL database ...") + + cmd = ["pg_dump", "-h", db_host, "-U", db_user, db_database_name, "-w"] + if db_port: + cmd.extend(["-p", db_port]) + + proc = run(cmd, stdout=PIPE, stderr=PIPE, env=environ | {"PGPASSWORD": db_password}) + + stderr = proc.stderr.decode() + if "Table 'db.test_" not in stderr and proc.returncode != 0: + LOGGER.error(f"Failed to dump the database: {stderr}") sys_exit(1) - db_path = Path(match.group("path")) - - LOGGER.info("Creating a backup for the SQLite database ...") - - proc = run(["sqlite3", db_path.as_posix(), ".dump"], stdout=PIPE, stderr=PIPE) - else: - db_host = db.database_uri.rsplit("@", 1)[1].split("/")[0].split(":") - db_port = None - if len(db_host) == 1: - db_host = db_host[0] - else: - db_host, db_port = db_host - - db_user = db.database_uri.split("://")[1].split(":")[0] - db_password = db.database_uri.split("://")[1].split(":")[1].rsplit("@", 1)[0] - db_database_name = db.database_uri.split("/")[-1].split("?")[0] - - if database in ("mariadb", "mysql"): - LOGGER.info("Creating a backup for the MariaDB/MySQL database ...") - - cmd = ["mysqldump", "-h", db_host, "-u", db_user, db_database_name] - if db_port: - cmd.extend(["-P", db_port]) - - proc = run(cmd, stdout=PIPE, stderr=PIPE, env=environ | {"MYSQL_PWD": db_password}) - elif database == "postgresql": - LOGGER.info("Creating a backup for the PostgreSQL database ...") - - cmd = ["pg_dump", "-h", db_host, "-U", db_user, db_database_name, "-w"] - if db_port: - cmd.extend(["-p", db_port]) - - proc = run(cmd, stdout=PIPE, stderr=PIPE, env=environ | {"PGPASSWORD": db_password}) - - if proc.returncode != 0: - LOGGER.error(f"Failed to dump the database: {proc.stderr.decode()}") + if (datetime.now() - current_time).total_seconds() >= 10: + LOGGER.error("Failed to dump the database: Timeout reached") sys_exit(1) with ZipFile(backup_file, "w", compression=ZIP_DEFLATED) as zipf: