Update UI data handling and TOTP verification

Refactor UI data handling to use a dedicated class and improve TOTP verification logic for better security and usability.
This commit is contained in:
Théophile Diot 2024-08-09 13:10:04 +01:00
parent 57c458d504
commit 4ad0596ad6
No known key found for this signature in database
GPG key ID: FA995104A0BA376A
5 changed files with 243 additions and 226 deletions

View file

@ -1,15 +1,28 @@
from json import JSONDecodeError, dumps, loads
from os import cpu_count, getenv, getpid, sep
from os.path import join
from pathlib import Path
from sys import path as sys_path
from random import randint
from secrets import token_urlsafe
from sys import exit, path as sys_path
from time import sleep
for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in (("deps", "python"), ("utils",), ("api",), ("db",))]:
if deps_path not in sys_path:
sys_path.append(deps_path)
from passlib import totp
from common_utils import get_version # type: ignore
from logger import setup_logger # type: ignore
from ui_database import UIDatabase
from utils import USER_PASSWORD_RX, check_password, gen_password_hash
TMP_DIR = Path(sep, "var", "tmp", "bunkerweb")
RUN_DIR = Path(sep, "var", "run", "bunkerweb")
LIB_DIR = Path(sep, "var", "lib", "bunkerweb")
MAX_WORKERS = int(getenv("MAX_WORKERS", max((cpu_count() or 1) - 1, 1)))
LOG_LEVEL = getenv("CUSTOM_LOG_LEVEL", getenv("LOG_LEVEL", "info"))
@ -19,7 +32,6 @@ proc_name = "bunkerweb-ui"
accesslog = "/var/log/bunkerweb/ui-access.log"
access_log_format = '%({x-forwarded-for}i)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
errorlog = "/var/log/bunkerweb/ui.log"
preload_app = True
reuse_port = True
worker_tmp_dir = join(sep, "dev", "shm")
tmp_upload_dir = join(sep, "var", "tmp", "bunkerweb", "ui")
@ -39,8 +51,138 @@ if DEBUG:
reload_extra_files = [file.as_posix() for file in Path(sep, "usr", "share", "bunkerweb", "ui", "templates").iterdir()]
def when_ready(server):
def on_starting(server):
TMP_DIR.mkdir(parents=True, exist_ok=True)
RUN_DIR.mkdir(parents=True, exist_ok=True)
LIB_DIR.mkdir(parents=True, exist_ok=True)
LOGGER = setup_logger("UI", getenv("CUSTOM_LOG_LEVEL", getenv("LOG_LEVEL", "INFO")))
FLASK_SECRET = getenv("FLASK_SECRET")
if not FLASK_SECRET and not TMP_DIR.joinpath(".flask_secret").is_file():
LOGGER.warning("The FLASK_SECRET environment variable is missing, generating a random one ...")
TMP_DIR.joinpath(".flask_secret").write_text(token_urlsafe(32), encoding="utf-8")
TOTP_SECRETS = getenv("TOTP_SECRETS", "")
if TOTP_SECRETS:
try:
TOTP_SECRETS = loads(TOTP_SECRETS)
except JSONDecodeError:
x = 1
tmp_secrets = {}
for secret in TOTP_SECRETS.strip().split(" "):
if secret:
tmp_secrets[x] = secret
x += 1
TOTP_SECRETS = tmp_secrets.copy()
del tmp_secrets
if not TOTP_SECRETS:
LOGGER.warning("The TOTP_SECRETS environment variable is missing, generating a random one ...")
LIB_DIR.joinpath(".totp_secrets.json").write_text(dumps({k: totp.generate_secret() for k in range(randint(1, 5))}), encoding="utf-8")
DB = UIDatabase(LOGGER)
ready = False
while not ready:
db_metadata = DB.get_metadata()
if isinstance(db_metadata, str) or not db_metadata["is_initialized"]:
LOGGER.warning("Database is not initialized, retrying in 5s ...")
else:
ready = True
continue
sleep(5)
BW_VERSION = get_version()
ret, err = DB.init_ui_tables(BW_VERSION)
if not ret and err:
LOGGER.error(f"Exception while checking database tables : {err}")
exit(1)
elif not ret:
LOGGER.info("Database ui tables didn't change, skipping update ...")
else:
LOGGER.info("Database ui tables successfully updated")
if not DB.get_ui_roles(as_dict=True):
ret = DB.create_ui_role("admin", "Admin can create account, manager software and read data.", ["manage", "write", "read"])
if ret:
LOGGER.error(f"Couldn't create the admin role in the database: {ret}")
exit(1)
ret = DB.create_ui_role("writer", "Write can manage software and read data but can't create account.", ["write", "read"])
if ret:
LOGGER.error(f"Couldn't create the admin role in the database: {ret}")
exit(1)
ret = DB.create_ui_role("reader", "Reader can read data but can't proceed to any actions.", ["read"])
if ret:
LOGGER.error(f"Couldn't create the admin role in the database: {ret}")
exit(1)
ADMIN_USER = "Error"
while ADMIN_USER == "Error":
try:
ADMIN_USER = DB.get_ui_user(as_dict=True)
except BaseException as e:
LOGGER.debug(f"Couldn't get the admin user: {e}")
sleep(1)
env_admin_username = getenv("ADMIN_USERNAME", "")
env_admin_password = getenv("ADMIN_PASSWORD", "")
if ADMIN_USER:
LOGGER.debug(f"Admin user: {ADMIN_USER}")
if env_admin_username or env_admin_password:
override_admin_creds = getenv("OVERRIDE_ADMIN_CREDS", "no").lower() == "yes"
if ADMIN_USER["method"] == "manual" or override_admin_creds:
updated = False
if env_admin_username and ADMIN_USER["username"] != env_admin_username:
ADMIN_USER["username"] = env_admin_username
updated = True
if env_admin_password and not check_password(env_admin_password, ADMIN_USER["password"]):
if not USER_PASSWORD_RX.match(env_admin_password):
LOGGER.warning(
"The admin password is not strong enough. It must contain at least 8 characters, including at least 1 uppercase letter, 1 lowercase letter, 1 number and 1 special character (#@?!$%^&*-). It will not be updated."
)
else:
ADMIN_USER["password"] = gen_password_hash(env_admin_password)
updated = True
if updated:
if override_admin_creds:
LOGGER.warning("Overriding the admin user credentials, as the OVERRIDE_ADMIN_CREDS environment variable is set to 'yes'.")
err = DB.update_ui_user(ADMIN_USER["username"], ADMIN_USER["password"], ADMIN_USER["totp_secret"], method="manual")
if err:
LOGGER.error(f"Couldn't update the admin user in the database: {err}")
else:
LOGGER.info("The admin user was updated successfully")
else:
LOGGER.warning("The admin user wasn't created manually. You can't change it from the environment variables.")
elif env_admin_username and env_admin_password:
user_name = env_admin_username or "admin"
if not DEBUG:
if len(user_name) > 256:
LOGGER.error("The admin username is too long. It must be less than 256 characters.")
exit(1)
elif not USER_PASSWORD_RX.match(env_admin_password):
LOGGER.error(
"The admin password is not strong enough. It must contain at least 8 characters, including at least 1 uppercase letter, 1 lowercase letter, 1 number and 1 special character (#@?!$%^&*-)."
)
exit(1)
ret = DB.create_ui_user(user_name, gen_password_hash(env_admin_password), ["admin"], admin=True)
if ret:
LOGGER.error(f"Couldn't create the admin user in the database: {ret}")
exit(1)
LOGGER.info("UI is ready")
def when_ready(server):
RUN_DIR.joinpath("ui.pid").write_text(str(getpid()), encoding="utf-8")
TMP_DIR.joinpath("ui.healthy").write_text("ok", encoding="utf-8")

View file

@ -3,10 +3,8 @@ import json
import base64
from contextlib import suppress
from math import floor
from multiprocessing import Manager
from os import _exit, getenv, listdir, sep
from os.path import basename, dirname, isabs, join
from random import randint
from secrets import choice, token_urlsafe
from string import ascii_letters, digits
from sys import path as sys_path, modules as sys_modules
@ -21,7 +19,6 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
from bs4 import BeautifulSoup
from copy import deepcopy
from cryptography.fernet import Fernet
from datetime import datetime, timedelta, timezone
from flask import Flask, Response, flash, jsonify, make_response, redirect, render_template, request, send_file, session, url_for
from flask_login import current_user, LoginManager, login_required, login_user, logout_user
@ -32,7 +29,6 @@ from importlib.machinery import SourceFileLoader
from io import BytesIO
from json import JSONDecodeError, dumps, loads as json_loads
from jinja2 import Environment, FileSystemLoader, select_autoescape
from passlib import totp
from redis import Redis, Sentinel
from regex import compile as re_compile
from requests import get
@ -50,6 +46,7 @@ from src.custom_config import CustomConfig
from src.config import Config
from src.reverse_proxied import ReverseProxied
from src.totp import Totp
from src.ui_data import UIData
from builder.home import home_builder
from builder.instances import instances_builder
@ -66,13 +63,10 @@ from logger import setup_logger # type: ignore
from models import AnonymousUser
from ui_database import UIDatabase
from utils import USER_PASSWORD_RX, PLUGIN_KEYS, PLUGIN_ID_RX, check_password, check_settings, gen_password_hash, path_to_dict, get_remain
from utils import USER_PASSWORD_RX, PLUGIN_KEYS, PLUGIN_ID_RX, check_settings, gen_password_hash, path_to_dict, get_remain
TMP_DIR = Path(sep, "var", "tmp", "bunkerweb")
TMP_DIR.mkdir(parents=True, exist_ok=True)
LIB_DIR = Path(sep, "var", "lib", "bunkerweb")
LIB_DIR.mkdir(parents=True, exist_ok=True)
def stop_gunicorn():
@ -112,8 +106,8 @@ with app.app_context():
FLASK_SECRET = getenv("FLASK_SECRET")
if not FLASK_SECRET:
if not TMP_DIR.joinpath(".flask_secret").is_file():
app.logger.warning("The FLASK_SECRET environment variable is missing or the .flask_secret file is missing, generating a random one ...")
TMP_DIR.joinpath(".flask_secret").write_text(token_urlsafe(32), encoding="utf-8")
app.logger.error("The FLASK_SECRET environment variable is missing and the .flask_secret file is missing, exiting ...")
stop(1)
FLASK_SECRET = TMP_DIR.joinpath(".flask_secret").read_text(encoding="utf-8").strip()
TOTP_SECRETS = getenv("TOTP_SECRETS", "")
@ -121,40 +115,21 @@ with app.app_context():
try:
TOTP_SECRETS = json_loads(TOTP_SECRETS)
except JSONDecodeError:
app.logger.warning(
"The TOTP_SECRETS environment variable is invalid, generating a random one ... (check the format via the documentation: https://passlib.readthedocs.io/en/stable/narr/totp-tutorial.html#application-secrets)"
)
TOTP_SECRETS = None
x = 1
tmp_secrets = {}
for secret in TOTP_SECRETS.strip().split(" "):
if secret:
tmp_secrets[x] = secret
x += 1
TOTP_SECRETS = tmp_secrets.copy()
del tmp_secrets
if not TOTP_SECRETS:
if not LIB_DIR.joinpath(".totp_secrets.json").is_file():
if TOTP_SECRETS is not None:
app.logger.warning("The TOTP_SECRETS environment variable is missing or the .totp_secrets.json file is missing, generating a random one ...")
LIB_DIR.joinpath(".totp_secrets.json").write_text(dumps({k: totp.generate_secret() for k in range(randint(1, 5))}), encoding="utf-8")
app.logger.error("The TOTP_SECRETS environment variable is missing and the .totp_secrets.json file is missing, exiting ...")
stop(1)
TOTP_SECRETS = json_loads(LIB_DIR.joinpath(".totp_secrets.json").read_text(encoding="utf-8"))
MF_RECOVERY_CODES_KEYS = []
if getenv("MF_ENCRYPT_RECOVERY_CODES", "yes").lower() != "no":
MF_RECOVERY_CODES_KEYS = getenv("MF_RECOVERY_CODES_KEYS", "")
if MF_RECOVERY_CODES_KEYS:
try:
MF_RECOVERY_CODES_KEYS = json_loads(MF_RECOVERY_CODES_KEYS)
except JSONDecodeError:
app.logger.warning(
"The MF_RECOVERY_CODES_KEYS environment variable is invalid, generating a random one ... (check the format via the documentation: https://cryptography.io/en/latest/fernet/#fernet-symmetric-encryption)"
)
MF_RECOVERY_CODES_KEYS = None
if not MF_RECOVERY_CODES_KEYS:
if MF_RECOVERY_CODES_KEYS is not None and not LIB_DIR.joinpath(".mf_recovery_codes_keys.json").is_file():
app.logger.warning("The MF_RECOVERY_CODES_KEYS environment variable is missing, generating a random one ...")
LIB_DIR.joinpath(".mf_recovery_codes_keys.json").write_text(
dumps([Fernet.generate_key().decode() for _ in range(randint(1, 5))]), encoding="utf-8"
)
MF_RECOVERY_CODES_KEYS = json_loads(LIB_DIR.joinpath(".mf_recovery_codes_keys.json").read_text(encoding="utf-8"))
else:
app.logger.warning("MF_ENCRYPT_RECOVERY_CODES is set to 'no', multi-factor recovery codes will not be encrypted")
app.config["SECRET_KEY"] = FLASK_SECRET
app.config["SESSION_COOKIE_NAME"] = "__Host-bw_ui_session"
@ -181,102 +156,7 @@ with app.app_context():
login_manager.login_view = "login"
login_manager.anonymous_user = AnonymousUser
app.db = UIDatabase(app.logger)
ready = False
while not ready:
db_metadata = app.db.get_metadata()
if isinstance(db_metadata, str) or not db_metadata["is_initialized"]:
app.logger.warning("Database is not initialized, retrying in 5s ...")
else:
ready = True
continue
sleep(5)
BW_VERSION = get_version()
ret, err = app.db.init_ui_tables(BW_VERSION)
if not ret and err:
app.logger.error(f"Exception while checking database tables : {err}")
exit(1)
elif not ret:
app.logger.info("Database ui tables didn't change, skipping update ...")
else:
app.logger.info("Database ui tables successfully updated")
if not app.db.get_ui_roles(as_dict=True):
ret = app.db.create_ui_role("admin", "Admin can create account, manager software and read data.", ["manage", "write", "read"])
if ret:
app.logger.error(f"Couldn't create the admin role in the database: {ret}")
exit(1)
ret = app.db.create_ui_role("writer", "Write can manage software and read data but can't create account.", ["write", "read"])
if ret:
app.logger.error(f"Couldn't create the admin role in the database: {ret}")
exit(1)
ret = app.db.create_ui_role("reader", "Reader can read data but can't proceed to any actions.", ["read"])
if ret:
app.logger.error(f"Couldn't create the admin role in the database: {ret}")
exit(1)
ADMIN_USER = "Error"
while ADMIN_USER == "Error":
try:
ADMIN_USER = app.db.get_ui_user(as_dict=True)
except BaseException as e:
app.logger.debug(f"Couldn't get the admin user: {e}")
sleep(1)
env_admin_username = getenv("ADMIN_USERNAME", "")
env_admin_password = getenv("ADMIN_PASSWORD", "")
if ADMIN_USER:
if env_admin_username or env_admin_password:
override_admin_creds = getenv("OVERRIDE_ADMIN_CREDS", "no").lower() == "yes"
if ADMIN_USER["method"] == "manual" or override_admin_creds:
updated = False
if env_admin_username and ADMIN_USER["username"] != env_admin_username:
ADMIN_USER["username"] = env_admin_username
updated = True
if env_admin_password and not check_password(env_admin_password, ADMIN_USER["password"]):
if not USER_PASSWORD_RX.match(env_admin_password):
app.logger.warning(
"The admin password is not strong enough. It must contain at least 8 characters, including at least 1 uppercase letter, 1 lowercase letter, 1 number and 1 special character (#@?!$%^&*-). It will not be updated."
)
else:
ADMIN_USER["password"] = gen_password_hash(env_admin_password)
updated = True
if updated:
if override_admin_creds:
app.logger.warning("Overriding the admin user credentials, as the OVERRIDE_ADMIN_CREDS environment variable is set to 'yes'.")
err = app.db.update_ui_user(ADMIN_USER["username"], ADMIN_USER["password"], ADMIN_USER["totp_secret"], method="manual")
if err:
app.logger.error(f"Couldn't update the admin user in the database: {err}")
else:
app.logger.info("The admin user was updated successfully")
else:
app.logger.warning("The admin user wasn't created manually. You can't change it from the environment variables.")
elif env_admin_username and env_admin_password:
user_name = env_admin_username or "admin"
if not getenv("FLASK_DEBUG", False):
if len(user_name) > 256:
app.logger.error("The admin username is too long. It must be less than 256 characters.")
exit(1)
elif not USER_PASSWORD_RX.match(env_admin_password):
app.logger.error(
"The admin password is not strong enough. It must contain at least 8 characters, including at least 1 uppercase letter, 1 lowercase letter, 1 number and 1 special character (#@?!$%^&*-)."
)
exit(1)
ret = app.db.create_ui_user(user_name, gen_password_hash(env_admin_password), ["admin"], admin=True)
if ret:
app.logger.error(f"Couldn't create the admin user in the database: {ret}")
exit(1)
app.db = UIDatabase(app.logger, log=False)
# Declare functions for jinja2
app.jinja_env.globals.update(check_settings=check_settings)
@ -288,14 +168,12 @@ with app.app_context():
app.bw_instances_utils = InstancesUtils(app.db)
app.bw_config = Config(app.db)
app.bw_custom_configs = CustomConfig()
app.data = Manager().dict()
app.totp = Totp(app, TOTP_SECRETS, [key.encode("utf-8") for key in MF_RECOVERY_CODES_KEYS])
app.data = UIData(TMP_DIR.joinpath("ui_data.json"))
app.totp = Totp(app, TOTP_SECRETS)
LOG_RX = re_compile(r"^(?P<date>\d+/\d+/\d+\s\d+:\d+:\d+)\s\[(?P<level>[a-z]+)\]\s\d+#\d+:\s(?P<message>[^\n]+)$")
REVERSE_PROXY_PATH = re_compile(r"^(?P<host>https?://.{1,255}(:((6553[0-5])|(655[0-2]\d)|(65[0-4]\d{2})|(6[0-4]\d{3})|([1-5]\d{4})|([0-5]{0,5})|(\d{1,4})))?)$")
app.logger.info("UI is ready")
def wait_applying():
current_time = datetime.now()
@ -520,6 +398,7 @@ def error_message(msg: str):
@app.context_processor
def inject_variables():
app.data.load_from_file()
metadata = app.db.get_metadata()
changes_ongoing = any(
@ -593,9 +472,6 @@ def set_security_headers(response):
# * Referrer-Policy header to prevent leaking of sensitive data
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
if current_user.totp_refreshed:
app.db.set_ui_user_recovery_code_refreshed(current_user.get_id(), False)
return response
@ -606,12 +482,12 @@ def load_user(username):
app.logger.warning(f"Couldn't get the user {username} from the database.")
return None
ui_user.list_roles = app.db.get_ui_user_roles(username)
ui_user.list_roles = [role.role_name for role in ui_user.roles]
for role in ui_user.list_roles:
ui_user.list_permissions.extend(app.db.get_ui_role_permissions(role))
if ui_user.totp_secret:
ui_user.list_recovery_codes = app.db.get_ui_user_recovery_codes(username)
ui_user.list_recovery_codes = [recovery_code.code for recovery_code in ui_user.recovery_codes]
return ui_user
@ -651,6 +527,7 @@ def handle_csrf_error(_):
@app.before_request
def before_request():
app.data.load_from_file()
if app.data.get("SERVER_STOPPING", False):
response = make_response(jsonify({"message": "Server is shutting down, try again later."}), 503)
response.headers["Retry-After"] = 30 # Clients should retry after 30 seconds # type: ignore
@ -919,7 +796,7 @@ def home():
metadata = app.db.get_metadata()
data = {
"check_version": not remote_version or BW_VERSION == remote_version,
"check_version": not remote_version or get_version() == remote_version,
"remote_version": remote_version,
"version": metadata["version"],
"instances_number": len(instances),
@ -1054,11 +931,9 @@ def account():
if totp_secret and totp_secret != current_user.totp_secret:
totp_recovery_codes = app.totp.generate_recovery_codes()
current_user.totp_refreshed = True
current_user.list_recovery_codes = totp_recovery_codes
flash(
"The recovery codes have been refreshed.\nPlease save them in a safe place. They will not be displayed again."
+ "\n".join(app.totp.decrypt_recovery_codes(current_user)),
+ "\n".join(totp_recovery_codes),
"info",
) # TODO: Remove this when we have a way to display the recovery codes
@ -2476,6 +2351,8 @@ def login():
@app.route("/check_reloading")
@login_required
def check_reloading():
app.data.load_from_file()
if not app.data.get("RELOADING", False) or app.data.get("LAST_RELOAD", 0) + 60 < time():
if app.data.get("RELOADING", False):
app.logger.warning("Reloading took too long, forcing the state to be reloaded")

View file

@ -1,7 +1,6 @@
from base64 import b64encode
from contextlib import suppress
from io import BytesIO
from cryptography.fernet import Fernet, InvalidToken, MultiFernet
from bcrypt import checkpw
from typing import Dict, List, Optional, Union
from flask import Flask
from passlib.totp import TOTP, MalformedTokenError, TokenError, TotpMatch
@ -13,7 +12,7 @@ from models import Users
class Totp:
def __init__(self, app: Flask, secrets: Dict[Union[str, int], str], recovery_codes_keys: List[bytes]):
def __init__(self, app: Flask, secrets: Dict[Union[str, int], str]):
"""Initialize a totp factory.
secrets are used to encrypt the per-user totp_secret on disk.
recovery_codes_keys are used to encrypt the per-user recovery codes on disk.
@ -22,10 +21,6 @@ class Totp:
self.app = app
self._totp = TOTP.using(secrets=secrets, issuer="BunkerWeb UI")
self.cryptor: Optional[MultiFernet] = None
if recovery_codes_keys:
self.cryptor = MultiFernet([Fernet(key) for key in recovery_codes_keys])
def generate_totp_secret(self) -> str:
"""Create new user-unique totp_secret."""
return self._totp.new().to_json(encrypt=True)
@ -37,28 +32,16 @@ class Totp:
return self._totp.from_source(totp_secret).pretty_key()
def generate_recovery_codes(self) -> List[str]:
codes = ["-".join([pwd[i : i + 4] for i in range(0, len(pwd), 4)]) for pwd in genword(length=16, charset="hex", returns=5)] # noqa: E203
if not self.cryptor:
return codes
return [self.cryptor.encrypt(code.encode()).decode() for code in codes]
def decrypt_recovery_code(self, code: str) -> Optional[str]:
if not self.cryptor:
return code
return self.cryptor.decrypt(code.encode()).decode()
def decrypt_recovery_codes(self, user: Users) -> List[str]:
return [self.decrypt_recovery_code(code) for code in user.list_recovery_codes]
return ["-".join([pwd[i : i + 4] for i in range(0, len(pwd), 4)]) for pwd in genword(length=16, charset="hex", returns=5)] # noqa: E203
def verify_recovery_code(self, code: str, user: Users) -> Optional[str]:
"""Check if recovery code is valid for user."""
if not user.list_recovery_codes:
return
with suppress(InvalidToken):
for i, decrypted_code in enumerate(self.decrypt_recovery_codes(user)):
if code == decrypted_code:
return user.list_recovery_codes.pop(i)
for i, encrypted_code in enumerate(user.list_recovery_codes):
if checkpw(code.encode("utf-8"), encrypted_code.encode("utf-8")):
return user.list_recovery_codes.pop(i)
def verify_totp(self, token: str, *, totp_secret: Optional[str] = None, user: Optional[Users] = None) -> bool:
"""Verifies token for specific user."""
@ -68,7 +51,7 @@ class Totp:
totp_secret = user.totp_secret
try:
tmatch = self._totp.verify(token, totp_secret, last_counter=self.get_last_counter(user))
tmatch = self._totp.verify(token, totp_secret, window=3, last_counter=self.get_last_counter(user))
if user:
self.set_last_counter(user, tmatch)
return True

29
src/ui/src/ui_data.py Normal file
View file

@ -0,0 +1,29 @@
from json import dumps, loads
from multiprocessing import Lock
from pathlib import Path
class UIData(dict):
def __init__(self, file_path: Path):
super().__init__()
self.file_path = file_path
self.__lock = Lock()
self.load_from_file()
def _write_to_file(self):
with self.__lock:
self.file_path.write_text(dumps(self))
def load_from_file(self):
if self.file_path.is_file():
with self.__lock:
for key, value in loads(self.file_path.read_text()).items():
super().__setitem__(key, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
self._write_to_file()
def __delitem__(self, key):
super().__delitem__(key)
self._write_to_file()

View file

@ -11,7 +11,9 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
if deps_path not in sys_path:
sys_path.append(deps_path)
from bcrypt import gensalt, hashpw
from sqlalchemy import MetaData, inspect, text
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import IntegrityError
from Database import Database # type: ignore
@ -70,9 +72,9 @@ class UIDatabase(Database):
# Rename the old tables
db_version_id = db_version.replace(".", "_")
for table_name in metadata.tables.keys():
if table_name in Base.metadata.tables:
with self._db_session() as session:
with self._db_session() as session:
for table_name in metadata.tables.keys():
if table_name in Base.metadata.tables:
if inspector.has_table(f"{table_name}_{db_version_id}"):
self.logger.warning(f'UI table "{table_name}" already exists, dropping it to make room for the new one')
session.execute(text(f"DROP TABLE {table_name}_{db_version_id}"))
@ -144,9 +146,12 @@ class UIDatabase(Database):
"""Get ui user. If username is None, return the first admin user."""
with self._db_session() as session:
if username:
ui_user = session.query(Users).filter_by(username=username).first()
query = session.query(Users).filter_by(username=username)
else:
ui_user = session.query(Users).filter_by(admin=True).first()
query = session.query(Users).filter_by(admin=True)
query = query.options(joinedload(Users.roles), joinedload(Users.recovery_codes))
ui_user = query.first()
if not ui_user:
return None
@ -164,15 +169,10 @@ class UIDatabase(Database):
"totp_secret": ui_user.totp_secret,
"creation_date": ui_user.creation_date,
"update_date": ui_user.update_date,
"roles": [],
"recovery_codes": [],
"roles": [role.role_name for role in ui_user.roles],
"recovery_codes": [recovery_code.code for recovery_code in ui_user.recovery_codes],
}
for role in session.query(RolesUsers).filter_by(user_name=ui_user.username).all():
ui_user_data["roles"].append(role.role_name)
for recovery_code in session.query(UserRecoveryCodes).filter_by(user_name=ui_user.username).all():
ui_user_data["recovery_codes"].append(recovery_code.code)
return ui_user_data
def create_ui_user(
@ -192,15 +192,15 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
if admin and session.query(Users).filter_by(admin=True).first():
if admin and session.query(Users).with_entities(Users.username).filter_by(admin=True).first():
return "An admin user already exists"
user = session.query(Users).filter_by(username=username).first()
user = session.query(Users).with_entities(Users.username).filter_by(username=username).first()
if user:
return f"User {username} already exists"
for role in roles:
if not session.query(Roles).filter_by(name=role).first():
if not session.query(Roles).with_entities(Roles.name).filter_by(name=role).first():
return f"Role {role} doesn't exist"
session.add(RolesUsers(user_name=username, role_name=role))
@ -212,12 +212,11 @@ class UIDatabase(Database):
method=method,
admin=admin,
totp_secret=totp_secret,
totp_refreshed=bool(totp_secret),
)
)
for code in totp_recovery_codes or []:
session.add(UserRecoveryCodes(user_name=username, code=code))
session.add(UserRecoveryCodes(user_name=username, code=hashpw(code.encode("utf-8"), gensalt(rounds=8)).decode("utf-8")))
try:
session.commit()
@ -230,6 +229,7 @@ class UIDatabase(Database):
self, username: str, password: bytes, totp_secret: Optional[str], *, totp_recovery_codes: Optional[List[str]] = None, method: str = "manual"
) -> str:
"""Update ui user."""
totp_changed = False
with self._db_session() as session:
if self.readonly:
return "The database is read-only, the changes will not be saved"
@ -238,8 +238,7 @@ class UIDatabase(Database):
if not user:
return f"User {username} doesn't exist"
if user.totp_secret != totp_secret:
user.totp_refreshed = True
totp_changed = user.totp_secret != totp_secret
user.password = password.decode("utf-8")
user.totp_secret = totp_secret
@ -250,7 +249,7 @@ class UIDatabase(Database):
except BaseException as e:
return str(e)
if user.totp_refreshed:
if totp_changed:
if totp_recovery_codes:
self.refresh_ui_user_recovery_codes(username, totp_recovery_codes or [])
else:
@ -303,13 +302,13 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
if session.query(Roles).filter_by(name=name).first():
if session.query(Roles).with_entities(Roles.name).filter_by(name=name).first():
return f"Role {name} already exists"
session.add(Roles(name=name, description=description))
for permission in permissions:
if not session.query(Permissions).filter_by(name=permission).first():
if not session.query(Permissions).with_entities(Permissions.name).filter_by(name=permission).first():
session.add(Permissions(name=permission))
session.add(RolesPermissions(role_name=name, permission_name=permission))
@ -323,7 +322,7 @@ class UIDatabase(Database):
def get_ui_roles(self, *, as_dict: bool = False) -> List[Union[Roles, dict]]:
"""Get ui roles."""
with self._db_session() as session:
roles = session.query(Roles).all()
roles = session.query(Roles).with_entities(Roles.name, Roles.description, Roles.update_datetime).all()
if not as_dict:
return roles
@ -336,7 +335,7 @@ class UIDatabase(Database):
"permissions": [],
}
for permission in session.query(RolesPermissions).filter_by(role_name=role.name).all():
for permission in session.query(RolesPermissions).with_entities(RolesPermissions.permission_name).filter_by(role_name=role.name):
role_data["permissions"].append(permission.permission_name)
roles_data.append(role_data)
@ -349,19 +348,17 @@ class UIDatabase(Database):
if self.readonly:
return "The database is read-only, the changes will not be saved"
if not codes:
return "No recovery codes provided"
user = session.query(Users).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
if not codes:
return "No recovery codes provided"
session.query(UserRecoveryCodes).filter_by(user_name=username).delete()
for code in codes:
session.add(UserRecoveryCodes(user_name=username, code=code))
user.totp_refreshed = True
session.add(UserRecoveryCodes(user_name=username, code=hashpw(code.encode("utf-8"), gensalt(rounds=8)).decode("utf-8")))
try:
session.commit()
@ -388,36 +385,39 @@ class UIDatabase(Database):
def get_ui_user_roles(self, username: str) -> List[str]:
"""Get ui user roles."""
with self._db_session() as session:
return [role.role_name for role in session.query(RolesUsers).filter_by(user_name=username).all()]
return [role.role_name for role in session.query(RolesUsers).with_entities(RolesUsers.role_name).filter_by(user_name=username)]
def get_ui_role_permissions(self, role_name: str) -> List[str]:
"""Get ui role permissions."""
with self._db_session() as session:
return [permission.permission_name for permission in session.query(RolesPermissions).filter_by(role_name=role_name).all()]
return [
permission.permission_name
for permission in session.query(RolesPermissions).with_entities(RolesPermissions.permission_name).filter_by(role_name=role_name)
]
def get_ui_user_recovery_codes(self, username: str) -> List[str]:
"""Get ui user recovery codes."""
with self._db_session() as session:
return [code.code for code in session.query(UserRecoveryCodes).filter_by(user_name=username).all()]
return [code.code for code in session.query(UserRecoveryCodes).with_entities(UserRecoveryCodes.code).filter_by(user_name=username)]
def get_ui_user_permissions(self, username: str) -> List[str]:
"""Get ui user permissions."""
with self._db_session() as session:
roles = session.query(RolesUsers).filter_by(user_name=username).all()
query = session.query(RolesUsers).with_entities(RolesUsers.role_name).filter_by(user_name=username)
permissions = []
for role in roles:
for role in query:
permissions.extend(self.get_ui_role_permissions(role.role_name))
return permissions
def use_ui_user_recovery_code(self, username: str, code: str) -> str:
def use_ui_user_recovery_code(self, username: str, hashed_code: str) -> str:
"""Use ui user recovery code."""
with self._db_session() as session:
user = session.query(Users).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
recovery_code = session.query(UserRecoveryCodes).filter_by(user_name=username, code=code).first()
recovery_code = session.query(UserRecoveryCodes).filter_by(user_name=username, code=hashed_code).first()
if not recovery_code:
return "Invalid recovery code"
@ -428,18 +428,4 @@ class UIDatabase(Database):
except BaseException as e:
return str(e)
def set_ui_user_recovery_code_refreshed(self, username: str, value: bool) -> str:
"""Set ui user recovery code refreshed."""
with self._db_session() as session:
user = session.query(Users).filter_by(username=username).first()
if not user:
return f"User {username} doesn't exist"
user.totp_refreshed = value
try:
session.commit()
except BaseException as e:
return str(e)
return ""