Add check when plugins are configured + Add Semaphore to accelerate jobs execution + Code optimization

This commit is contained in:
Théophile Diot 2023-05-17 20:37:54 -04:00
parent 4c4fa44fbc
commit 3f51f59bcb
No known key found for this signature in database
GPG key ID: E752C80DB72BB014
11 changed files with 419 additions and 172 deletions

View file

@ -6,6 +6,7 @@ from hashlib import sha1
from os import _exit, getenv
from pathlib import Path
from sys import exit as sys_exit, path as sys_path
from threading import Lock
from traceback import format_exc
sys_path.extend(
@ -25,6 +26,7 @@ from jobs import cache_file, cache_hash, file_hash, is_cached_file
logger = setup_logger("JOBS.mmdb-asn", getenv("LOG_LEVEL", "INFO"))
status = 0
lock = Lock()
try:
dl_mmdb = True
@ -33,7 +35,8 @@ try:
# Don't go further if the cache match the latest version
if Path("/var/tmp/bunkerweb/asn.mmdb").exists():
response = get("https://db-ip.com/db/download/ip-to-asn-lite")
with lock:
response = get("https://db-ip.com/db/download/ip-to-asn-lite")
if response.status_code == 200:
_sha1 = sha1()

View file

@ -6,6 +6,7 @@ from hashlib import sha1
from os import _exit, getenv
from pathlib import Path
from sys import exit as sys_exit, path as sys_path
from threading import Lock
from traceback import format_exc
sys_path.extend(
@ -25,6 +26,7 @@ from jobs import cache_file, cache_hash, file_hash, is_cached_file
logger = setup_logger("JOBS.mmdb-country", getenv("LOG_LEVEL", "INFO"))
status = 0
lock = Lock()
try:
dl_mmdb = True
@ -33,7 +35,8 @@ try:
# Don't go further if the cache match the latest version
if Path("/var/tmp/bunkerweb/country.mmdb").exists():
response = get("https://db-ip.com/db/download/ip-to-country-lite")
with lock:
response = get("https://db-ip.com/db/download/ip-to-country-lite")
if response.status_code == 200:
_sha1 = sha1()

View file

@ -251,7 +251,7 @@ class Database:
return ""
def init_tables(self, default_settings: List[dict]) -> Tuple[bool, str]:
def init_tables(self, default_plugins: List[dict]) -> Tuple[bool, str]:
"""Initialize the database tables and return the result"""
inspector = inspect(self.__sql_engine)
if len(Base.metadata.tables.keys()) <= len(inspector.get_table_names()):
@ -269,7 +269,7 @@ class Database:
to_put = []
with self.__db_session() as session:
for plugins in default_settings:
for plugins in default_plugins:
if not isinstance(plugins, list):
plugins = [plugins]

View file

@ -6,11 +6,11 @@ from logging import Logger
from os import listdir
from os.path import basename, dirname
from pathlib import Path
from re import search as re_search
from re import compile as re_compile, search as re_search
from sys import path as sys_path
from tarfile import open as tar_open
from traceback import format_exc
from typing import Optional, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
if "/usr/share/bunkerweb/utils" not in sys_path:
sys_path.append("/usr/share/bunkerweb/utils")
@ -20,27 +20,24 @@ class Configurator:
def __init__(
self,
settings: str,
core: Union[str, dict],
plugins: Union[str, dict],
variables: Union[str, dict],
core: str,
external_plugins: Union[str, List[Dict[str, Any]]],
variables: Union[str, Dict[str, Any]],
logger: Logger,
*,
plugins_settings: Optional[list] = None,
):
self.__logger = logger
self.__plugin_id_rx = re_compile(r"^[\w.-]{1,64}$")
self.__plugin_version_rx = re_compile(r"^\d+\.\d+(\.\d+)?$")
self.__setting_id_rx = re_compile(r"^[A-Z0-9_]{1,256}$")
self.__name_rx = re_compile(r"^[\w.-]{1,128}$")
self.__job_file_rx = re_compile(r"^[\w./-]{1,256}$")
self.__settings = self.__load_settings(settings)
self.__core_plugins = self.__load_plugins(core)
if isinstance(core, str):
self.__core = self.__load_plugins(core)
if isinstance(external_plugins, str):
self.__external_plugins = self.__load_plugins(external_plugins, "external")
else:
self.__core = core
self.__plugins_settings = plugins_settings or []
if isinstance(plugins, str):
self.__plugins = self.__load_plugins(plugins, "plugins")
else:
self.__plugins = plugins
self.__external_plugins = external_plugins
if isinstance(variables, str):
self.__variables = self.__load_variables(variables)
@ -50,17 +47,33 @@ class Configurator:
self.__multisite = self.__variables.get("MULTISITE", "no") == "yes"
self.__servers = self.__map_servers()
def get_settings(self):
def get_settings(self) -> Dict[str, Any]:
return self.__settings
def get_plugins_settings(self):
return self.__plugins_settings
def get_plugins(
self, _type: Union[Literal["core"], Literal["external"]]
) -> List[Dict[str, Any]]:
return self.__core_plugins if _type == "core" else self.__external_plugins
def __map_servers(self):
def get_plugins_settings(
self, _type: Union[Literal["core"], Literal["external"]]
) -> Dict[str, Any]:
if _type == "core":
plugins = self.__core_plugins
else:
plugins = self.__external_plugins
plugins_settings = {}
for plugin in plugins:
plugins_settings.update(plugin["settings"])
return plugins_settings
def __map_servers(self) -> Dict[str, List[str]]:
if not self.__multisite or not "SERVER_NAME" in self.__variables:
return {}
servers = {}
for server_name in self.__variables["SERVER_NAME"].split(" "):
for server_name in self.__variables["SERVER_NAME"].strip().split(" "):
if not re_search(self.__settings["SERVER_NAME"]["regex"], server_name):
self.__logger.warning(
f"Ignoring server name {server_name} because regex is not valid",
@ -76,21 +89,44 @@ class Configurator:
f"Ignoring {server_name}_SERVER_NAME because regex is not valid",
)
else:
names = self.__variables[f"{server_name}_SERVER_NAME"].split(" ")
names = (
self.__variables[f"{server_name}_SERVER_NAME"]
.strip()
.split(" ")
)
servers[server_name] = names
return servers
def __load_settings(self, path) -> dict:
def __load_settings(self, path: str) -> Dict[str, Any]:
return loads(Path(path).read_text())
def __load_plugins(self, path, _type: str = "other"):
plugins = {}
def __load_plugins(self, path: str, _type: str = "core") -> List[Dict[str, Any]]:
orders = {}
plugins = []
files = glob(f"{path}/*/plugin.json")
for file in files:
try:
data = self.__load_settings(file)
if _type == "plugins":
resp, msg = self.__validate_plugin(data)
if not resp:
self.__logger.warning(
f"Ignoring plugin {file} : {msg}",
)
continue
if data["order"] not in orders:
orders[data["order"]] = [data["id"]]
else:
if len(orders[data["order"]]) > 1 and data["order"] != 999:
self.__logger.warning(
f"Plugin {data['id']} have the same order than {', '.join(orders[data['order']])}. Therefor, the execution order will be random."
)
orders[data["order"]].append(data["id"])
if _type == "external":
plugin_content = BytesIO()
with tar_open(fileobj=plugin_content, mode="w:gz") as tar:
tar.add(
@ -101,9 +137,8 @@ class Configurator:
plugin_content.seek(0)
value = plugin_content.getvalue()
self.__plugins_settings.append(
data
| {
data.update(
{
"external": path.startswith("/etc/bunkerweb/plugins"),
"page": "ui" in listdir(dirname(file)),
"method": "manual",
@ -112,7 +147,7 @@ class Configurator:
}
)
plugins.update(data["settings"])
plugins.append(data)
except:
self.__logger.error(
f"Exception while loading JSON from {file} : {format_exc()}",
@ -120,7 +155,7 @@ class Configurator:
return plugins
def __load_variables(self, path):
def __load_variables(self, path: str) -> Dict[str, Any]:
variables = {}
with open(path) as f:
lines = f.readlines()
@ -128,15 +163,18 @@ class Configurator:
line = line.strip()
if not line or line.startswith("#") or not "=" in line:
continue
var = line.split("=")[0]
value = line[len(var) + 1 :]
variables[var] = value
splitted = line.split("=", 1)
variables[splitted[0]] = splitted[1]
return variables
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {}
# Extract default settings
default_settings = [self.__settings, self.__core, self.__plugins]
default_settings = [
self.__settings,
self.get_plugins_settings("core"),
self.get_plugins_settings("external"),
]
for settings in default_settings:
for setting, data in settings.items():
config[setting] = data["default"]
@ -166,12 +204,15 @@ class Configurator:
# Expand variables to each sites if MULTISITE=yes and if not present
if config.get("MULTISITE", "no") == "yes":
for server_name in config["SERVER_NAME"].split(" "):
if server_name == "":
server_name = server_name.strip()
if not server_name:
continue
for settings in default_settings:
for setting, data in settings.items():
if data["context"] == "global":
continue
key = f"{server_name}_{setting}"
if key not in config:
if setting == "SERVER_NAME":
@ -180,15 +221,13 @@ class Configurator:
config[key] = config[setting]
return config
def __check_var(self, variable):
def __check_var(self, variable: str) -> Tuple[bool, str]:
value = self.__variables[variable]
# MULTISITE=no
if not self.__multisite:
where, real_var = self.__find_var(variable)
if not where:
return False, f"variable name {variable} doesn't exist"
elif not "regex" in where[real_var]:
return False, f"missing regex for variable {variable}"
elif not re_search(where[real_var]["regex"], value):
return (
False,
@ -200,17 +239,21 @@ class Configurator:
where, real_var = self.__find_var(real_var)
if not where:
return False, f"variable name {variable} doesn't exist"
if prefixed and where[real_var]["context"] != "multisite":
elif prefixed and where[real_var]["context"] != "multisite":
return False, f"context of {variable} isn't multisite"
if not re_search(where[real_var]["regex"], value):
elif not re_search(where[real_var]["regex"], value):
return (
False,
f"value {value} doesn't match regex {where[real_var]['regex']}",
)
return True, "ok"
def __find_var(self, variable):
targets = [self.__settings, self.__core, self.__plugins]
def __find_var(self, variable: str) -> Tuple[Optional[Dict[str, Any]], str]:
targets = [
self.__settings,
self.get_plugins_settings("core"),
self.get_plugins_settings("external"),
]
for target in targets:
if variable in target:
return target, variable
@ -219,10 +262,162 @@ class Configurator:
f"^{real_var}_[0-9]+$", variable
):
return target, real_var
return False, variable
return None, variable
def __var_is_prefixed(self, variable):
def __var_is_prefixed(self, variable: str) -> Tuple[bool, str]:
for server in self.__servers:
if variable.startswith(f"{server}_"):
return True, variable.replace(f"{server}_", "", 1)
return False, variable
def __validate_plugin(self, plugin: dict) -> Tuple[bool, str]:
if not all(
key in plugin.keys()
for key in [
"id",
"order",
"name",
"description",
"version",
"stream",
"settings",
]
):
return (
False,
f"Missing mandatory keys for plugin {plugin.get('id', 'unknown')} (id, order, name, description, version, stream, settings)",
)
if not self.__plugin_id_rx.match(plugin["id"]):
return (
False,
f"Invalid id for plugin {plugin['id']} (Can only contain numbers, letters, underscores and hyphens (min 1 characters and max 64))",
)
elif not isinstance(plugin["order"], int):
return False, f"Invalid order for plugin {plugin['id']}, must be a number"
elif len(plugin["name"]) > 128:
return (
False,
f"Invalid name for plugin {plugin['id']} (Max 128 characters)",
)
elif len(plugin["description"]) > 256:
return (
False,
f"Invalid description for plugin {plugin['id']} (Max 256 characters)",
)
elif not self.__plugin_version_rx.match(plugin["version"]):
return (
False,
f"Invalid version for plugin {plugin['id']} (Must be in format \d+\.\d+(\.\d+)?)",
)
elif plugin["stream"] not in ["yes", "no", "partial"]:
return (
False,
f"Invalid stream for plugin {plugin['id']} (Must be yes, no or partial)",
)
for setting, data in plugin["settings"].items():
if not all(
key in data.keys()
for key in [
"context",
"default",
"help",
"id",
"label",
"regex",
"type",
]
):
return (
False,
f"missing keys for setting {setting} in plugin {plugin['id']}, must have context, default, help, id, label, regex and type",
)
if not self.__setting_id_rx.match(setting):
return (
False,
f"Invalid setting name for setting {setting} in plugin {plugin['id']} (Can only contain capital letters and underscores (min 1 characters and max 256))",
)
elif data["context"] not in ["global", "multisite"]:
return (
False,
f"Invalid context for setting {setting} in plugin {plugin['id']} (Must be global or multisite)",
)
elif len(data["default"]) > 4096:
return (
False,
f"Invalid default for setting {setting} in plugin {plugin['id']} (Max 4096 characters)",
)
elif len(data["help"]) > 512:
return (
False,
f"Invalid help for setting {setting} in plugin {plugin['id']} (Max 512 characters)",
)
elif len(data["label"]) > 256:
return (
False,
f"Invalid label for setting {setting} in plugin {plugin['id']} (Max 256 characters)",
)
elif len(data["regex"]) > 1024:
return (
False,
f"Invalid regex for setting {setting} in plugin {plugin['id']} (Max 1024 characters)",
)
elif data["type"] not in ["password", "text", "check", "select"]:
return (
False,
f"Invalid type for setting {setting} in plugin {plugin['id']} (Must be password, text, check or select)",
)
if "multiple" in data:
if not self.__name_rx.match(data["multiple"]):
return (
False,
f"Invalid multiple for setting {setting} in plugin {plugin['id']} (Can only contain numbers, letters, underscores and hyphens (min 1 characters and max 128))",
)
for select in data.get("select", []):
if len(select) > 256:
return (
False,
f"Invalid select value {select} for setting {setting} in plugin {plugin['id']} (Max 256 characters)",
)
for job in plugin.get("jobs", []):
if not all(
key in job.keys()
for key in [
"name",
"file",
"every",
"reload",
]
):
return (
False,
f"missing keys for job {job['name']} in plugin {plugin['id']}, must have name, file, every and reload",
)
if not self.__name_rx.match(job["name"]):
return (
False,
f"Invalid name for job {job['name']} in plugin {plugin['id']}",
)
elif not self.__job_file_rx.match(job["file"]):
return (
False,
f"Invalid file for job {job['name']} in plugin {plugin['id']} (Can only contain numbers, letters, underscores, hyphens and slashes (min 1 characters and max 256))",
)
elif job["every"] not in ["once", "minute", "hour", "day", "week"]:
return (
False,
f"Invalid every for job {job['name']} in plugin {plugin['id']} (Must be once, minute, hour, day or week)",
)
elif job["reload"] is not True and job["reload"] is not False:
return (
False,
f"Invalid reload for job {job['name']} in plugin {plugin['id']} (Must be true or false)",
)
return True, "ok"

View file

@ -1,16 +1,24 @@
from copy import deepcopy
from glob import glob
from importlib import import_module
from os.path import basename, dirname
from pathlib import Path
from random import choice
from string import ascii_letters, digits
from typing import Any, Dict, List, Optional
from jinja2 import Environment, FileSystemLoader
class Templator:
def __init__(self, templates, core, plugins, output, target, config):
def __init__(
self,
templates: str,
core: str,
plugins: str,
output: str,
target: str,
config: Dict[str, Any],
):
self.__templates = templates
self.__core = core
self.__plugins = plugins
@ -25,13 +33,13 @@ class Templator:
def render(self):
self.__render_global()
servers = [self.__config.get("SERVER_NAME", "")]
servers = [self.__config.get("SERVER_NAME", "").strip()]
if self.__config.get("MULTISITE", "no") == "yes":
servers = self.__config.get("SERVER_NAME", "").split(" ")
servers = self.__config.get("SERVER_NAME", "").strip().split(" ")
for server in servers:
self.__render_server(server)
def __load_jinja_env(self):
def __load_jinja_env(self) -> Environment:
searchpath = [self.__templates]
for subpath in glob(f"{self.__core}/*") + glob(f"{self.__plugins}/*"):
if Path(subpath).is_dir():
@ -42,7 +50,7 @@ class Templator:
trim_blocks=True,
)
def __find_templates(self, contexts):
def __find_templates(self, contexts) -> List[str]:
templates = []
for template in self.__jinja_env.list_templates():
if "global" in contexts and "/" not in template:
@ -53,14 +61,11 @@ class Templator:
templates.append(template)
return templates
def __write_config(self, subpath=None, config=None):
real_path = self.__output
if subpath != None:
real_path += f"{subpath}/"
real_path += "variables.env"
real_config = self.__config
if config != None:
real_config = config
def __write_config(
self, subpath: Optional[str] = None, config: Optional[Dict[str, Any]] = None
):
real_path = self.__output + (f"{subpath}/" if subpath else "") + "variables.env"
real_config = config or self.__config
Path(dirname(real_path)).mkdir(parents=True, exist_ok=True)
Path(real_path).write_text(
"\n".join(f"{k}={v}" for k, v in real_config.items())
@ -74,23 +79,24 @@ class Templator:
for template in templates:
self.__render_template(template)
def __render_server(self, server):
def __render_server(self, server: str):
templates = self.__find_templates(
["modsec", "modsec-crs", "server-http", "server-stream"]
)
if self.__config.get("MULTISITE", "no") == "yes":
config = deepcopy(self.__config)
config = self.__config.copy()
for variable, value in self.__config.items():
if variable.startswith(f"{server}_"):
config[variable.replace(f"{server}_", "", 1)] = value
self.__write_config(subpath=server, config=config)
for template in templates:
subpath = None
config = None
name = None
if self.__config.get("MULTISITE", "no") == "yes":
subpath = server
config = deepcopy(self.__config)
config = self.__config.copy()
for variable, value in self.__config.items():
if variable.startswith(f"{server}_"):
config[variable.replace(f"{server}_", "", 1)] = value
@ -98,6 +104,7 @@ class Templator:
server_key = f"{server}_SERVER_NAME"
if server_key not in self.__config:
config["SERVER_NAME"] = server
root_confs = [
"server.conf",
"access-lua.conf",
@ -114,53 +121,50 @@ class Templator:
break
self.__render_template(template, subpath=subpath, config=config, name=name)
def __render_template(self, template, subpath=None, config=None, name=None):
def __render_template(
self,
template: str,
subpath: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
# Get real config and output folder in case it's a server config and we are in multisite mode
real_config = deepcopy(self.__config)
if config:
real_config = deepcopy(config)
real_config["all"] = deepcopy(real_config)
real_config = config.copy() if config else self.__config.copy()
real_config["all"] = real_config.copy()
real_config["import"] = import_module
real_config["is_custom_conf"] = Templator.is_custom_conf
real_config["has_variable"] = Templator.has_variable
real_config["random"] = Templator.random
real_config["read_lines"] = Templator.read_lines
real_output = self.__output
if subpath:
real_output += f"/{subpath}/"
real_name = template
if name:
real_name = name
real_path = (
self.__output + (f"/{subpath}/" if subpath else "") + (name or template)
)
jinja_template = self.__jinja_env.get_template(template)
Path(dirname(f"{real_output}{real_name}")).mkdir(parents=True, exist_ok=True)
Path(f"{real_output}{real_name}").write_text(jinja_template.render(real_config))
Path(dirname(real_path)).mkdir(parents=True, exist_ok=True)
Path(real_path).write_text(jinja_template.render(real_config))
@staticmethod
def is_custom_conf(path):
return glob(f"{path}/*.conf")
def is_custom_conf(path: str) -> bool:
return bool(glob(f"{path}/*.conf"))
@staticmethod
def has_variable(all_vars, variable, value):
if variable in all_vars and all_vars[variable] == value:
def has_variable(all_vars: Dict[str, Any], variable: str, value: Any) -> bool:
if all_vars.get(variable) == value:
return True
if all_vars.get("MULTISITE", "no") == "yes":
for server_name in all_vars["SERVER_NAME"].split(" "):
if (
f"{server_name}_{variable}" in all_vars
and all_vars[f"{server_name}_{variable}"] == value
):
elif all_vars.get("MULTISITE", "no") == "yes":
for server_name in all_vars["SERVER_NAME"].strip().split(" "):
if all_vars.get(f"{server_name}_{variable}") == value:
return True
return False
@staticmethod
def random(nb):
def random(nb: int) -> str:
characters = ascii_letters + digits
return "".join(choice(characters) for _ in range(nb))
@staticmethod
def read_lines(file):
def read_lines(file: str) -> List[str]:
try:
with open(file, "r") as f:
return f.readlines()
return Path(file).read_text().splitlines()
except:
return []

View file

@ -9,6 +9,7 @@ from subprocess import DEVNULL, STDOUT, run
from sys import exit as sys_exit, path as sys_path
from time import sleep
from traceback import format_exc
from typing import Any, Dict
if "/usr/share/bunkerweb/deps/python" not in sys_path:
sys_path.append("/usr/share/bunkerweb/deps/python")
@ -131,10 +132,9 @@ if __name__ == "__main__":
# Compute the config
logger.info("Computing config ...")
config = Configurator(
config: Dict[str, Any] = Configurator(
args.settings, args.core, args.plugins, args.variables, logger
)
config = config.get_config()
).get_config()
else:
if "/usr/share/bunkerweb/db" not in sys_path:
sys_path.append("/usr/share/bunkerweb/db")
@ -145,7 +145,7 @@ if __name__ == "__main__":
logger,
sqlalchemy_string=getenv("DATABASE_URI", None),
)
config = db.get_config()
config: Dict[str, Any] = db.get_config()
# Remove old files
logger.info("Removing old files ...")

View file

@ -2,7 +2,6 @@
from argparse import ArgumentParser
from glob import glob
from itertools import chain
from json import loads
from os import R_OK, X_OK, access, environ, getenv, listdir, walk
from os.path import join
@ -160,15 +159,12 @@ if __name__ == "__main__":
db = None
apis = []
plugins = args.plugins
plugins_settings = None
external_plugins = args.plugins
if not Path("/usr/sbin/nginx").exists() and args.method == "ui":
db = Database(logger)
plugins = {}
plugins_settings = []
external_plugins = []
for plugin in db.get_plugins():
plugins_settings.append(plugin)
plugins.update(plugin["settings"])
external_plugins.append(plugin)
# Check existences and permissions
logger.info("Checking arguments ...")
@ -208,16 +204,6 @@ if __name__ == "__main__":
f"Exception while loading JSON from {file} : {format_exc()}",
)
core_settings = {}
for order in core_plugins:
if len(core_plugins[order]) > 1 and order != 999:
logger.warning(
f"Multiple plugins have the same order ({order}) : {', '.join(plugin['id'] for plugin in core_plugins[order])}. Therefor, the execution order will be random.",
)
for plugin in core_plugins[order]:
core_settings.update(plugin["settings"])
if args.variables:
logger.info(f"Variables : {args.variables}")
@ -225,11 +211,10 @@ if __name__ == "__main__":
logger.info("Computing config ...")
config = Configurator(
args.settings,
core_settings,
plugins,
args.core,
external_plugins,
args.variables,
logger,
plugins_settings=plugins_settings,
)
config_files = config.get_config()
custom_confs = []
@ -329,11 +314,10 @@ if __name__ == "__main__":
logger.info("Computing config ...")
config = Configurator(
args.settings,
core_settings,
plugins,
args.core,
external_plugins,
tmp_config,
logger,
plugins_settings=plugins_settings,
)
config_files = config.get_config()
@ -344,8 +328,8 @@ if __name__ == "__main__":
ret, err = db.init_tables(
[
config.get_settings(),
list(chain.from_iterable(core_plugins.values())),
config.get_plugins_settings(),
config.get_plugins("core"),
config.get_plugins("external"),
]
)

View file

@ -13,7 +13,7 @@ from stat import (
from typing import List
def has_permissions(path: str, need_permissions: List[str]):
def has_permissions(path: str, need_permissions: List[str]) -> bool:
uid = geteuid()
gid = getegid()
statinfo = stat(path)

View file

@ -1,17 +1,22 @@
from copy import deepcopy
from functools import partial
from glob import glob
from json import loads
from logging import Logger
from os import cpu_count, environ, getenv
from os.path import basename, dirname
from pathlib import Path
from subprocess import DEVNULL, PIPE, STDOUT, run
from threading import Lock, Thread
from re import match
from typing import Any, Dict, Optional
from schedule import (
Job,
clear as schedule_clear,
every as schedule_every,
jobs as schedule_jobs,
)
from subprocess import DEVNULL, PIPE, STDOUT, run
from sys import path as sys_path
from threading import Lock, Semaphore, Thread
from traceback import format_exc
sys_path.extend(("/usr/share/bunkerweb/utils", "/usr/share/bunkerweb/db"))
@ -24,14 +29,14 @@ from ApiCaller import ApiCaller
class JobScheduler(ApiCaller):
def __init__(
self,
env=None,
lock=None,
apis=[],
logger: Logger = setup_logger("Scheduler", getenv("LOG_LEVEL", "INFO")),
env: Optional[Dict[str, Any]] = None,
lock: Optional[Lock] = None,
apis: Optional[list] = None,
logger: Optional[Logger] = None,
integration: str = "Linux",
):
super().__init__(apis)
self.__logger = logger
super().__init__(apis or [])
self.__logger = logger or setup_logger("Scheduler", getenv("LOG_LEVEL", "INFO"))
self.__integration = integration
self.__db = Database(self.__logger)
self.__env = env or {}
@ -40,21 +45,68 @@ class JobScheduler(ApiCaller):
self.__lock = lock
self.__thread_lock = Lock()
self.__job_success = True
self.__semaphore = Semaphore(cpu_count() or 1)
def __get_jobs(self):
jobs = {}
plugins_core = [folder for folder in glob("/usr/share/bunkerweb/core/*/")]
plugins_external = [folder for folder in glob("/etc/bunkerweb/plugins/*/")]
for plugin in plugins_core + plugins_external:
plugin_name = plugin.split("/")[-2]
for plugin_file in list(
glob("/usr/share/bunkerweb/core/*/plugin.json") # core plugins
) + list(
glob("/etc/bunkerweb/plugins/*/plugin.json") # external plugins
):
plugin_name = basename(dirname(plugin_file))
jobs[plugin_name] = []
try:
plugin_data = loads(Path(f"{plugin}/plugin.json").read_text())
plugin_data = loads(Path(plugin_file).read_text())
if not "jobs" in plugin_data:
continue
for job in plugin_data["jobs"]:
job["path"] = plugin
jobs[plugin_name] = plugin_data["jobs"]
plugin_jobs = plugin_data["jobs"]
for x, job in enumerate(deepcopy(plugin_jobs)):
if not all(
key in job.keys()
for key in [
"name",
"file",
"every",
"reload",
]
):
self.__logger.warning(
f"missing keys for job {job['name']} in plugin {plugin_name}, must have name, file, every and reload, ignoring job"
)
plugin_jobs.pop(x)
continue
if not match(r"^[\w.-]{1,128}$", job["name"]):
self.__logger.warning(
f"Invalid name for job {job['name']} in plugin {plugin_name} (Can only contain numbers, letters, underscores and hyphens (min 1 characters and max 128)), ignoring job"
)
plugin_jobs.pop(x)
continue
elif not match(r"^[\w./-]{1,256}$", job["file"]):
self.__logger.warning(
f"Invalid file for job {job['name']} in plugin {plugin_name} (Can only contain numbers, letters, underscores, hyphens and slashes (min 1 characters and max 256)), ignoring job"
)
plugin_jobs.pop(x)
continue
elif job["every"] not in ["once", "minute", "hour", "day", "week"]:
self.__logger.warning(
f"Invalid every for job {job['name']} in plugin {plugin_name} (Must be once, minute, hour, day or week), ignoring job"
)
plugin_jobs.pop(x)
continue
elif job["reload"] is not True and job["reload"] is not False:
self.__logger.warning(
f"Invalid reload for job {job['name']} in plugin {plugin_name} (Must be true or false), ignoring job"
)
plugin_jobs.pop(x)
continue
plugin_jobs[x]["path"] = f"{dirname(plugin_file)}/"
jobs[plugin_name] = plugin_jobs
except FileNotFoundError:
pass
except:
@ -63,18 +115,18 @@ class JobScheduler(ApiCaller):
)
return jobs
def __str_to_schedule(self, every):
def __str_to_schedule(self, every: str) -> Job:
if every == "minute":
return schedule_every().minute
if every == "hour":
elif every == "hour":
return schedule_every().hour
if every == "day":
elif every == "day":
return schedule_every().day
if every == "week":
elif every == "week":
return schedule_every().week
raise Exception(f"can't convert every string {every} to schedule")
raise Exception(f"can't convert string {every} to schedule")
def __reload(self):
def __reload(self) -> bool:
reload = True
if self.__integration not in ("Autoconf", "Swarm", "Kubernetes", "Docker"):
self.__logger.info("Reloading nginx ...")
@ -89,7 +141,7 @@ class JobScheduler(ApiCaller):
self.__logger.info("Successfully reloaded nginx")
else:
self.__logger.error(
f"Error while reloading nginx - returncode: {proc.returncode} - error: {proc.stderr.decode('utf-8')}",
f"Error while reloading nginx - returncode: {proc.returncode} - error: {proc.stderr.decode()}",
)
else:
self.__logger.info("Reloading nginx ...")
@ -100,7 +152,7 @@ class JobScheduler(ApiCaller):
self.__logger.error("Error while reloading nginx")
return reload
def __job_wrapper(self, path, plugin, name, file):
def __job_wrapper(self, path: str, plugin: str, name: str, file: str) -> int:
self.__logger.info(
f"Executing job {name} from plugin {plugin} ...",
)
@ -119,7 +171,7 @@ class JobScheduler(ApiCaller):
with self.__thread_lock:
self.__job_success = False
if self.__job_success and proc.returncode >= 2:
if self.__job_success and ret >= 2:
success = False
self.__logger.error(
f"Error while executing job {name} from plugin {plugin}",
@ -157,18 +209,24 @@ class JobScheduler(ApiCaller):
f"Exception while scheduling jobs for plugin {plugin} : {format_exc()}",
)
def run_pending(self):
if self.__lock is not None:
def run_pending(self) -> bool:
if self.__lock:
self.__lock.acquire()
jobs = [job for job in schedule_jobs if job.should_run]
success = True
reload = False
for job in jobs:
ret = job.run()
if not isinstance(ret, int):
ret = -1
if ret == 1:
reload = True
elif ret < 0 or ret >= 2:
success = False
if reload:
try:
if self._get_apis():
@ -189,11 +247,12 @@ class JobScheduler(ApiCaller):
self.__logger.error(
f"Exception while reloading after job scheduling : {format_exc()}",
)
if self.__lock is not None:
if self.__lock:
self.__lock.release()
return success
def run_once(self):
def run_once(self) -> bool:
threads = []
for plugin, jobs in self.__jobs.items():
jobs_jobs = []
@ -207,35 +266,33 @@ class JobScheduler(ApiCaller):
jobs_jobs.append(partial(self.__job_wrapper, path, plugin, name, file))
# Create a thread for each plugin
threads.append(
Thread(
target=lambda jobs_jobs: [job() for job in jobs_jobs],
args=(jobs_jobs,),
)
)
threads.append(Thread(target=self.__run_in_thread, args=(jobs_jobs,)))
# Split the list of threads into sublists of the max cpu count
nbr_cpu = cpu_count() or 1
for i in range(0, len(threads), nbr_cpu):
sublist = threads[i : i + nbr_cpu]
for t in sublist:
t.start()
for t in sublist:
t.join()
for thread in threads:
thread.start()
ret = self.__job_success
for thread in threads:
thread.join()
ret = self.__job_success is True
self.__job_success = True
return ret
def __run_in_thread(self, jobs: list):
self.__semaphore.acquire()
for job in jobs:
job()
self.__semaphore.release()
def clear(self):
schedule_clear()
def reload(self, env, apis=[]):
def reload(self, env: Dict[str, Any], apis: Optional[list] = None) -> bool:
ret = True
try:
self.__env = env
super().__init__(apis)
super().__init__(apis or [])
self.clear()
self.__jobs = self.__get_jobs()
ret = self.run_once()

View file

@ -163,6 +163,7 @@ PLUGIN_KEYS = [
"name",
"description",
"version",
"stream",
"settings",
]

View file

@ -26,7 +26,7 @@ def generate_custom_configs(
class ConfigFiles:
def __init__(self, logger, db):
self.__name_regex = re_compile(r"^[a-zA-Z0-9_\-.]{1,64}$")
self.__name_regex = re_compile(r"^[\w.-]{1,64}$")
self.__root_dirs = [
child["name"]
for child in path_to_dict("/etc/bunkerweb/configs")["children"]