Refactor dotenv variable loading to improve readability and error handling; enhance server name validation logic and optimize plugin loading with threading

This commit is contained in:
Théophile Diot 2025-01-08 11:58:33 +01:00
parent b44492c685
commit 1f9393f8d0
No known key found for this signature in database
GPG key ID: FA995104A0BA376A
6 changed files with 77 additions and 64 deletions

View file

@ -15,7 +15,6 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
if deps_path not in sys_path:
sys_path.append(deps_path)
from dotenv import dotenv_values
from redis import StrictRedis, Sentinel
from API import API # type: ignore
@ -59,7 +58,8 @@ class CLI(ApiCaller):
self.__variables = {}
self.__db = None
if variables_path.is_file():
self.__variables = dotenv_values(variables_path)
with variables_path.open() as f:
self.__variables = dict(line.strip().split("=", 1) for line in f if line.strip() and not line.startswith("#"))
if Path(sep, "usr", "share", "bunkerweb", "db").exists():
from Database import Database # type: ignore

View file

@ -63,7 +63,7 @@ try:
sent, err, status, resp = api.request(
"POST",
f"/reload?test={'no' if getenv('DISABLE_CONFIGURATION_TESTING', 'no').lower() == 'yes' else 'yes'}",
timeout=max(reload_min_timeout, 2 * len(services)),
timeout=max(reload_min_timeout, 3 * len(services)),
)
if not sent:
status = 1

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python3
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import cache
from io import BytesIO
@ -67,10 +68,7 @@ class Configurator:
@cache
def get_plugins_settings(self, _type: Literal["core", "external", "pro"]) -> Dict[str, str]:
plugins_settings = {}
for plugin in self.get_plugins(_type):
plugins_settings.update(plugin.get("settings", {}))
return plugins_settings
return {k: v for plugin in self.get_plugins(_type) for k, v in plugin.get("settings", {}).items()}
@cache
def __map_servers(self) -> Dict[str, List[str]]:
@ -78,35 +76,42 @@ class Configurator:
return {}
servers = {}
for server_name in self.__variables["SERVER_NAME"].strip().split(" "):
if not server_name:
continue
server_regex = re_compile(self.__settings["SERVER_NAME"]["regex"])
server_names = [s for s in self.__variables["SERVER_NAME"].strip().split(" ") if s]
if re_search(self.__settings["SERVER_NAME"]["regex"], server_name) is None:
for server_name in server_names:
if not server_regex.search(server_name):
self.__logger.warning(f"Ignoring server name {server_name} because regex is not valid")
continue
names = [server_name]
if f"{server_name}_SERVER_NAME" in self.__variables:
if re_search(self.__settings["SERVER_NAME"]["regex"], self.__variables[f"{server_name}_SERVER_NAME"]) is None:
self.__logger.warning(f"Ignoring {server_name}_SERVER_NAME because regex is not valid")
server_name_var = f"{server_name}_SERVER_NAME"
if server_name_var in self.__variables:
names_str = self.__variables[server_name_var].strip()
if not server_regex.search(names_str):
self.__logger.warning(f"Ignoring {server_name_var} because regex is not valid")
servers[server_name] = [server_name]
else:
names = self.__variables[f"{server_name}_SERVER_NAME"].strip().split(" ")
servers[server_name] = [n for n in names_str.split(" ") if n]
else:
servers[server_name] = [server_name]
servers[server_name] = names
return servers
def __load_settings(self, path: Path) -> Dict[str, str]:
return loads(path.read_text())
def __load_plugins(self, path: Path, _type: Literal["core", "external", "pro"] = "core"):
x = 0
for file in path.glob("*/plugin.json"):
self.__logger.debug(f"Loading {_type} plugin {file}")
self.__load_plugin(file, _type)
x += 1
self.__logger.info(f"Computed {x} {_type} plugin{'s' if x > 1 else ''}")
plugin_files = list(path.glob("*/plugin.json"))
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.__load_plugin, file, _type) for file in plugin_files]
for file, future in zip(plugin_files, futures):
self.__logger.debug(f"Loading {_type} plugin {file}")
future.result()
count = len(plugin_files)
self.__logger.info(f"Computed {count} {_type} plugin{'s' if count > 1 else ''}")
def __load_plugin(self, file: Path, _type: Literal["core", "external", "pro"] = "core"):
try:
@ -141,16 +146,15 @@ class Configurator:
self.__logger.error(f"Exception while loading JSON from {file} : {e}")
def __load_variables(self, path: Path) -> Dict[str, str]:
variables = {}
with path.open("r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
split = line.split("=", 1)
variables[split[0]] = split[1]
return variables
try:
return dict(
line.strip().split("=", 1)
for line in path.read_text(encoding="utf-8").splitlines()
if line.strip() and not line.strip().startswith("#") and "=" in line
)
except Exception as e:
self.__logger.error(f"Failed to load variables from {path}: {e}")
return {}
def get_config(self, db=None, *, first_run: bool = False) -> Dict[str, str]:
config = {}

View file

@ -64,10 +64,10 @@ class Templator:
raise TypeError("config must be a dictionary")
self._templates = templates
self._global_templates = [basename(template) for template in glob(join(self._templates, "*", "*.conf"))]
self._core = core
self._plugins = plugins
self._pro_plugins = pro_plugins
self._global_templates = [template.name for template in Path(self._templates).glob("**/*.conf")]
self._core = Path(core)
self._plugins = Path(plugins)
self._pro_plugins = Path(pro_plugins)
self._output = output
self._target = target
self._config = config
@ -78,7 +78,7 @@ class Templator:
self._render_global()
servers = [self._config.get("SERVER_NAME", "").strip()]
if self._config.get("MULTISITE", "no") == "yes":
servers = self._config.get("SERVER_NAME", "").strip().split()
servers = self._config.get("SERVER_NAME", "").strip().split(" ")
for server in servers:
self._render_server(server)
@ -89,9 +89,7 @@ class Templator:
Environment: The Jinja2 environment.
"""
searchpath = [self._templates]
for subpath in glob(join(self._core, "*", "confs")) + glob(join(self._plugins, "*", "confs")) + glob(join(self._pro_plugins, "*", "confs")):
if Path(subpath).is_dir():
searchpath.append(subpath)
searchpath.extend(p.as_posix() for p in (*self._core.glob("*/confs"), *self._plugins.glob("*/confs"), *self._pro_plugins.glob("*/confs")) if p.is_dir())
return Environment(
loader=FileSystemLoader(searchpath=searchpath),
lstrip_blocks=True,
@ -108,13 +106,20 @@ class Templator:
Returns:
List[str]: List of template names.
"""
context_set = set(contexts)
templates = [
template
for template in self._jinja_env.list_templates()
if any(template.startswith(context + "/") or (context == "global" and "/" not in template) for context in context_set)
]
return templates
templates = set()
all_templates = frozenset(self._jinja_env.list_templates())
# Handle global context specially for better performance
if "global" in contexts:
templates.update(t for t in all_templates if "/" not in t)
contexts.remove("global")
# Process remaining contexts
if contexts:
prefix_set = tuple(context + "/" for context in contexts)
templates.update(t for t in all_templates if any(t.startswith(prefix) for prefix in prefix_set))
return list(templates)
def _write_config(self, subpath: Optional[str] = None, config: Optional[Dict[str, Any]] = None) -> None:
"""Write the configuration to a variables.env file.
@ -127,8 +132,7 @@ class Templator:
real_path = Path(self._output, subpath or "", "variables.env")
try:
real_path.parent.mkdir(parents=True, exist_ok=True)
with real_path.open("w") as f:
f.write("\n".join(f"{k}={v}" for k, v in real_config.items()))
real_path.write_text("".join(f"{k}={v}\n" for k, v in real_config.items()))
except IOError as e:
logger.error(f"Error writing configuration to {real_path}: {e}")
@ -141,23 +145,30 @@ class Templator:
Returns:
Dict[str, Any]: Configuration dictionary for the server.
"""
config = self._config.copy()
config = {}
prefix = f"{server}_"
for variable, value in self._config.items():
if variable.startswith(prefix):
config[variable[len(prefix) :]] = value # noqa: E203
prefix_len = len(prefix)
# Pre-populate with base config and handle NGINX_PREFIX
config.update(self._config)
config["NGINX_PREFIX"] = join(self._target, server) + "/"
server_key = f"{server}_SERVER_NAME"
if server_key not in self._config:
# Efficient single-pass override of server-specific values
for key, value in ((k, v) for k, v in self._config.items() if k.startswith(prefix)):
config[key[prefix_len:]] = value
# Set default SERVER_NAME if not explicitly defined
if f"{prefix}SERVER_NAME" not in self._config:
config["SERVER_NAME"] = server
return config
def _render_global(self) -> None:
"""Render global templates."""
self._write_config()
templates = self._find_templates(["global", "http", "stream", "default-server-http"])
for template in templates:
self._render_template(template)
with ThreadPoolExecutor(max_workers=min(32, len(templates))) as executor:
executor.map(self._render_template, templates)
def _render_server(self, server: str) -> None:
"""Render templates for a specific server.

View file

@ -14,8 +14,6 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
if deps_path not in sys_path:
sys_path.append(deps_path)
from dotenv import dotenv_values
from logger import setup_logger # type: ignore
from Configurator import Configurator
from Templator import Templator
@ -75,7 +73,8 @@ if __name__ == "__main__":
if args.variables:
variables_path = Path(args.variables)
LOGGER.info(f"Variables : {variables_path}")
dotenv_env = dotenv_values(variables_path.as_posix())
with variables_path.open() as f:
dotenv_env = dict(line.strip().split("=", 1) for line in f if line.strip() and not line.startswith("#"))
db = None
if DB_PATH.is_dir():

View file

@ -12,8 +12,6 @@ for deps_path in [join(sep, "usr", "share", "bunkerweb", *paths) for paths in ((
if deps_path not in sys_path:
sys_path.append(deps_path)
from dotenv import dotenv_values
from common_utils import get_integration, get_version # type: ignore
from logger import setup_logger # type: ignore
from Database import Database # type: ignore
@ -67,7 +65,8 @@ if __name__ == "__main__":
if args.variables:
variables_path = Path(args.variables)
LOGGER.info(f"Variables : {variables_path}")
dotenv_env = dotenv_values(variables_path.as_posix())
with variables_path.open() as f:
dotenv_env = dict(line.strip().split("=", 1) for line in f if line.strip() and not line.startswith("#"))
# Check existences and permissions
LOGGER.info("Checking arguments ...")