From 88e5611a09a86147bcd609fa148dff65d7d24909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Diot?= Date: Sun, 10 Mar 2024 17:28:01 +0000 Subject: [PATCH] Add SQLite PRAGMA settings for foreign keys and journal mode in an event + Fix shenanigans with service_id and job_cache --- src/common/db/Database.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/common/db/Database.py b/src/common/db/Database.py index 6dfe0ff3e..09ba162b5 100644 --- a/src/common/db/Database.py +++ b/src/common/db/Database.py @@ -39,7 +39,8 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in (( from common_utils import file_hash # type: ignore from pymysql import install_as_MySQLdb -from sqlalchemy import create_engine, MetaData as sql_metadata, text, inspect +from sqlalchemy import create_engine, event, MetaData as sql_metadata, text, inspect +from sqlalchemy.engine import Engine from sqlalchemy.exc import ( ArgumentError, DatabaseError, @@ -49,10 +50,20 @@ from sqlalchemy.exc import ( ) from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import QueuePool +from sqlite3 import Connection as SQLiteConnection install_as_MySQLdb() +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, _): + if isinstance(dbapi_connection, SQLiteConnection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() + + class Database: DB_STRING_RX = re_compile(r"^(?P(mariadb|mysql)(\+pymysql)?|sqlite(\+pysqlite)?|postgresql(\+psycopg)?):/+(?P/[^\s]+)") @@ -154,14 +165,8 @@ class Database: self.logger.error(f"Error when trying to connect to the database: {format_exc()}") exit(1) - self.logger.info("✅ Database connection established") - self.suffix_rx = re_compile(r"_\d+$") - - if sqlalchemy_string.startswith("sqlite"): - with self.__db_session() as session: - session.execute(text("PRAGMA journal_mode=WAL")) - session.commit() + self.logger.info("✅ Database connection established") def __del__(self) -> None: """Close the database""" @@ -1186,6 +1191,7 @@ class Database: ) -> str: """Update the plugin cache in the database""" job_name = job_name or basename(getsourcefile(_getframe(1))).replace(".py", "") + service_id = service_id or None with self.__db_session() as session: cache = session.query(Jobs_cache).filter_by(job_name=job_name, service_id=service_id, file_name=file_name).first()