Fix shenanigans with separated model by joining them into one

This commit is contained in:
Théophile Diot 2024-10-29 14:55:44 +01:00
parent b71cf63cb0
commit c1fe1a5483
No known key found for this signature in database
GPG key ID: FA995104A0BA376A
6 changed files with 123 additions and 251 deletions

View file

@ -1216,6 +1216,7 @@ class Database:
self.logger.warning(f'Restoring data for table "{table_name}"')
self.logger.debug(f"Data: {data}")
for row in data:
two_factor_enabled = getattr(row, "is_two_factor_enabled", None)
external_column = getattr(row, "external", None)
row = {column: getattr(row, column) for column in Base.metadata.tables[table_name].columns.keys() if hasattr(row, column)}
@ -1225,6 +1226,12 @@ class Database:
elif table_name in ("bw_services", "bw_instances") and "creation_date" not in row:
row["creation_date"] = datetime.now().astimezone()
row["last_update" if table_name == "bw_services" else "last_seen"] = datetime.now().astimezone()
elif table_name == "bw_ui_users" and two_factor_enabled is not None:
if two_factor_enabled:
self.logger.warning(
"Detected old user model, as we implemented advanced security in the new model (custom salt for passwords, totp, etc.), you will have to re set the two factor authentication for the admin user."
)
row["admin"] = True
with self._db_session() as session:
try:

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
from sqlalchemy import TEXT, Boolean, Column, DateTime, Enum, ForeignKey, Identity, Integer, LargeBinary, String
from sqlalchemy import TEXT, Boolean, Column, DateTime, Enum, ForeignKey, Identity, Integer, LargeBinary, String, UnicodeText
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.schema import UniqueConstraint
@ -305,3 +305,94 @@ class Metadata(Base):
integration = Column(INTEGRATIONS_ENUM, default="Unknown", nullable=False)
version = Column(String(32), default="1.6.0-beta", nullable=False)
ui_version = Column(String(32), default="1.6.0-beta", nullable=False)
## UI Models
THEMES_ENUM = Enum("light", "dark", name="themes_enum")
class Users(Base):
__tablename__ = "bw_ui_users"
username = Column(String(256), primary_key=True)
email = Column(String(256), unique=True, nullable=True)
password = Column(String(60), nullable=False)
method = Column(METHODS_ENUM, nullable=False, default="manual")
admin = Column(Boolean, nullable=False, default=False)
theme = Column(THEMES_ENUM, nullable=False, default="light")
# 2FA
totp_secret = Column(String(256), nullable=True)
creation_date = Column(DateTime(timezone=True), nullable=False)
update_date = Column(DateTime(timezone=True), nullable=False)
roles = relationship("RolesUsers", back_populates="user", cascade="all")
recovery_codes = relationship("UserRecoveryCodes", back_populates="user", cascade="all")
sessions = relationship("UserSessions", back_populates="user", cascade="all")
list_roles: list[str] = []
list_permissions: list[str] = []
list_recovery_codes: list[str] = []
class Roles(Base):
__tablename__ = "bw_ui_roles"
name = Column(String(64), primary_key=True)
description = Column(String(256), nullable=False)
update_datetime = Column(DateTime(timezone=True), nullable=False)
users = relationship("RolesUsers", back_populates="role", cascade="all")
permissions = relationship("RolesPermissions", back_populates="role", cascade="all")
class RolesUsers(Base):
__tablename__ = "bw_ui_roles_users"
user_name = Column(String(256), ForeignKey("bw_ui_users.username", onupdate="cascade", ondelete="cascade"), primary_key=True)
role_name = Column(String(64), ForeignKey("bw_ui_roles.name", onupdate="cascade", ondelete="cascade"), primary_key=True)
user = relationship("Users", back_populates="roles")
role = relationship("Roles", back_populates="users")
class UserRecoveryCodes(Base):
__tablename__ = "bw_ui_user_recovery_codes"
id = Column(Integer, Identity(start=1, increment=1), primary_key=True)
user_name = Column(String(256), ForeignKey("bw_ui_users.username", onupdate="cascade", ondelete="cascade"), nullable=False)
code = Column(UnicodeText, nullable=False)
user = relationship("Users", back_populates="recovery_codes")
class RolesPermissions(Base):
__tablename__ = "bw_ui_roles_permissions"
role_name = Column(String(64), ForeignKey("bw_ui_roles.name", onupdate="cascade", ondelete="cascade"), primary_key=True)
permission_name = Column(String(64), ForeignKey("bw_ui_permissions.name", onupdate="cascade", ondelete="cascade"), primary_key=True)
role = relationship("Roles", back_populates="permissions")
permission = relationship("Permissions", back_populates="roles")
class Permissions(Base):
__tablename__ = "bw_ui_permissions"
name = Column(String(64), primary_key=True)
roles = relationship("RolesPermissions", back_populates="permission", cascade="all")
class UserSessions(Base):
__tablename__ = "bw_ui_user_sessions"
id = Column(Integer, Identity(start=1, increment=1), primary_key=True)
user_name = Column(String(256), ForeignKey("bw_ui_users.username", onupdate="cascade", ondelete="cascade"), nullable=False)
ip = Column(String(39), nullable=False)
user_agent = Column(TEXT, nullable=False)
creation_date = Column(DateTime(timezone=True), nullable=False)
last_activity = Column(DateTime(timezone=True), nullable=False)
user = relationship("Users", back_populates="sessions")

View file

@ -8,14 +8,8 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
from bcrypt import checkpw
from flask_login import AnonymousUserMixin, UserMixin
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy import TEXT, Boolean, DateTime, Column, Enum, Identity, Integer, String, ForeignKey, UnicodeText
from model import METHODS_ENUM # type: ignore
THEMES_ENUM = Enum("light", "dark", name="themes_enum")
Base = declarative_base()
from model import Users # type: ignore
class AnonymousUser(AnonymousUserMixin):
@ -38,93 +32,9 @@ class AnonymousUser(AnonymousUserMixin):
return False
class Users(Base, UserMixin):
__tablename__ = "bw_ui_users"
username = Column(String(256), primary_key=True)
email = Column(String(256), unique=True, nullable=True)
password = Column(String(60), nullable=False)
method = Column(METHODS_ENUM, nullable=False, default="manual")
admin = Column(Boolean, nullable=False, default=False)
theme = Column(THEMES_ENUM, nullable=False, default="light")
# 2FA
totp_secret = Column(String(256), nullable=True)
creation_date = Column(DateTime(timezone=True), nullable=False)
update_date = Column(DateTime(timezone=True), nullable=False)
roles = relationship("RolesUsers", back_populates="user", cascade="all")
recovery_codes = relationship("UserRecoveryCodes", back_populates="user", cascade="all")
sessions = relationship("UserSessions", back_populates="user", cascade="all")
list_roles: list[str] = []
list_permissions: list[str] = []
list_recovery_codes: list[str] = []
class UiUsers(Users, UserMixin):
def get_id(self):
return self.username
def check_password(self, password: str) -> bool:
return checkpw(password.encode("utf-8"), self.password.encode("utf-8"))
class Roles(Base):
__tablename__ = "bw_ui_roles"
name = Column(String(64), primary_key=True)
description = Column(String(256), nullable=False)
update_datetime = Column(DateTime(timezone=True), nullable=False)
users = relationship("RolesUsers", back_populates="role", cascade="all")
permissions = relationship("RolesPermissions", back_populates="role", cascade="all")
class RolesUsers(Base):
__tablename__ = "bw_ui_roles_users"
user_name = Column(String(256), ForeignKey("bw_ui_users.username", onupdate="cascade", ondelete="cascade"), primary_key=True)
role_name = Column(String(64), ForeignKey("bw_ui_roles.name", onupdate="cascade", ondelete="cascade"), primary_key=True)
user = relationship("Users", back_populates="roles")
role = relationship("Roles", back_populates="users")
class UserRecoveryCodes(Base):
__tablename__ = "bw_ui_user_recovery_codes"
id = Column(Integer, Identity(start=1, increment=1), primary_key=True)
user_name = Column(String(256), ForeignKey("bw_ui_users.username", onupdate="cascade", ondelete="cascade"), nullable=False)
code = Column(UnicodeText, nullable=False)
user = relationship("Users", back_populates="recovery_codes")
class RolesPermissions(Base):
__tablename__ = "bw_ui_roles_permissions"
role_name = Column(String(64), ForeignKey("bw_ui_roles.name", onupdate="cascade", ondelete="cascade"), primary_key=True)
permission_name = Column(String(64), ForeignKey("bw_ui_permissions.name", onupdate="cascade", ondelete="cascade"), primary_key=True)
role = relationship("Roles", back_populates="permissions")
permission = relationship("Permissions", back_populates="roles")
class Permissions(Base):
__tablename__ = "bw_ui_permissions"
name = Column(String(64), primary_key=True)
roles = relationship("RolesPermissions", back_populates="permission", cascade="all")
class UserSessions(Base):
__tablename__ = "bw_ui_user_sessions"
id = Column(Integer, Identity(start=1, increment=1), primary_key=True)
user_name = Column(String(256), ForeignKey("bw_ui_users.username", onupdate="cascade", ondelete="cascade"), nullable=False)
ip = Column(String(39), nullable=False)
user_agent = Column(TEXT, nullable=False)
creation_date = Column(DateTime(timezone=True), nullable=False)
last_activity = Column(DateTime(timezone=True), nullable=False)
user = relationship("Users", back_populates="sessions")

View file

@ -9,7 +9,7 @@ from passlib.pwd import genword
from qrcode import make
from qrcode.image.svg import SvgImage
from app.models.models import Users
from app.models.models import UiUsers
from app.dependencies import DATA
from app.utils import LIB_DIR, LOGGER, stop
@ -56,7 +56,7 @@ class Totp:
def generate_recovery_codes(self) -> List[str]:
return ["-".join([pwd[i : i + 4] for i in range(0, len(pwd), 4)]) for pwd in genword(length=16, charset="hex", returns=6)] # noqa: E203
def verify_recovery_code(self, code: str, user: Users) -> Optional[str]:
def verify_recovery_code(self, code: str, user: UiUsers) -> Optional[str]:
"""Check if recovery code is valid for user."""
if not user.list_recovery_codes:
return
@ -65,7 +65,7 @@ class Totp:
if checkpw(code.encode("utf-8"), encrypted_code.encode("utf-8")):
return user.list_recovery_codes.pop(i)
def verify_totp(self, token: str, *, totp_secret: Optional[str] = None, user: Optional[Users] = None) -> bool:
def verify_totp(self, token: str, *, totp_secret: Optional[str] = None, user: Optional[UiUsers] = None) -> bool:
"""Verifies token for specific user."""
if not totp_secret and not user:
raise ValueError("Either totp_secret or user must be provided")
@ -93,12 +93,12 @@ class Totp:
return f"data:image/svg+xml;base64,{image_as_str}"
def get_last_counter(self, user: Users) -> Optional[int]:
def get_last_counter(self, user: UiUsers) -> Optional[int]:
"""Fetch stored last_counter from cache."""
DATA.load_from_file()
return DATA.get("totp_last_counter", {}).get(user.get_id())
def set_last_counter(self, user: Users, tmatch: TotpMatch) -> None:
def set_last_counter(self, user: UiUsers, tmatch: TotpMatch) -> None:
"""Cache last_counter."""
DATA.load_from_file()
if "totp_last_counter" not in DATA:

View file

@ -3,8 +3,7 @@ from logging import Logger
from os import sep
from os.path import join
from sys import path as sys_path
from time import sleep
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Union
for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in (("deps", "python"), ("utils",), ("api",), ("db",))]:
@ -12,144 +11,26 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
sys_path.append(deps_path)
from bcrypt import gensalt, hashpw
from sqlalchemy import MetaData, inspect, text
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import IntegrityError
from Database import Database # type: ignore
from model import Metadata # type: ignore
from model import Permissions, Roles, RolesPermissions, RolesUsers, UserRecoveryCodes, UserSessions # type: ignore
from app.models.models import Base, Permissions, Roles, RolesPermissions, RolesUsers, Users, UserRecoveryCodes, UserSessions
from app.models.models import UiUsers
class UIDatabase(Database):
def __init__(self, logger: Logger, sqlalchemy_string: Optional[str] = None, *, pool: Optional[bool] = None, log: bool = True, **kwargs) -> None:
super().__init__(logger, sqlalchemy_string, ui=True, pool=pool, log=log, **kwargs)
def init_ui_tables(self, bunkerweb_version: str) -> Tuple[bool, str]:
"""Initialize the database ui tables and return the result"""
if self.readonly:
return False, "The database is read-only, the changes will not be saved"
assert self.sql_engine is not None, "The database engine is not initialized"
inspector = inspect(self.sql_engine)
db_version = None
has_all_tables = True
old_data = {}
if inspector and len(inspector.get_table_names()):
metadata = self.get_metadata()
db_version = metadata["ui_version"]
if metadata["default"]:
db_version = "error"
if db_version != bunkerweb_version:
self.logger.warning(f"UI tables version ({db_version}) is different from BunkerWeb version ({bunkerweb_version}), migrating them ...")
current_time = datetime.now().astimezone()
error = True
while error:
try:
metadata = MetaData()
metadata.reflect(self.sql_engine)
error = False
except BaseException as e:
if (datetime.now().astimezone() - current_time).total_seconds() > 10:
raise e
sleep(1)
assert isinstance(metadata, MetaData)
for table_name in Base.metadata.tables.keys():
if not inspector.has_table(table_name):
self.logger.warning(f'UI table "{table_name}" is missing, creating it')
has_all_tables = False
continue
with self._db_session() as session:
old_data[table_name] = session.query(metadata.tables[table_name]).all()
# Rename the old tables
db_version_id = db_version.replace(".", "_")
with self._db_session() as session:
for table_name in metadata.tables.keys():
if table_name in Base.metadata.tables:
if inspector.has_table(f"{table_name}_{db_version_id}"):
self.logger.warning(f'UI table "{table_name}" already exists, dropping it to make room for the new one')
session.execute(text(f"DROP TABLE {table_name}_{db_version_id}"))
session.execute(text(f"ALTER TABLE {table_name} RENAME TO {table_name}_{db_version_id}"))
session.commit()
Base.metadata.drop_all(self.sql_engine)
else:
for table_name in Base.metadata.tables.keys():
if not inspector.has_table(table_name):
self.logger.warning(f'UI table "{table_name}" is missing, creating it')
has_all_tables = False
continue
if has_all_tables and db_version and db_version == bunkerweb_version:
return False, ""
self.logger.info("Creating UI tables ...")
try:
Base.metadata.create_all(self.sql_engine, checkfirst=True)
except BaseException as e:
return False, str(e)
if db_version and db_version != bunkerweb_version:
for table_name, data in old_data.items():
if not data:
continue
self.logger.warning(f'Restoring data for ui table "{table_name}"')
self.logger.debug(f"Data: {data}")
for row in data:
two_factor_enabled = getattr(row, "is_two_factor_enabled", None)
row = {column: getattr(row, column) for column in Base.metadata.tables[table_name].columns.keys() if hasattr(row, column)}
if table_name == "bw_ui_users" and two_factor_enabled is not None:
if two_factor_enabled:
self.logger.warning(
"Detected old user model, as we implemented advanced security in the new model (custom salt for passwords, totp, etc.), you will have to re set the two factor authentication for the admin user."
)
row["admin"] = True
with self._db_session() as session:
try:
# Check if the row already exists in the table
existing_row = session.query(Base.metadata.tables[table_name]).filter_by(**row).first()
if not existing_row:
session.execute(Base.metadata.tables[table_name].insert().values(row))
session.commit()
except IntegrityError as e:
session.rollback()
if "Duplicate entry" not in str(e):
self.logger.error(f"Error when trying to restore data for table {table_name}: {e}")
continue
self.logger.debug(e)
with self._db_session() as session:
try:
metadata = session.query(Metadata).get(1)
if metadata:
metadata.ui_version = bunkerweb_version
session.commit()
except BaseException as e:
self.logger.error(f"Error when trying to update ui_version field in metadata: {e}")
return True, ""
def get_ui_user(self, *, username: Optional[str] = None, as_dict: bool = False) -> Optional[Union[Users, dict]]:
def get_ui_user(self, *, username: Optional[str] = None, as_dict: bool = False) -> Optional[Union[UiUsers, dict]]:
"""Get ui user. If username is None, return the first admin user."""
with self._db_session() as session:
if username:
query = session.query(Users).filter_by(username=username)
query = session.query(UiUsers).filter_by(username=username)
else:
query = session.query(Users).filter_by(admin=True)
query = query.options(joinedload(Users.roles), joinedload(Users.recovery_codes))
query = session.query(UiUsers).filter_by(admin=True)
query = query.options(joinedload(UiUsers.roles), joinedload(UiUsers.recovery_codes))
ui_user = query.first()
@ -191,10 +72,10 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
if admin and session.query(Users).with_entities(Users.username).filter_by(admin=True).first():
if admin and session.query(UiUsers).with_entities(UiUsers.username).filter_by(admin=True).first():
return "An admin user already exists"
user = session.query(Users).with_entities(Users.username).filter_by(username=username).first()
user = session.query(UiUsers).with_entities(UiUsers.username).filter_by(username=username).first()
if user:
return f"User {username} already exists"
@ -205,7 +86,7 @@ class UIDatabase(Database):
current_time = datetime.now().astimezone()
session.add(
Users(
UiUsers(
username=username,
email=email,
password=password.decode("utf-8"),
@ -247,12 +128,12 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
user = session.query(Users).filter_by(username=old_username).first()
user = session.query(UiUsers).filter_by(username=old_username).first()
if not user:
return f"User {old_username} doesn't exist"
if username != old_username:
if session.query(Users).with_entities(Users.username).filter_by(username=username).first():
if session.query(UiUsers).with_entities(UiUsers.username).filter_by(username=username).first():
return f"User {username} already exists"
user.username = username
@ -289,7 +170,7 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
user = session.query(Users).filter_by(username=username).first()
user = session.query(UiUsers).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
@ -310,7 +191,7 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
user = session.query(Users).filter_by(username=username).first()
user = session.query(UiUsers).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
@ -405,7 +286,7 @@ class UIDatabase(Database):
if not codes:
return "No recovery codes provided"
user = session.query(Users).filter_by(username=username).first()
user = session.query(UiUsers).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
@ -467,7 +348,7 @@ class UIDatabase(Database):
def use_ui_user_recovery_code(self, username: str, hashed_code: str) -> str:
"""Use ui user recovery code."""
with self._db_session() as session:
user = session.query(Users).filter_by(username=username).first()
user = session.query(UiUsers).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
@ -519,7 +400,7 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
user = session.query(Users).filter_by(username=username).first()
user = session.query(UiUsers).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"

View file

@ -18,7 +18,6 @@ patch_all()
from passlib import totp
from common_utils import get_version # type: ignore
from logger import setup_logger # type: ignore
from app.models.ui_database import UIDatabase
@ -109,23 +108,7 @@ def on_starting(server):
continue
sleep(5)
BW_VERSION = get_version()
ret, err = DB.init_ui_tables(BW_VERSION)
if not ret and err:
if err.startswith("The database is read-only"):
LOGGER.warning(err)
else:
LOGGER.error(f"Exception while checking database tables : {err}")
exit(1)
elif not ret:
LOGGER.info("Database ui tables didn't change, skipping update ...")
else:
LOGGER.info("Database ui tables successfully updated")
if not DB.get_ui_roles(as_dict=True):
ret = DB.create_ui_role("admin", "Admins can create new users, edit and read the data.", ["manage", "write", "read"])
if ret:
LOGGER.error(f"Couldn't create the admin role in the database: {ret}")