mirror of
https://github.com/jmagar/unraid-mcp
synced 2026-04-21 13:37:53 +00:00
feat: improve auth, server, subscriptions, tools, and add regression tests
This commit is contained in:
parent
943877a802
commit
ae55b5b7e0
22 changed files with 409 additions and 95 deletions
|
|
@ -11,6 +11,21 @@
|
|||
"homepage": "https://github.com/jmagar/unraid-mcp",
|
||||
"license": "MIT",
|
||||
"keywords": ["unraid", "homelab", "mcp", "graphql", "docker", "nas", "monitoring"],
|
||||
"mcpServers": {
|
||||
"unraid": {
|
||||
"type": "stdio",
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"--directory",
|
||||
"${CLAUDE_PLUGIN_ROOT}",
|
||||
"unraid-mcp-server"
|
||||
],
|
||||
"env": {
|
||||
"UNRAID_MCP_TRANSPORT": "stdio"
|
||||
}
|
||||
}
|
||||
},
|
||||
"userConfig": {
|
||||
"unraid_mcp_url": {
|
||||
"type": "string",
|
||||
|
|
@ -23,7 +38,7 @@
|
|||
"type": "string",
|
||||
"title": "MCP Server Bearer Token",
|
||||
"description": "Bearer token for authenticating with the MCP server. Must match UNRAID_MCP_BEARER_TOKEN in the server's .env. Generate with: openssl rand -hex 32",
|
||||
"sensitive": false
|
||||
"sensitive": true
|
||||
},
|
||||
"unraid_api_url": {
|
||||
"type": "string",
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ All notable changes to this project are documented here.
|
|||
|
||||
### Added
|
||||
- **HTTP bearer token auth**: ASGI-level `BearerAuthMiddleware` (pure `__call__` pattern, no BaseHTTPMiddleware overhead) enforces `Authorization: Bearer <token>` on all HTTP requests. RFC 6750 compliant — missing header returns `WWW-Authenticate: Bearer realm="unraid-mcp"`, invalid token adds `error="invalid_token"`.
|
||||
- **Auto token generation**: On first HTTP startup with no token configured, a `secrets.token_urlsafe(32)` token is generated, written to `~/.unraid-mcp/.env` (mode 600), printed once to STDERR, and removed from `os.environ` so subprocesses cannot inherit it.
|
||||
- **Auto token generation**: On first HTTP startup with no token configured, a `secrets.token_urlsafe(32)` token is generated, written to `~/.unraid-mcp/.env` (mode 600), announced once on STDERR without printing the secret, and removed from `os.environ` so subprocesses cannot inherit it.
|
||||
- **Per-IP rate limiting**: 60 failed auth attempts per 60 seconds → 429 with `Retry-After: 60` header.
|
||||
- **Gateway escape hatch**: `UNRAID_MCP_DISABLE_HTTP_AUTH=true` bypasses bearer auth for users who handle authentication at a reverse proxy / gateway layer.
|
||||
- **Startup guard**: Server refuses to start in HTTP mode (`streamable-http`/`sse`) if no token is set and `DISABLE_HTTP_AUTH` is not explicitly enabled.
|
||||
- **Startup guard**: Server refuses to start in HTTP mode (`streamable-http`/`sse`) if no token is set and `UNRAID_MCP_DISABLE_HTTP_AUTH` is not explicitly enabled.
|
||||
- **Tests**: 23 new tests in `tests/test_auth.py` covering pass-through scopes, 401/429 responses, RFC 6750 header differentiation, per-IP rate limiting, window expiry, token generation, and startup guard.
|
||||
|
||||
### Changed
|
||||
|
|
@ -22,10 +22,12 @@ All notable changes to this project are documented here.
|
|||
## [1.1.6] - 2026-03-30
|
||||
|
||||
### Security
|
||||
|
||||
- **Path traversal**: `flash_backup` source path now validated after `posixpath.normpath` (not before) — raw-string `..` check was bypassable via encoded sequences like `foo/bar/../..`; null byte guard added
|
||||
- **Key validation**: `DANGEROUS_KEY_PATTERN` now blocks space (0x20) and DEL (0x7f) in addition to existing shell metacharacters; applies to both rclone and settings key validation
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Settings validation**: `configure_ups` input now validated via `_validate_settings_input` before mutation — was previously passing unvalidated dict directly to GraphQL
|
||||
- **Subscription locks**: `_start_one` `last_error` write and `stop_all()` keys snapshot both now take `_task_lock` to prevent concurrent write/read races
|
||||
- **Keepalive handling**: Removed `"ping"` from keepalive `elif` — ping messages require a pong response, not silent discard; only `"ka"` and `"pong"` are silently dropped
|
||||
|
|
@ -33,11 +35,13 @@ All notable changes to this project are documented here.
|
|||
- **Health reverse map**: `_STATUS_FROM_SEVERITY` dict hoisted to module level — was being rebuilt on every `_comprehensive_health_check` call
|
||||
|
||||
### Changed
|
||||
|
||||
- **Log content cap**: `_cap_log_content` now skipped for non-log subscriptions (only `log_tail`/`logFileSubscription` have `content` fields) — reduces unnecessary dict key lookups on every WebSocket message
|
||||
- **Live assertion**: `_handle_live` now raises `RuntimeError` at import time if `COLLECT_ACTIONS` contains keys not in `_HANDLED_COLLECT_SUBACTIONS` — catches handler omissions before runtime
|
||||
- **Subscription name guard**: `start_subscription` validates name matches `^[a-zA-Z0-9_]+$` before use as WebSocket message ID
|
||||
|
||||
### Added
|
||||
|
||||
- **Tests**: 27 parametrized tests for `DANGEROUS_KEY_PATTERN` covering all documented dangerous characters and safe key examples (`tests/test_validation.py`)
|
||||
- **Tests**: `test_check_api_error_wrapped_tool_error` — verifies health check returns `{status: unhealthy}` when `make_graphql_request` raises `ToolError` wrapping `httpx.ConnectError`
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,6 @@ EXPOSE 6970
|
|||
USER mcp
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:6970/health')" || exit 1
|
||||
CMD python -c "import os, urllib.request; urllib.request.urlopen('http://localhost:%s/health' % os.getenv('UNRAID_MCP_PORT', '6970'))" || exit 1
|
||||
|
||||
ENTRYPOINT ["unraid-mcp-server"]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ services:
|
|||
container_name: unraid-mcp
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${UNRAID_MCP_PORT:-6970}:6970"
|
||||
- "${UNRAID_MCP_PORT:-6970}:${UNRAID_MCP_PORT:-6970}"
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
|
|
@ -12,7 +12,7 @@ services:
|
|||
- ./backups:/app/backups
|
||||
- unraid-mcp-credentials:/home/mcp/.unraid-mcp
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:6970/health')"]
|
||||
test: ["CMD-SHELL", "python -c \"import os, urllib.request; urllib.request.urlopen('http://localhost:%s/health' % os.getenv('UNRAID_MCP_PORT', '6970'))\""]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
start_period: 10s
|
||||
|
|
|
|||
|
|
@ -16,8 +16,46 @@ REQUIRED=(
|
|||
|
||||
touch "$GITIGNORE"
|
||||
|
||||
existing="$(cat "$GITIGNORE")"
|
||||
for pattern in "${REQUIRED[@]}"; do
|
||||
if ! grep -qxF "$pattern" "$GITIGNORE" 2>/dev/null; then
|
||||
echo "$pattern" >> "$GITIGNORE"
|
||||
existing+=$'\n'"$pattern"
|
||||
fi
|
||||
done
|
||||
|
||||
printf '%s\n' "$existing" | awk '
|
||||
BEGIN {
|
||||
want[".env"]=1
|
||||
want[".env.*"]=1
|
||||
want["!.env.example"]=1
|
||||
want["backups/*"]=1
|
||||
want["!backups/.gitkeep"]=1
|
||||
want["logs/*"]=1
|
||||
want["!logs/.gitkeep"]=1
|
||||
want["__pycache__/"]=1
|
||||
}
|
||||
{ lines[++n]=$0 }
|
||||
END {
|
||||
emitted[""] = 1
|
||||
for (i = 1; i <= n; i++) {
|
||||
if (!want[lines[i]] && !emitted[lines[i]]) {
|
||||
print lines[i]
|
||||
emitted[lines[i]] = 1
|
||||
}
|
||||
}
|
||||
ordered[1]=".env"
|
||||
ordered[2]=".env.*"
|
||||
ordered[3]="!.env.example"
|
||||
ordered[4]="backups/*"
|
||||
ordered[5]="!backups/.gitkeep"
|
||||
ordered[6]="logs/*"
|
||||
ordered[7]="!logs/.gitkeep"
|
||||
ordered[8]="__pycache__/"
|
||||
for (i = 1; i <= 8; i++) {
|
||||
if (!emitted[ordered[i]]) {
|
||||
print ordered[i]
|
||||
emitted[ordered[i]] = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
' > "$GITIGNORE"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ fi
|
|||
for key in "${!MANAGED[@]}"; do
|
||||
value="${MANAGED[$key]}"
|
||||
[ -z "$value" ] && continue
|
||||
if [[ "$value" == *$'\n'* || "$value" == *$'\r'* || "$value" == *$'\t'* ]]; then
|
||||
echo "sync-env: refusing ${key} with control characters" >&2
|
||||
exit 1
|
||||
fi
|
||||
escaped_value=$(printf '%s\n' "$value" | sed 's/[&/\|]/\\&/g')
|
||||
if grep -q "^${key}=" "$ENV_FILE" 2>/dev/null; then
|
||||
sed -i "s|^${key}=.*|${key}=${escaped_value}|" "$ENV_FILE"
|
||||
|
|
@ -30,7 +34,8 @@ for key in "${!MANAGED[@]}"; do
|
|||
done
|
||||
|
||||
# Auto-generate UNRAID_MCP_BEARER_TOKEN if not yet set
|
||||
if ! grep -q "^UNRAID_MCP_BEARER_TOKEN=" "$ENV_FILE" 2>/dev/null; then
|
||||
if ! grep -Eq '^UNRAID_MCP_BEARER_TOKEN=.+$' "$ENV_FILE" 2>/dev/null; then
|
||||
sed -i '/^UNRAID_MCP_BEARER_TOKEN=/d' "$ENV_FILE"
|
||||
generated=$(openssl rand -hex 32)
|
||||
echo "UNRAID_MCP_BEARER_TOKEN=${generated}" >> "$ENV_FILE"
|
||||
echo "sync-env: generated UNRAID_MCP_BEARER_TOKEN (update plugin userConfig to match)" >&2
|
||||
|
|
|
|||
|
|
@ -328,5 +328,5 @@ curl -s "$CLAUDE_PLUGIN_OPTION_UNRAID_API_URL" \
|
|||
curl -s "$CLAUDE_PLUGIN_OPTION_UNRAID_API_URL" \
|
||||
-H "x-api-key: $CLAUDE_PLUGIN_OPTION_UNRAID_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query":"{ array { state capacity { disks { name status temp } } } }"}'
|
||||
-d '{"query":"{ array { state capacity { kilobytes { free used total } } disks { name status temp } } }"}'
|
||||
```
|
||||
|
|
|
|||
|
|
@ -132,6 +132,11 @@ class TestSubscriptionManagerInit:
|
|||
|
||||
|
||||
class TestConnectionLifecycle:
|
||||
async def test_start_subscription_rejects_trailing_newline_in_name(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
with pytest.raises(ValueError, match="only \\[a-zA-Z0-9_\\]"):
|
||||
await mgr.start_subscription("test_sub\n", SAMPLE_QUERY)
|
||||
|
||||
async def test_start_subscription_creates_task(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
ws = FakeWebSocket([{"type": "connection_ack"}])
|
||||
|
|
@ -186,6 +191,21 @@ class TestConnectionLifecycle:
|
|||
mgr = SubscriptionManager()
|
||||
await mgr.stop_subscription("nonexistent")
|
||||
|
||||
async def test_stop_subscription_reasserts_stopped_after_cancel(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
|
||||
async def _runner() -> None:
|
||||
try:
|
||||
await asyncio.Future()
|
||||
except asyncio.CancelledError:
|
||||
mgr.connection_states["test_sub"] = "reconnecting"
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(_runner())
|
||||
mgr.active_subscriptions["test_sub"] = task
|
||||
await mgr.stop_subscription("test_sub")
|
||||
assert mgr.connection_states["test_sub"] == "stopped"
|
||||
|
||||
async def test_connection_state_transitions(self) -> None:
|
||||
mgr = SubscriptionManager()
|
||||
ws = FakeWebSocket([{"type": "connection_ack"}])
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class TestUnauthorized:
|
|||
return asyncio.get_event_loop().run_until_complete(_collect_response(self.mw, scope))
|
||||
|
||||
def test_missing_header_returns_401(self):
|
||||
status, headers, body = self._run()
|
||||
status, _, _ = self._run()
|
||||
assert status == 401
|
||||
assert not self.called["value"]
|
||||
|
||||
|
|
@ -143,6 +143,10 @@ class TestUnauthorized:
|
|||
_, _, body = self._run()
|
||||
assert b'"error":"unauthorized"' in body
|
||||
|
||||
def test_missing_header_counts_toward_rate_limit(self):
|
||||
self._run(ip="10.10.10.10")
|
||||
assert len(self.mw._ip_failures["10.10.10.10"]) == 1
|
||||
|
||||
def test_wrong_token_returns_401(self):
|
||||
status, _, _ = self._run(headers=[(b"authorization", b"Bearer wrong-token")])
|
||||
assert status == 401
|
||||
|
|
@ -254,6 +258,16 @@ class TestRateLimiting:
|
|||
scope = _make_http_scope(headers=[(b"authorization", b"Bearer wrong")], client_ip=ip)
|
||||
status, _, _ = asyncio.get_event_loop().run_until_complete(_collect_response(mw, scope))
|
||||
assert status == 401
|
||||
assert len(mw._ip_failures[ip]) == 1
|
||||
|
||||
def test_stale_warning_timestamps_are_evicted(self):
|
||||
app, _ = _app_called_flag()
|
||||
mw = BearerAuthMiddleware(app, token="secret")
|
||||
ip = "172.16.0.2"
|
||||
mw._ip_last_warn[ip] = time.monotonic() - (_RATE_WINDOW_SECS + 1)
|
||||
scope = _make_http_scope(headers=[(b"authorization", b"Bearer wrong")], client_ip=ip)
|
||||
asyncio.get_event_loop().run_until_complete(_collect_response(mw, scope))
|
||||
assert ip in mw._ip_last_warn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -347,6 +361,31 @@ class TestEnsureTokenExists:
|
|||
s.UNRAID_MCP_DISABLE_HTTP_AUTH = orig_disabled
|
||||
os.environ.pop("UNRAID_MCP_BEARER_TOKEN", None)
|
||||
|
||||
def test_generated_token_is_not_printed_to_stderr(self, tmp_path, capsys):
|
||||
import unraid_mcp.config.settings as s
|
||||
|
||||
orig_token = s.UNRAID_MCP_BEARER_TOKEN
|
||||
orig_disabled = s.UNRAID_MCP_DISABLE_HTTP_AUTH
|
||||
try:
|
||||
s.UNRAID_MCP_BEARER_TOKEN = None
|
||||
s.UNRAID_MCP_DISABLE_HTTP_AUTH = False
|
||||
|
||||
env_path = tmp_path / ".env"
|
||||
with (
|
||||
patch("unraid_mcp.server.CREDENTIALS_DIR", tmp_path),
|
||||
patch("unraid_mcp.server.CREDENTIALS_ENV_PATH", env_path),
|
||||
):
|
||||
from unraid_mcp.server import ensure_token_exists
|
||||
|
||||
ensure_token_exists()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "UNRAID_MCP_BEARER_TOKEN=" not in captured.err
|
||||
assert s.UNRAID_MCP_BEARER_TOKEN not in captured.err
|
||||
finally:
|
||||
s.UNRAID_MCP_BEARER_TOKEN = orig_token
|
||||
s.UNRAID_MCP_DISABLE_HTTP_AUTH = orig_disabled
|
||||
|
||||
|
||||
class TestStartupGuard:
|
||||
def test_startup_guard_exits_when_http_no_token_no_disable(self):
|
||||
|
|
|
|||
|
|
@ -148,3 +148,15 @@ def test_collect_actions_all_handled():
|
|||
f"COLLECT_ACTIONS keys without handlers in _handle_live: {unhandled}. "
|
||||
"Add an if-branch in unraid_mcp/tools/_live.py and update _HANDLED_COLLECT_SUBACTIONS."
|
||||
)
|
||||
|
||||
|
||||
def test_collect_actions_rejects_stale_handled_keys(monkeypatch):
|
||||
import unraid_mcp.tools._live as live_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
live_module,
|
||||
"_HANDLED_COLLECT_SUBACTIONS",
|
||||
frozenset({"log_tail", "notification_feed", "stale_key"}),
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="stale"):
|
||||
live_module._assert_collect_subactions_complete()
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ PASS=0
|
|||
FAIL=0
|
||||
SKIP=0
|
||||
|
||||
pass() { echo " PASS: $1"; ((PASS++)); }
|
||||
fail() { echo " FAIL: $1 — $2"; ((FAIL++)); }
|
||||
skip() { echo " SKIP: $1 — $2"; ((SKIP++)); }
|
||||
pass() { echo " PASS: $1"; ((++PASS)); }
|
||||
fail() { echo " FAIL: $1 — $2"; ((++FAIL)); }
|
||||
skip() { echo " SKIP: $1 — $2"; ((++SKIP)); }
|
||||
|
||||
header() { echo; echo "=== $1 ==="; }
|
||||
|
||||
|
|
@ -105,9 +105,16 @@ npx mcporter call "${SERVER_NAME}.unraid" \
|
|||
# ── Resources (server-level, no tool name needed) ────────────────────────────
|
||||
header "Resources"
|
||||
|
||||
npx mcporter call "${SERVER_NAME}" --http-url "$MCP_URL" --header "$AUTH_HEADER" \
|
||||
--list-resources > /dev/null 2>&1 \
|
||||
&& pass "resources/list" || skip "resources/list" "no resources defined"
|
||||
resources_output="$(
|
||||
npx mcporter call "${SERVER_NAME}" --http-url "$MCP_URL" --header "$AUTH_HEADER" \
|
||||
--list-resources 2>&1
|
||||
)" && pass "resources/list" || {
|
||||
if printf '%s' "$resources_output" | grep -qi "no resources defined"; then
|
||||
skip "resources/list" "no resources defined"
|
||||
else
|
||||
fail "resources/list" "$resources_output"
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Bearer token enforcement ─────────────────────────────────────────────────
|
||||
header "Bearer token enforcement"
|
||||
|
|
|
|||
92
tests/test_review_regressions.py
Normal file
92
tests/test_review_regressions.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
"""Regression coverage for packaging, manifests, and hook scripts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def test_plugin_manifest_restores_stdio_server_definition() -> None:
|
||||
plugin = json.loads((PROJECT_ROOT / ".claude-plugin" / "plugin.json").read_text())
|
||||
assert plugin["userConfig"]["unraid_mcp_token"]["sensitive"] is True
|
||||
assert "mcpServers" in plugin
|
||||
assert plugin["mcpServers"]["unraid"]["type"] == "stdio"
|
||||
|
||||
|
||||
def test_container_configs_use_runtime_port_variable() -> None:
|
||||
compose = (PROJECT_ROOT / "docker-compose.yaml").read_text()
|
||||
dockerfile = (PROJECT_ROOT / "Dockerfile").read_text()
|
||||
assert "${UNRAID_MCP_PORT:-6970}:${UNRAID_MCP_PORT:-6970}" in compose
|
||||
assert "os.getenv('UNRAID_MCP_PORT', '6970')" in compose
|
||||
assert "os.getenv('UNRAID_MCP_PORT', '6970')" in dockerfile
|
||||
|
||||
|
||||
def test_test_live_script_uses_safe_counters_and_resource_failures() -> None:
|
||||
script = (PROJECT_ROOT / "tests" / "test_live.sh").read_text()
|
||||
assert "((++PASS))" in script
|
||||
assert "((++FAIL))" in script
|
||||
assert "((++SKIP))" in script
|
||||
assert 'fail "resources/list" "$resources_output"' in script
|
||||
|
||||
|
||||
def test_sync_env_rejects_multiline_values(tmp_path: Path) -> None:
|
||||
env = os.environ.copy()
|
||||
env["CLAUDE_PLUGIN_ROOT"] = str(tmp_path)
|
||||
env["CLAUDE_PLUGIN_OPTION_UNRAID_API_URL"] = "https://tower.local\nINJECT=1"
|
||||
|
||||
result = subprocess.run( # noqa: S603
|
||||
["/usr/bin/bash", str(PROJECT_ROOT / "hooks" / "scripts" / "sync-env.sh")],
|
||||
cwd=PROJECT_ROOT,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode != 0
|
||||
assert "control characters" in result.stderr
|
||||
|
||||
|
||||
def test_sync_env_regenerates_empty_bearer_token(tmp_path: Path) -> None:
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("UNRAID_MCP_BEARER_TOKEN=\n")
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CLAUDE_PLUGIN_ROOT"] = str(tmp_path)
|
||||
|
||||
result = subprocess.run( # noqa: S603
|
||||
["/usr/bin/bash", str(PROJECT_ROOT / "hooks" / "scripts" / "sync-env.sh")],
|
||||
cwd=PROJECT_ROOT,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
lines = env_file.read_text().splitlines()
|
||||
token_line = next(line for line in lines if line.startswith("UNRAID_MCP_BEARER_TOKEN="))
|
||||
assert token_line != "UNRAID_MCP_BEARER_TOKEN="
|
||||
|
||||
|
||||
def test_ensure_gitignore_preserves_ignore_before_negation(tmp_path: Path) -> None:
|
||||
gitignore = tmp_path / ".gitignore"
|
||||
gitignore.write_text("!backups/.gitkeep\n")
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CLAUDE_PLUGIN_ROOT"] = str(tmp_path)
|
||||
|
||||
result = subprocess.run( # noqa: S603
|
||||
["/usr/bin/bash", str(PROJECT_ROOT / "hooks" / "scripts" / "ensure-gitignore.sh")],
|
||||
cwd=PROJECT_ROOT,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
lines = gitignore.read_text().splitlines()
|
||||
assert lines.index("backups/*") < lines.index("!backups/.gitkeep")
|
||||
|
|
@ -89,6 +89,21 @@ class TestSettingsUpdate:
|
|||
assert result["success"] is True
|
||||
assert result["subaction"] == "update"
|
||||
|
||||
async def test_update_allows_nested_json_values(self, _mock_graphql: AsyncMock) -> None:
|
||||
_mock_graphql.return_value = {
|
||||
"updateSettings": {"restartRequired": False, "values": {}, "warnings": []}
|
||||
}
|
||||
tool_fn = _make_tool()
|
||||
payload = {
|
||||
"themeOverrides": {"sidebar": None, "panels": ["cpu", "memory"]},
|
||||
"advanced": [1, True, {"nested": "ok"}],
|
||||
}
|
||||
result = await tool_fn(action="setting", subaction="update", settings_input=payload)
|
||||
assert result["success"] is True
|
||||
_mock_graphql.assert_awaited_once()
|
||||
sent_payload = _mock_graphql.await_args.args[1]["input"]
|
||||
assert sent_payload == payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# configure_ups
|
||||
|
|
@ -114,3 +129,13 @@ class TestUpsConfig:
|
|||
)
|
||||
assert result["success"] is True
|
||||
assert result["subaction"] == "configure_ups"
|
||||
|
||||
async def test_configure_ups_rejects_nested_values(self, _mock_graphql: AsyncMock) -> None:
|
||||
tool_fn = _make_tool()
|
||||
with pytest.raises(ToolError, match="must be a string, number, or boolean"):
|
||||
await tool_fn(
|
||||
action="setting",
|
||||
subaction="configure_ups",
|
||||
confirm=True,
|
||||
ups_config={"mode": {"nested": "invalid"}},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -238,6 +238,12 @@ class TestStorageActions:
|
|||
with pytest.raises(ToolError, match="Log file not found or inaccessible"):
|
||||
await tool_fn(action="disk", subaction="logs", log_path="/var/log/syslog")
|
||||
|
||||
async def test_logs_empty_log_file_payload_raises(self, _mock_graphql: AsyncMock) -> None:
|
||||
_mock_graphql.return_value = {"logFile": {}}
|
||||
tool_fn = _make_tool()
|
||||
with pytest.raises(ToolError, match="Log file not found or inaccessible"):
|
||||
await tool_fn(action="disk", subaction="logs", log_path="/var/log/syslog")
|
||||
|
||||
async def test_disk_details_not_found(self, _mock_graphql: AsyncMock) -> None:
|
||||
_mock_graphql.return_value = {"disk": None}
|
||||
tool_fn = _make_tool()
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ def get_config_summary() -> dict[str, Any]:
|
|||
"missing_config": missing if not is_valid else None,
|
||||
# Auth fields only meaningful in HTTP mode
|
||||
"http_auth_enabled": is_http and not UNRAID_MCP_DISABLE_HTTP_AUTH,
|
||||
"http_auth_token_set": bool(UNRAID_MCP_BEARER_TOKEN) if is_http else None,
|
||||
"http_auth_token_set": bool(UNRAID_MCP_BEARER_TOKEN) if is_http else False,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -93,6 +93,8 @@ class BearerAuthMiddleware:
|
|||
|
||||
if auth_header is None:
|
||||
# No Authorization header at all — prompt client per RFC 6750
|
||||
self._record_failure(client_ip)
|
||||
self._maybe_warn(client_ip, "missing authorization header")
|
||||
await self._send_response(
|
||||
send,
|
||||
status=401,
|
||||
|
|
@ -156,15 +158,12 @@ class BearerAuthMiddleware:
|
|||
"""Return True if this IP has hit the failure rate limit."""
|
||||
if ip not in self._ip_failures:
|
||||
return False
|
||||
now = time.monotonic()
|
||||
q = self._ip_failures[ip]
|
||||
cutoff = now - _RATE_WINDOW_SECS
|
||||
while q and q[0] < cutoff:
|
||||
q.popleft()
|
||||
return len(q) >= _RATE_MAX_FAILURES
|
||||
self._prune_ip_state(ip)
|
||||
return len(self._ip_failures.get(ip, ())) >= _RATE_MAX_FAILURES
|
||||
|
||||
def _record_failure(self, ip: str) -> None:
|
||||
"""Record one failed auth attempt for this IP."""
|
||||
self._prune_ip_state(ip)
|
||||
if ip not in self._ip_failures:
|
||||
self._ip_failures[ip] = deque()
|
||||
self._ip_failures[ip].append(time.monotonic())
|
||||
|
|
@ -179,6 +178,21 @@ class BearerAuthMiddleware:
|
|||
self._ip_last_warn[ip] = now
|
||||
logger.warning("Bearer auth rejected (%s) from %s", reason, ip)
|
||||
|
||||
def _prune_ip_state(self, ip: str) -> None:
|
||||
"""Drop stale failure and warning-tracking state for one IP."""
|
||||
now = time.monotonic()
|
||||
q = self._ip_failures.get(ip)
|
||||
if q is not None:
|
||||
cutoff = now - _RATE_WINDOW_SECS
|
||||
while q and q[0] < cutoff:
|
||||
q.popleft()
|
||||
if not q:
|
||||
self._ip_failures.pop(ip, None)
|
||||
|
||||
last_warn = self._ip_last_warn.get(ip)
|
||||
if last_warn is not None and (now - last_warn) >= _RATE_WINDOW_SECS:
|
||||
self._ip_last_warn.pop(ip, None)
|
||||
|
||||
@staticmethod
|
||||
async def _send_response(
|
||||
send: Send,
|
||||
|
|
|
|||
|
|
@ -37,11 +37,13 @@ from .subscriptions.resources import register_subscription_resources
|
|||
from .tools.unraid import register_unraid_tool
|
||||
|
||||
|
||||
def _chmod_safe(path: object, mode: int) -> None:
|
||||
"""Chmod with graceful fallback for volume mounts owned by root."""
|
||||
def _chmod_safe(path: object, mode: int, *, strict: bool = False) -> None:
|
||||
"""Best-effort chmod, with optional fail-closed behavior for secrets."""
|
||||
try:
|
||||
path.chmod(mode) # type: ignore[union-attr]
|
||||
except PermissionError:
|
||||
except PermissionError as exc:
|
||||
if strict:
|
||||
raise RuntimeError(f"Failed to secure permissions on {path}") from exc
|
||||
logger.debug("Could not chmod %s (volume mount?) — skipping", path)
|
||||
|
||||
|
||||
|
|
@ -148,7 +150,7 @@ def ensure_token_exists() -> None:
|
|||
|
||||
# Ensure credentials dir exists with restricted permissions.
|
||||
CREDENTIALS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_chmod_safe(CREDENTIALS_DIR, 0o700)
|
||||
_chmod_safe(CREDENTIALS_DIR, 0o700, strict=True)
|
||||
|
||||
# Touch the file first so set_key has a target (no-op if already exists)
|
||||
if not CREDENTIALS_ENV_PATH.exists():
|
||||
|
|
@ -156,13 +158,11 @@ def ensure_token_exists() -> None:
|
|||
|
||||
# In-place .env write — preserves comments and existing keys
|
||||
set_key(str(CREDENTIALS_ENV_PATH), "UNRAID_MCP_BEARER_TOKEN", token, quote_mode="auto")
|
||||
_chmod_safe(CREDENTIALS_ENV_PATH, 0o600)
|
||||
_chmod_safe(CREDENTIALS_ENV_PATH, 0o600, strict=True)
|
||||
|
||||
# Print once to STDERR so the user can copy it into their MCP client config
|
||||
print(
|
||||
f"\n[unraid-mcp] Generated HTTP bearer token (saved to {CREDENTIALS_ENV_PATH}):\n"
|
||||
f" UNRAID_MCP_BEARER_TOKEN={token}\n"
|
||||
" Add this to your MCP client's Authorization header: Bearer <token>\n",
|
||||
f"\n[unraid-mcp] Generated HTTP bearer token and saved it to {CREDENTIALS_ENV_PATH}.\n"
|
||||
"Configure your MCP client to send Authorization: Bearer <token> using that stored value.\n",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
|
@ -178,64 +178,58 @@ def run_server() -> None:
|
|||
"""Run the MCP server with the configured transport."""
|
||||
is_http = _settings.UNRAID_MCP_TRANSPORT in ("streamable-http", "sse")
|
||||
|
||||
# Auto-generate token before the startup guard so a fresh install
|
||||
# can start without manual intervention.
|
||||
if is_http:
|
||||
ensure_token_exists()
|
||||
try:
|
||||
if is_http:
|
||||
ensure_token_exists()
|
||||
|
||||
# Hard stop: HTTP mode with no token and auth not explicitly disabled.
|
||||
# We deliberately do NOT mention DISABLE_HTTP_AUTH in the error message to
|
||||
# avoid inadvertently guiding users toward disabling auth.
|
||||
if (
|
||||
is_http
|
||||
and not _settings.UNRAID_MCP_DISABLE_HTTP_AUTH
|
||||
and not _settings.UNRAID_MCP_BEARER_TOKEN
|
||||
):
|
||||
print(
|
||||
"FATAL: HTTP transport requires a bearer token. "
|
||||
"Set UNRAID_MCP_BEARER_TOKEN in ~/.unraid-mcp/.env or restart to auto-generate.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if (
|
||||
is_http
|
||||
and not _settings.UNRAID_MCP_DISABLE_HTTP_AUTH
|
||||
and not _settings.UNRAID_MCP_BEARER_TOKEN
|
||||
):
|
||||
print(
|
||||
"FATAL: HTTP transport requires a bearer token. "
|
||||
"Set UNRAID_MCP_BEARER_TOKEN in ~/.unraid-mcp/.env or restart to auto-generate.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Validate Unraid API credentials
|
||||
is_valid, missing = validate_required_config()
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"Missing configuration: %s. "
|
||||
"Server will prompt for credentials on first tool call via elicitation.",
|
||||
", ".join(missing),
|
||||
)
|
||||
|
||||
log_configuration_status(logger)
|
||||
|
||||
if UNRAID_VERIFY_SSL is False:
|
||||
logger.warning(
|
||||
"SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). "
|
||||
"Connections to Unraid API are vulnerable to man-in-the-middle attacks. "
|
||||
"Only use this in trusted networks or for development."
|
||||
)
|
||||
|
||||
if is_http:
|
||||
if _settings.UNRAID_MCP_DISABLE_HTTP_AUTH:
|
||||
is_valid, missing = validate_required_config()
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"HTTP auth disabled (UNRAID_MCP_DISABLE_HTTP_AUTH=true). "
|
||||
"Ensure an upstream gateway enforces authentication."
|
||||
"Missing configuration: %s. "
|
||||
"Server will prompt for credentials on first tool call via elicitation.",
|
||||
", ".join(missing),
|
||||
)
|
||||
|
||||
log_configuration_status(logger)
|
||||
|
||||
if UNRAID_VERIFY_SSL is False:
|
||||
logger.warning(
|
||||
"SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). "
|
||||
"Connections to Unraid API are vulnerable to man-in-the-middle attacks. "
|
||||
"Only use this in trusted networks or for development."
|
||||
)
|
||||
|
||||
if is_http:
|
||||
if _settings.UNRAID_MCP_DISABLE_HTTP_AUTH:
|
||||
logger.warning(
|
||||
"HTTP auth disabled (UNRAID_MCP_DISABLE_HTTP_AUTH=true). "
|
||||
"Ensure an upstream gateway enforces authentication."
|
||||
)
|
||||
else:
|
||||
logger.info("HTTP bearer token authentication enabled.")
|
||||
|
||||
if is_http:
|
||||
logger.info(
|
||||
"Starting Unraid MCP Server on %s:%s using %s transport...",
|
||||
UNRAID_MCP_HOST,
|
||||
UNRAID_MCP_PORT,
|
||||
_settings.UNRAID_MCP_TRANSPORT,
|
||||
)
|
||||
else:
|
||||
logger.info("HTTP bearer token authentication enabled.")
|
||||
logger.info("Starting Unraid MCP Server using stdio transport...")
|
||||
|
||||
if is_http:
|
||||
logger.info(
|
||||
"Starting Unraid MCP Server on %s:%s using %s transport...",
|
||||
UNRAID_MCP_HOST,
|
||||
UNRAID_MCP_PORT,
|
||||
_settings.UNRAID_MCP_TRANSPORT,
|
||||
)
|
||||
else:
|
||||
logger.info("Starting Unraid MCP Server using stdio transport...")
|
||||
|
||||
try:
|
||||
if is_http:
|
||||
if _settings.UNRAID_MCP_TRANSPORT == "sse":
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -35,6 +35,12 @@ _last_graphql_error: dict[str, str] = {}
|
|||
_graphql_error_count: dict[str, int] = {}
|
||||
|
||||
|
||||
def _clear_graphql_error_burst(subscription_name: str) -> None:
|
||||
"""Reset deduplicated GraphQL error tracking for one subscription."""
|
||||
_last_graphql_error.pop(subscription_name, None)
|
||||
_graphql_error_count.pop(subscription_name, None)
|
||||
|
||||
|
||||
def _preview(message: str | bytes, n: int = 200) -> str:
|
||||
"""Return the first *n* characters of *message* as a UTF-8 string.
|
||||
|
||||
|
|
@ -193,9 +199,21 @@ class SubscriptionManager:
|
|||
self.last_error[name] = str(e)
|
||||
start_errors.append((name, e))
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for name, config in auto_start_configs:
|
||||
tg.create_task(_start_one(name, config))
|
||||
started_names: list[str] = []
|
||||
|
||||
async def _tracked_start(name: str, config: dict[str, Any]) -> None:
|
||||
await _start_one(name, config)
|
||||
started_names.append(name)
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for name, config in auto_start_configs:
|
||||
tg.create_task(_tracked_start(name, config))
|
||||
except asyncio.CancelledError:
|
||||
for name in started_names:
|
||||
if name in self.active_subscriptions:
|
||||
await self.stop_subscription(name)
|
||||
raise
|
||||
|
||||
started = len(auto_start_configs) - len(start_errors)
|
||||
logger.info(
|
||||
|
|
@ -211,10 +229,11 @@ class SubscriptionManager:
|
|||
self, subscription_name: str, query: str, variables: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Start a GraphQL subscription and maintain it as a resource."""
|
||||
if not re.match(r"^[a-zA-Z0-9_]+$", subscription_name):
|
||||
if not re.fullmatch(r"[a-zA-Z0-9_]+", subscription_name):
|
||||
raise ValueError(
|
||||
f"subscription_name must contain only [a-zA-Z0-9_], got: {subscription_name!r}"
|
||||
)
|
||||
_clear_graphql_error_burst(subscription_name)
|
||||
logger.info(f"[SUBSCRIPTION:{subscription_name}] Starting subscription...")
|
||||
|
||||
# Guard must be inside the lock to prevent a TOCTOU race where two
|
||||
|
|
@ -275,6 +294,8 @@ class SubscriptionManager:
|
|||
await task
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
|
||||
self.connection_states[subscription_name] = "stopped"
|
||||
_clear_graphql_error_burst(subscription_name)
|
||||
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
|
|
@ -461,6 +482,7 @@ class SubscriptionManager:
|
|||
logger.info(
|
||||
f"[DATA:{subscription_name}] Received subscription data update"
|
||||
)
|
||||
_clear_graphql_error_burst(subscription_name)
|
||||
capped_data = (
|
||||
_cap_log_content(payload["data"])
|
||||
if isinstance(payload["data"], dict)
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ async def _handle_disk(
|
|||
return {"log_files": data.get("logFiles", [])}
|
||||
if subaction == "logs":
|
||||
result = data.get("logFile")
|
||||
if result is None:
|
||||
if not result:
|
||||
raise ToolError(f"Log file not found or inaccessible: {log_path}")
|
||||
return dict(result)
|
||||
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ async def _comprehensive_health_check() -> dict[str, Any]:
|
|||
"timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||
"error": str(e),
|
||||
}
|
||||
except Exception as e:
|
||||
except _client.ToolError as e:
|
||||
# make_graphql_request wraps httpx network errors in ToolError; catch them
|
||||
# here so health/check returns {"status": "unhealthy"} on real outages
|
||||
# rather than propagating an unhandled ToolError to the caller.
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ _HANDLED_COLLECT_SUBACTIONS: frozenset[str] = frozenset({"log_tail", "notificati
|
|||
|
||||
|
||||
def _assert_collect_subactions_complete() -> None:
|
||||
"""Raise AssertionError at import time if COLLECT_ACTIONS has an unhandled key.
|
||||
"""Raise RuntimeError at import time if collect subactions drift.
|
||||
|
||||
Every key in COLLECT_ACTIONS must appear in _HANDLED_COLLECT_SUBACTIONS AND
|
||||
have a matching if-branch in _handle_live. This assertion catches the former
|
||||
|
|
@ -32,11 +32,16 @@ def _assert_collect_subactions_complete() -> None:
|
|||
from ..subscriptions.queries import COLLECT_ACTIONS
|
||||
|
||||
missing = set(COLLECT_ACTIONS) - _HANDLED_COLLECT_SUBACTIONS
|
||||
stale = _HANDLED_COLLECT_SUBACTIONS - set(COLLECT_ACTIONS)
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"_HANDLED_COLLECT_SUBACTIONS is missing keys from COLLECT_ACTIONS: {missing}. "
|
||||
"Add a handler branch in _handle_live and update _HANDLED_COLLECT_SUBACTIONS."
|
||||
)
|
||||
if stale:
|
||||
raise RuntimeError(
|
||||
f"_HANDLED_COLLECT_SUBACTIONS contains stale keys not present in COLLECT_ACTIONS: {stale}."
|
||||
)
|
||||
|
||||
|
||||
_assert_collect_subactions_complete()
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ from ..core.validation import DANGEROUS_KEY_PATTERN, MAX_VALUE_LENGTH
|
|||
_MAX_SETTINGS_KEYS = 100
|
||||
|
||||
|
||||
def _validate_settings_input(settings_input: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate settings_input before forwarding to the Unraid API.
|
||||
def _validate_settings_mapping(settings_input: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate flat scalar settings data before forwarding to the Unraid API.
|
||||
|
||||
Enforces a key count cap and rejects dangerous key names and oversized values
|
||||
to prevent unvalidated bulk input from reaching the API. Modeled on
|
||||
|
|
@ -60,6 +60,22 @@ def _validate_settings_input(settings_input: dict[str, Any]) -> dict[str, Any]:
|
|||
return validated
|
||||
|
||||
|
||||
def _validate_json_settings_input(settings_input: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate JSON-typed settings input without narrowing valid JSON members."""
|
||||
if len(settings_input) > _MAX_SETTINGS_KEYS:
|
||||
raise ToolError(f"settings_input has {len(settings_input)} keys (max {_MAX_SETTINGS_KEYS})")
|
||||
validated: dict[str, Any] = {}
|
||||
for key, value in settings_input.items():
|
||||
if not isinstance(key, str) or not key.strip():
|
||||
raise ToolError(
|
||||
f"settings_input keys must be non-empty strings, got: {type(key).__name__}"
|
||||
)
|
||||
if DANGEROUS_KEY_PATTERN.search(key):
|
||||
raise ToolError(f"settings_input key '{key}' contains disallowed characters")
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
|
||||
_SETTING_MUTATIONS: dict[str, str] = {
|
||||
"update": "mutation UpdateSettings($input: JSON!) { updateSettings(input: $input) { restartRequired values warnings } }",
|
||||
"configure_ups": "mutation ConfigureUps($config: UPSConfigInput!) { configureUps(config: $config) }",
|
||||
|
|
@ -95,7 +111,7 @@ async def _handle_setting(
|
|||
if subaction == "update":
|
||||
if settings_input is None:
|
||||
raise ToolError("settings_input is required for setting/update")
|
||||
validated_input = _validate_settings_input(settings_input)
|
||||
validated_input = _validate_json_settings_input(settings_input)
|
||||
data = await _client.make_graphql_request(
|
||||
_SETTING_MUTATIONS["update"], {"input": validated_input}
|
||||
)
|
||||
|
|
@ -107,7 +123,7 @@ async def _handle_setting(
|
|||
# Validate ups_config with the same rules as settings_input — key count
|
||||
# cap, scalar-only values, MAX_VALUE_LENGTH — to prevent unvalidated bulk
|
||||
# input from reaching the GraphQL mutation.
|
||||
validated_ups = _validate_settings_input(ups_config)
|
||||
validated_ups = _validate_settings_mapping(ups_config)
|
||||
data = await _client.make_graphql_request(
|
||||
_SETTING_MUTATIONS["configure_ups"], {"config": validated_ups}
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue