diff --git a/.claude-plugin/plugin.json b/.claude-plugin/plugin.json index 3dcfd7a..4764ebf 100644 --- a/.claude-plugin/plugin.json +++ b/.claude-plugin/plugin.json @@ -11,6 +11,21 @@ "homepage": "https://github.com/jmagar/unraid-mcp", "license": "MIT", "keywords": ["unraid", "homelab", "mcp", "graphql", "docker", "nas", "monitoring"], + "mcpServers": { + "unraid": { + "type": "stdio", + "command": "uv", + "args": [ + "run", + "--directory", + "${CLAUDE_PLUGIN_ROOT}", + "unraid-mcp-server" + ], + "env": { + "UNRAID_MCP_TRANSPORT": "stdio" + } + } + }, "userConfig": { "unraid_mcp_url": { "type": "string", @@ -23,7 +38,7 @@ "type": "string", "title": "MCP Server Bearer Token", "description": "Bearer token for authenticating with the MCP server. Must match UNRAID_MCP_BEARER_TOKEN in the server's .env. Generate with: openssl rand -hex 32", - "sensitive": false + "sensitive": true }, "unraid_api_url": { "type": "string", diff --git a/CHANGELOG.md b/CHANGELOG.md index 365301b..71d8b71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,10 @@ All notable changes to this project are documented here. ### Added - **HTTP bearer token auth**: ASGI-level `BearerAuthMiddleware` (pure `__call__` pattern, no BaseHTTPMiddleware overhead) enforces `Authorization: Bearer ` on all HTTP requests. RFC 6750 compliant — missing header returns `WWW-Authenticate: Bearer realm="unraid-mcp"`, invalid token adds `error="invalid_token"`. -- **Auto token generation**: On first HTTP startup with no token configured, a `secrets.token_urlsafe(32)` token is generated, written to `~/.unraid-mcp/.env` (mode 600), printed once to STDERR, and removed from `os.environ` so subprocesses cannot inherit it. +- **Auto token generation**: On first HTTP startup with no token configured, a `secrets.token_urlsafe(32)` token is generated, written to `~/.unraid-mcp/.env` (mode 600), announced once on STDERR without printing the secret, and removed from `os.environ` so subprocesses cannot inherit it. - **Per-IP rate limiting**: 60 failed auth attempts per 60 seconds → 429 with `Retry-After: 60` header. - **Gateway escape hatch**: `UNRAID_MCP_DISABLE_HTTP_AUTH=true` bypasses bearer auth for users who handle authentication at a reverse proxy / gateway layer. -- **Startup guard**: Server refuses to start in HTTP mode (`streamable-http`/`sse`) if no token is set and `DISABLE_HTTP_AUTH` is not explicitly enabled. +- **Startup guard**: Server refuses to start in HTTP mode (`streamable-http`/`sse`) if no token is set and `UNRAID_MCP_DISABLE_HTTP_AUTH` is not explicitly enabled. - **Tests**: 23 new tests in `tests/test_auth.py` covering pass-through scopes, 401/429 responses, RFC 6750 header differentiation, per-IP rate limiting, window expiry, token generation, and startup guard. ### Changed @@ -22,10 +22,12 @@ All notable changes to this project are documented here. ## [1.1.6] - 2026-03-30 ### Security + - **Path traversal**: `flash_backup` source path now validated after `posixpath.normpath` (not before) — raw-string `..` check was bypassable via encoded sequences like `foo/bar/../..`; null byte guard added - **Key validation**: `DANGEROUS_KEY_PATTERN` now blocks space (0x20) and DEL (0x7f) in addition to existing shell metacharacters; applies to both rclone and settings key validation ### Fixed + - **Settings validation**: `configure_ups` input now validated via `_validate_settings_input` before mutation — was previously passing unvalidated dict directly to GraphQL - **Subscription locks**: `_start_one` `last_error` write and `stop_all()` keys snapshot both now take `_task_lock` to prevent concurrent write/read races - **Keepalive handling**: Removed `"ping"` from keepalive `elif` — ping messages require a pong response, not silent discard; only `"ka"` and `"pong"` are silently dropped @@ -33,11 +35,13 @@ All notable changes to this project are documented here. - **Health reverse map**: `_STATUS_FROM_SEVERITY` dict hoisted to module level — was being rebuilt on every `_comprehensive_health_check` call ### Changed + - **Log content cap**: `_cap_log_content` now skipped for non-log subscriptions (only `log_tail`/`logFileSubscription` have `content` fields) — reduces unnecessary dict key lookups on every WebSocket message - **Live assertion**: `_handle_live` now raises `RuntimeError` at import time if `COLLECT_ACTIONS` contains keys not in `_HANDLED_COLLECT_SUBACTIONS` — catches handler omissions before runtime - **Subscription name guard**: `start_subscription` validates name matches `^[a-zA-Z0-9_]+$` before use as WebSocket message ID ### Added + - **Tests**: 27 parametrized tests for `DANGEROUS_KEY_PATTERN` covering all documented dangerous characters and safe key examples (`tests/test_validation.py`) - **Tests**: `test_check_api_error_wrapped_tool_error` — verifies health check returns `{status: unhealthy}` when `make_graphql_request` raises `ToolError` wrapping `httpx.ConnectError` diff --git a/Dockerfile b/Dockerfile index 03e7ed2..5207d2e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,6 +50,6 @@ EXPOSE 6970 USER mcp HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ - CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:6970/health')" || exit 1 + CMD python -c "import os, urllib.request; urllib.request.urlopen('http://localhost:%s/health' % os.getenv('UNRAID_MCP_PORT', '6970'))" || exit 1 ENTRYPOINT ["unraid-mcp-server"] diff --git a/docker-compose.yaml b/docker-compose.yaml index 9d2501a..2addc36 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,7 +4,7 @@ services: container_name: unraid-mcp restart: unless-stopped ports: - - "${UNRAID_MCP_PORT:-6970}:6970" + - "${UNRAID_MCP_PORT:-6970}:${UNRAID_MCP_PORT:-6970}" env_file: - .env volumes: @@ -12,7 +12,7 @@ services: - ./backups:/app/backups - unraid-mcp-credentials:/home/mcp/.unraid-mcp healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:6970/health')"] + test: ["CMD-SHELL", "python -c \"import os, urllib.request; urllib.request.urlopen('http://localhost:%s/health' % os.getenv('UNRAID_MCP_PORT', '6970'))\""] interval: 30s timeout: 5s start_period: 10s diff --git a/hooks/scripts/ensure-gitignore.sh b/hooks/scripts/ensure-gitignore.sh index 8ac55d7..b995c6c 100755 --- a/hooks/scripts/ensure-gitignore.sh +++ b/hooks/scripts/ensure-gitignore.sh @@ -16,8 +16,46 @@ REQUIRED=( touch "$GITIGNORE" +existing="$(cat "$GITIGNORE")" for pattern in "${REQUIRED[@]}"; do if ! grep -qxF "$pattern" "$GITIGNORE" 2>/dev/null; then - echo "$pattern" >> "$GITIGNORE" + existing+=$'\n'"$pattern" fi done + +printf '%s\n' "$existing" | awk ' + BEGIN { + want[".env"]=1 + want[".env.*"]=1 + want["!.env.example"]=1 + want["backups/*"]=1 + want["!backups/.gitkeep"]=1 + want["logs/*"]=1 + want["!logs/.gitkeep"]=1 + want["__pycache__/"]=1 + } + { lines[++n]=$0 } + END { + emitted[""] = 1 + for (i = 1; i <= n; i++) { + if (!want[lines[i]] && !emitted[lines[i]]) { + print lines[i] + emitted[lines[i]] = 1 + } + } + ordered[1]=".env" + ordered[2]=".env.*" + ordered[3]="!.env.example" + ordered[4]="backups/*" + ordered[5]="!backups/.gitkeep" + ordered[6]="logs/*" + ordered[7]="!logs/.gitkeep" + ordered[8]="__pycache__/" + for (i = 1; i <= 8; i++) { + if (!emitted[ordered[i]]) { + print ordered[i] + emitted[ordered[i]] = 1 + } + } + } +' > "$GITIGNORE" diff --git a/hooks/scripts/sync-env.sh b/hooks/scripts/sync-env.sh index 173ef16..c67dc2d 100755 --- a/hooks/scripts/sync-env.sh +++ b/hooks/scripts/sync-env.sh @@ -21,6 +21,10 @@ fi for key in "${!MANAGED[@]}"; do value="${MANAGED[$key]}" [ -z "$value" ] && continue + if [[ "$value" == *$'\n'* || "$value" == *$'\r'* || "$value" == *$'\t'* ]]; then + echo "sync-env: refusing ${key} with control characters" >&2 + exit 1 + fi escaped_value=$(printf '%s\n' "$value" | sed 's/[&/\|]/\\&/g') if grep -q "^${key}=" "$ENV_FILE" 2>/dev/null; then sed -i "s|^${key}=.*|${key}=${escaped_value}|" "$ENV_FILE" @@ -30,7 +34,8 @@ for key in "${!MANAGED[@]}"; do done # Auto-generate UNRAID_MCP_BEARER_TOKEN if not yet set -if ! grep -q "^UNRAID_MCP_BEARER_TOKEN=" "$ENV_FILE" 2>/dev/null; then +if ! grep -Eq '^UNRAID_MCP_BEARER_TOKEN=.+$' "$ENV_FILE" 2>/dev/null; then + sed -i '/^UNRAID_MCP_BEARER_TOKEN=/d' "$ENV_FILE" generated=$(openssl rand -hex 32) echo "UNRAID_MCP_BEARER_TOKEN=${generated}" >> "$ENV_FILE" echo "sync-env: generated UNRAID_MCP_BEARER_TOKEN (update plugin userConfig to match)" >&2 diff --git a/skills/unraid/SKILL.md b/skills/unraid/SKILL.md index 4fc6a5c..2ecac26 100644 --- a/skills/unraid/SKILL.md +++ b/skills/unraid/SKILL.md @@ -328,5 +328,5 @@ curl -s "$CLAUDE_PLUGIN_OPTION_UNRAID_API_URL" \ curl -s "$CLAUDE_PLUGIN_OPTION_UNRAID_API_URL" \ -H "x-api-key: $CLAUDE_PLUGIN_OPTION_UNRAID_API_KEY" \ -H "Content-Type: application/json" \ - -d '{"query":"{ array { state capacity { disks { name status temp } } } }"}' + -d '{"query":"{ array { state capacity { kilobytes { free used total } } disks { name status temp } } }"}' ``` diff --git a/tests/integration/test_subscriptions.py b/tests/integration/test_subscriptions.py index 87c6736..2be74a2 100644 --- a/tests/integration/test_subscriptions.py +++ b/tests/integration/test_subscriptions.py @@ -132,6 +132,11 @@ class TestSubscriptionManagerInit: class TestConnectionLifecycle: + async def test_start_subscription_rejects_trailing_newline_in_name(self) -> None: + mgr = SubscriptionManager() + with pytest.raises(ValueError, match="only \\[a-zA-Z0-9_\\]"): + await mgr.start_subscription("test_sub\n", SAMPLE_QUERY) + async def test_start_subscription_creates_task(self) -> None: mgr = SubscriptionManager() ws = FakeWebSocket([{"type": "connection_ack"}]) @@ -186,6 +191,21 @@ class TestConnectionLifecycle: mgr = SubscriptionManager() await mgr.stop_subscription("nonexistent") + async def test_stop_subscription_reasserts_stopped_after_cancel(self) -> None: + mgr = SubscriptionManager() + + async def _runner() -> None: + try: + await asyncio.Future() + except asyncio.CancelledError: + mgr.connection_states["test_sub"] = "reconnecting" + raise + + task = asyncio.create_task(_runner()) + mgr.active_subscriptions["test_sub"] = task + await mgr.stop_subscription("test_sub") + assert mgr.connection_states["test_sub"] == "stopped" + async def test_connection_state_transitions(self) -> None: mgr = SubscriptionManager() ws = FakeWebSocket([{"type": "connection_ack"}]) diff --git a/tests/test_auth.py b/tests/test_auth.py index 2e7a653..0192cec 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -131,7 +131,7 @@ class TestUnauthorized: return asyncio.get_event_loop().run_until_complete(_collect_response(self.mw, scope)) def test_missing_header_returns_401(self): - status, headers, body = self._run() + status, _, _ = self._run() assert status == 401 assert not self.called["value"] @@ -143,6 +143,10 @@ class TestUnauthorized: _, _, body = self._run() assert b'"error":"unauthorized"' in body + def test_missing_header_counts_toward_rate_limit(self): + self._run(ip="10.10.10.10") + assert len(self.mw._ip_failures["10.10.10.10"]) == 1 + def test_wrong_token_returns_401(self): status, _, _ = self._run(headers=[(b"authorization", b"Bearer wrong-token")]) assert status == 401 @@ -254,6 +258,16 @@ class TestRateLimiting: scope = _make_http_scope(headers=[(b"authorization", b"Bearer wrong")], client_ip=ip) status, _, _ = asyncio.get_event_loop().run_until_complete(_collect_response(mw, scope)) assert status == 401 + assert len(mw._ip_failures[ip]) == 1 + + def test_stale_warning_timestamps_are_evicted(self): + app, _ = _app_called_flag() + mw = BearerAuthMiddleware(app, token="secret") + ip = "172.16.0.2" + mw._ip_last_warn[ip] = time.monotonic() - (_RATE_WINDOW_SECS + 1) + scope = _make_http_scope(headers=[(b"authorization", b"Bearer wrong")], client_ip=ip) + asyncio.get_event_loop().run_until_complete(_collect_response(mw, scope)) + assert ip in mw._ip_last_warn # --------------------------------------------------------------------------- @@ -347,6 +361,31 @@ class TestEnsureTokenExists: s.UNRAID_MCP_DISABLE_HTTP_AUTH = orig_disabled os.environ.pop("UNRAID_MCP_BEARER_TOKEN", None) + def test_generated_token_is_not_printed_to_stderr(self, tmp_path, capsys): + import unraid_mcp.config.settings as s + + orig_token = s.UNRAID_MCP_BEARER_TOKEN + orig_disabled = s.UNRAID_MCP_DISABLE_HTTP_AUTH + try: + s.UNRAID_MCP_BEARER_TOKEN = None + s.UNRAID_MCP_DISABLE_HTTP_AUTH = False + + env_path = tmp_path / ".env" + with ( + patch("unraid_mcp.server.CREDENTIALS_DIR", tmp_path), + patch("unraid_mcp.server.CREDENTIALS_ENV_PATH", env_path), + ): + from unraid_mcp.server import ensure_token_exists + + ensure_token_exists() + + captured = capsys.readouterr() + assert "UNRAID_MCP_BEARER_TOKEN=" not in captured.err + assert s.UNRAID_MCP_BEARER_TOKEN not in captured.err + finally: + s.UNRAID_MCP_BEARER_TOKEN = orig_token + s.UNRAID_MCP_DISABLE_HTTP_AUTH = orig_disabled + class TestStartupGuard: def test_startup_guard_exits_when_http_no_token_no_disable(self): diff --git a/tests/test_live.py b/tests/test_live.py index 02cd04e..c2fec49 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -148,3 +148,15 @@ def test_collect_actions_all_handled(): f"COLLECT_ACTIONS keys without handlers in _handle_live: {unhandled}. " "Add an if-branch in unraid_mcp/tools/_live.py and update _HANDLED_COLLECT_SUBACTIONS." ) + + +def test_collect_actions_rejects_stale_handled_keys(monkeypatch): + import unraid_mcp.tools._live as live_module + + monkeypatch.setattr( + live_module, + "_HANDLED_COLLECT_SUBACTIONS", + frozenset({"log_tail", "notification_feed", "stale_key"}), + ) + with pytest.raises(RuntimeError, match="stale"): + live_module._assert_collect_subactions_complete() diff --git a/tests/test_live.sh b/tests/test_live.sh index ea0e810..593ce70 100755 --- a/tests/test_live.sh +++ b/tests/test_live.sh @@ -11,9 +11,9 @@ PASS=0 FAIL=0 SKIP=0 -pass() { echo " PASS: $1"; ((PASS++)); } -fail() { echo " FAIL: $1 — $2"; ((FAIL++)); } -skip() { echo " SKIP: $1 — $2"; ((SKIP++)); } +pass() { echo " PASS: $1"; ((++PASS)); } +fail() { echo " FAIL: $1 — $2"; ((++FAIL)); } +skip() { echo " SKIP: $1 — $2"; ((++SKIP)); } header() { echo; echo "=== $1 ==="; } @@ -105,9 +105,16 @@ npx mcporter call "${SERVER_NAME}.unraid" \ # ── Resources (server-level, no tool name needed) ──────────────────────────── header "Resources" -npx mcporter call "${SERVER_NAME}" --http-url "$MCP_URL" --header "$AUTH_HEADER" \ - --list-resources > /dev/null 2>&1 \ - && pass "resources/list" || skip "resources/list" "no resources defined" +resources_output="$( + npx mcporter call "${SERVER_NAME}" --http-url "$MCP_URL" --header "$AUTH_HEADER" \ + --list-resources 2>&1 +)" && pass "resources/list" || { + if printf '%s' "$resources_output" | grep -qi "no resources defined"; then + skip "resources/list" "no resources defined" + else + fail "resources/list" "$resources_output" + fi +} # ── Bearer token enforcement ───────────────────────────────────────────────── header "Bearer token enforcement" diff --git a/tests/test_review_regressions.py b/tests/test_review_regressions.py new file mode 100644 index 0000000..387195f --- /dev/null +++ b/tests/test_review_regressions.py @@ -0,0 +1,92 @@ +"""Regression coverage for packaging, manifests, and hook scripts.""" + +from __future__ import annotations + +import json +import os +import subprocess +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + + +def test_plugin_manifest_restores_stdio_server_definition() -> None: + plugin = json.loads((PROJECT_ROOT / ".claude-plugin" / "plugin.json").read_text()) + assert plugin["userConfig"]["unraid_mcp_token"]["sensitive"] is True + assert "mcpServers" in plugin + assert plugin["mcpServers"]["unraid"]["type"] == "stdio" + + +def test_container_configs_use_runtime_port_variable() -> None: + compose = (PROJECT_ROOT / "docker-compose.yaml").read_text() + dockerfile = (PROJECT_ROOT / "Dockerfile").read_text() + assert "${UNRAID_MCP_PORT:-6970}:${UNRAID_MCP_PORT:-6970}" in compose + assert "os.getenv('UNRAID_MCP_PORT', '6970')" in compose + assert "os.getenv('UNRAID_MCP_PORT', '6970')" in dockerfile + + +def test_test_live_script_uses_safe_counters_and_resource_failures() -> None: + script = (PROJECT_ROOT / "tests" / "test_live.sh").read_text() + assert "((++PASS))" in script + assert "((++FAIL))" in script + assert "((++SKIP))" in script + assert 'fail "resources/list" "$resources_output"' in script + + +def test_sync_env_rejects_multiline_values(tmp_path: Path) -> None: + env = os.environ.copy() + env["CLAUDE_PLUGIN_ROOT"] = str(tmp_path) + env["CLAUDE_PLUGIN_OPTION_UNRAID_API_URL"] = "https://tower.local\nINJECT=1" + + result = subprocess.run( # noqa: S603 + ["/usr/bin/bash", str(PROJECT_ROOT / "hooks" / "scripts" / "sync-env.sh")], + cwd=PROJECT_ROOT, + env=env, + capture_output=True, + text=True, + ) + + assert result.returncode != 0 + assert "control characters" in result.stderr + + +def test_sync_env_regenerates_empty_bearer_token(tmp_path: Path) -> None: + env_file = tmp_path / ".env" + env_file.write_text("UNRAID_MCP_BEARER_TOKEN=\n") + + env = os.environ.copy() + env["CLAUDE_PLUGIN_ROOT"] = str(tmp_path) + + result = subprocess.run( # noqa: S603 + ["/usr/bin/bash", str(PROJECT_ROOT / "hooks" / "scripts" / "sync-env.sh")], + cwd=PROJECT_ROOT, + env=env, + capture_output=True, + text=True, + ) + + assert result.returncode == 0 + lines = env_file.read_text().splitlines() + token_line = next(line for line in lines if line.startswith("UNRAID_MCP_BEARER_TOKEN=")) + assert token_line != "UNRAID_MCP_BEARER_TOKEN=" + + +def test_ensure_gitignore_preserves_ignore_before_negation(tmp_path: Path) -> None: + gitignore = tmp_path / ".gitignore" + gitignore.write_text("!backups/.gitkeep\n") + + env = os.environ.copy() + env["CLAUDE_PLUGIN_ROOT"] = str(tmp_path) + + result = subprocess.run( # noqa: S603 + ["/usr/bin/bash", str(PROJECT_ROOT / "hooks" / "scripts" / "ensure-gitignore.sh")], + cwd=PROJECT_ROOT, + env=env, + capture_output=True, + text=True, + ) + + assert result.returncode == 0 + lines = gitignore.read_text().splitlines() + assert lines.index("backups/*") < lines.index("!backups/.gitkeep") diff --git a/tests/test_settings.py b/tests/test_settings.py index 58703d5..6aa70b1 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -89,6 +89,21 @@ class TestSettingsUpdate: assert result["success"] is True assert result["subaction"] == "update" + async def test_update_allows_nested_json_values(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = { + "updateSettings": {"restartRequired": False, "values": {}, "warnings": []} + } + tool_fn = _make_tool() + payload = { + "themeOverrides": {"sidebar": None, "panels": ["cpu", "memory"]}, + "advanced": [1, True, {"nested": "ok"}], + } + result = await tool_fn(action="setting", subaction="update", settings_input=payload) + assert result["success"] is True + _mock_graphql.assert_awaited_once() + sent_payload = _mock_graphql.await_args.args[1]["input"] + assert sent_payload == payload + # --------------------------------------------------------------------------- # configure_ups @@ -114,3 +129,13 @@ class TestUpsConfig: ) assert result["success"] is True assert result["subaction"] == "configure_ups" + + async def test_configure_ups_rejects_nested_values(self, _mock_graphql: AsyncMock) -> None: + tool_fn = _make_tool() + with pytest.raises(ToolError, match="must be a string, number, or boolean"): + await tool_fn( + action="setting", + subaction="configure_ups", + confirm=True, + ups_config={"mode": {"nested": "invalid"}}, + ) diff --git a/tests/test_storage.py b/tests/test_storage.py index e30b441..534208f 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -238,6 +238,12 @@ class TestStorageActions: with pytest.raises(ToolError, match="Log file not found or inaccessible"): await tool_fn(action="disk", subaction="logs", log_path="/var/log/syslog") + async def test_logs_empty_log_file_payload_raises(self, _mock_graphql: AsyncMock) -> None: + _mock_graphql.return_value = {"logFile": {}} + tool_fn = _make_tool() + with pytest.raises(ToolError, match="Log file not found or inaccessible"): + await tool_fn(action="disk", subaction="logs", log_path="/var/log/syslog") + async def test_disk_details_not_found(self, _mock_graphql: AsyncMock) -> None: _mock_graphql.return_value = {"disk": None} tool_fn = _make_tool() diff --git a/unraid_mcp/config/settings.py b/unraid_mcp/config/settings.py index cc509e6..8c0e66f 100644 --- a/unraid_mcp/config/settings.py +++ b/unraid_mcp/config/settings.py @@ -175,7 +175,7 @@ def get_config_summary() -> dict[str, Any]: "missing_config": missing if not is_valid else None, # Auth fields only meaningful in HTTP mode "http_auth_enabled": is_http and not UNRAID_MCP_DISABLE_HTTP_AUTH, - "http_auth_token_set": bool(UNRAID_MCP_BEARER_TOKEN) if is_http else None, + "http_auth_token_set": bool(UNRAID_MCP_BEARER_TOKEN) if is_http else False, } diff --git a/unraid_mcp/core/auth.py b/unraid_mcp/core/auth.py index b8e2d79..4de1509 100644 --- a/unraid_mcp/core/auth.py +++ b/unraid_mcp/core/auth.py @@ -93,6 +93,8 @@ class BearerAuthMiddleware: if auth_header is None: # No Authorization header at all — prompt client per RFC 6750 + self._record_failure(client_ip) + self._maybe_warn(client_ip, "missing authorization header") await self._send_response( send, status=401, @@ -156,15 +158,12 @@ class BearerAuthMiddleware: """Return True if this IP has hit the failure rate limit.""" if ip not in self._ip_failures: return False - now = time.monotonic() - q = self._ip_failures[ip] - cutoff = now - _RATE_WINDOW_SECS - while q and q[0] < cutoff: - q.popleft() - return len(q) >= _RATE_MAX_FAILURES + self._prune_ip_state(ip) + return len(self._ip_failures.get(ip, ())) >= _RATE_MAX_FAILURES def _record_failure(self, ip: str) -> None: """Record one failed auth attempt for this IP.""" + self._prune_ip_state(ip) if ip not in self._ip_failures: self._ip_failures[ip] = deque() self._ip_failures[ip].append(time.monotonic()) @@ -179,6 +178,21 @@ class BearerAuthMiddleware: self._ip_last_warn[ip] = now logger.warning("Bearer auth rejected (%s) from %s", reason, ip) + def _prune_ip_state(self, ip: str) -> None: + """Drop stale failure and warning-tracking state for one IP.""" + now = time.monotonic() + q = self._ip_failures.get(ip) + if q is not None: + cutoff = now - _RATE_WINDOW_SECS + while q and q[0] < cutoff: + q.popleft() + if not q: + self._ip_failures.pop(ip, None) + + last_warn = self._ip_last_warn.get(ip) + if last_warn is not None and (now - last_warn) >= _RATE_WINDOW_SECS: + self._ip_last_warn.pop(ip, None) + @staticmethod async def _send_response( send: Send, diff --git a/unraid_mcp/server.py b/unraid_mcp/server.py index d1214af..1a8d68c 100644 --- a/unraid_mcp/server.py +++ b/unraid_mcp/server.py @@ -37,11 +37,13 @@ from .subscriptions.resources import register_subscription_resources from .tools.unraid import register_unraid_tool -def _chmod_safe(path: object, mode: int) -> None: - """Chmod with graceful fallback for volume mounts owned by root.""" +def _chmod_safe(path: object, mode: int, *, strict: bool = False) -> None: + """Best-effort chmod, with optional fail-closed behavior for secrets.""" try: path.chmod(mode) # type: ignore[union-attr] - except PermissionError: + except PermissionError as exc: + if strict: + raise RuntimeError(f"Failed to secure permissions on {path}") from exc logger.debug("Could not chmod %s (volume mount?) — skipping", path) @@ -148,7 +150,7 @@ def ensure_token_exists() -> None: # Ensure credentials dir exists with restricted permissions. CREDENTIALS_DIR.mkdir(parents=True, exist_ok=True) - _chmod_safe(CREDENTIALS_DIR, 0o700) + _chmod_safe(CREDENTIALS_DIR, 0o700, strict=True) # Touch the file first so set_key has a target (no-op if already exists) if not CREDENTIALS_ENV_PATH.exists(): @@ -156,13 +158,11 @@ def ensure_token_exists() -> None: # In-place .env write — preserves comments and existing keys set_key(str(CREDENTIALS_ENV_PATH), "UNRAID_MCP_BEARER_TOKEN", token, quote_mode="auto") - _chmod_safe(CREDENTIALS_ENV_PATH, 0o600) + _chmod_safe(CREDENTIALS_ENV_PATH, 0o600, strict=True) - # Print once to STDERR so the user can copy it into their MCP client config print( - f"\n[unraid-mcp] Generated HTTP bearer token (saved to {CREDENTIALS_ENV_PATH}):\n" - f" UNRAID_MCP_BEARER_TOKEN={token}\n" - " Add this to your MCP client's Authorization header: Bearer \n", + f"\n[unraid-mcp] Generated HTTP bearer token and saved it to {CREDENTIALS_ENV_PATH}.\n" + "Configure your MCP client to send Authorization: Bearer using that stored value.\n", file=sys.stderr, flush=True, ) @@ -178,64 +178,58 @@ def run_server() -> None: """Run the MCP server with the configured transport.""" is_http = _settings.UNRAID_MCP_TRANSPORT in ("streamable-http", "sse") - # Auto-generate token before the startup guard so a fresh install - # can start without manual intervention. - if is_http: - ensure_token_exists() + try: + if is_http: + ensure_token_exists() - # Hard stop: HTTP mode with no token and auth not explicitly disabled. - # We deliberately do NOT mention DISABLE_HTTP_AUTH in the error message to - # avoid inadvertently guiding users toward disabling auth. - if ( - is_http - and not _settings.UNRAID_MCP_DISABLE_HTTP_AUTH - and not _settings.UNRAID_MCP_BEARER_TOKEN - ): - print( - "FATAL: HTTP transport requires a bearer token. " - "Set UNRAID_MCP_BEARER_TOKEN in ~/.unraid-mcp/.env or restart to auto-generate.", - file=sys.stderr, - ) - sys.exit(1) + if ( + is_http + and not _settings.UNRAID_MCP_DISABLE_HTTP_AUTH + and not _settings.UNRAID_MCP_BEARER_TOKEN + ): + print( + "FATAL: HTTP transport requires a bearer token. " + "Set UNRAID_MCP_BEARER_TOKEN in ~/.unraid-mcp/.env or restart to auto-generate.", + file=sys.stderr, + ) + sys.exit(1) - # Validate Unraid API credentials - is_valid, missing = validate_required_config() - if not is_valid: - logger.warning( - "Missing configuration: %s. " - "Server will prompt for credentials on first tool call via elicitation.", - ", ".join(missing), - ) - - log_configuration_status(logger) - - if UNRAID_VERIFY_SSL is False: - logger.warning( - "SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). " - "Connections to Unraid API are vulnerable to man-in-the-middle attacks. " - "Only use this in trusted networks or for development." - ) - - if is_http: - if _settings.UNRAID_MCP_DISABLE_HTTP_AUTH: + is_valid, missing = validate_required_config() + if not is_valid: logger.warning( - "HTTP auth disabled (UNRAID_MCP_DISABLE_HTTP_AUTH=true). " - "Ensure an upstream gateway enforces authentication." + "Missing configuration: %s. " + "Server will prompt for credentials on first tool call via elicitation.", + ", ".join(missing), + ) + + log_configuration_status(logger) + + if UNRAID_VERIFY_SSL is False: + logger.warning( + "SSL VERIFICATION DISABLED (UNRAID_VERIFY_SSL=false). " + "Connections to Unraid API are vulnerable to man-in-the-middle attacks. " + "Only use this in trusted networks or for development." + ) + + if is_http: + if _settings.UNRAID_MCP_DISABLE_HTTP_AUTH: + logger.warning( + "HTTP auth disabled (UNRAID_MCP_DISABLE_HTTP_AUTH=true). " + "Ensure an upstream gateway enforces authentication." + ) + else: + logger.info("HTTP bearer token authentication enabled.") + + if is_http: + logger.info( + "Starting Unraid MCP Server on %s:%s using %s transport...", + UNRAID_MCP_HOST, + UNRAID_MCP_PORT, + _settings.UNRAID_MCP_TRANSPORT, ) else: - logger.info("HTTP bearer token authentication enabled.") + logger.info("Starting Unraid MCP Server using stdio transport...") - if is_http: - logger.info( - "Starting Unraid MCP Server on %s:%s using %s transport...", - UNRAID_MCP_HOST, - UNRAID_MCP_PORT, - _settings.UNRAID_MCP_TRANSPORT, - ) - else: - logger.info("Starting Unraid MCP Server using stdio transport...") - - try: if is_http: if _settings.UNRAID_MCP_TRANSPORT == "sse": logger.warning( diff --git a/unraid_mcp/subscriptions/manager.py b/unraid_mcp/subscriptions/manager.py index 030cc55..dcac690 100644 --- a/unraid_mcp/subscriptions/manager.py +++ b/unraid_mcp/subscriptions/manager.py @@ -35,6 +35,12 @@ _last_graphql_error: dict[str, str] = {} _graphql_error_count: dict[str, int] = {} +def _clear_graphql_error_burst(subscription_name: str) -> None: + """Reset deduplicated GraphQL error tracking for one subscription.""" + _last_graphql_error.pop(subscription_name, None) + _graphql_error_count.pop(subscription_name, None) + + def _preview(message: str | bytes, n: int = 200) -> str: """Return the first *n* characters of *message* as a UTF-8 string. @@ -193,9 +199,21 @@ class SubscriptionManager: self.last_error[name] = str(e) start_errors.append((name, e)) - async with asyncio.TaskGroup() as tg: - for name, config in auto_start_configs: - tg.create_task(_start_one(name, config)) + started_names: list[str] = [] + + async def _tracked_start(name: str, config: dict[str, Any]) -> None: + await _start_one(name, config) + started_names.append(name) + + try: + async with asyncio.TaskGroup() as tg: + for name, config in auto_start_configs: + tg.create_task(_tracked_start(name, config)) + except asyncio.CancelledError: + for name in started_names: + if name in self.active_subscriptions: + await self.stop_subscription(name) + raise started = len(auto_start_configs) - len(start_errors) logger.info( @@ -211,10 +229,11 @@ class SubscriptionManager: self, subscription_name: str, query: str, variables: dict[str, Any] | None = None ) -> None: """Start a GraphQL subscription and maintain it as a resource.""" - if not re.match(r"^[a-zA-Z0-9_]+$", subscription_name): + if not re.fullmatch(r"[a-zA-Z0-9_]+", subscription_name): raise ValueError( f"subscription_name must contain only [a-zA-Z0-9_], got: {subscription_name!r}" ) + _clear_graphql_error_burst(subscription_name) logger.info(f"[SUBSCRIPTION:{subscription_name}] Starting subscription...") # Guard must be inside the lock to prevent a TOCTOU race where two @@ -275,6 +294,8 @@ class SubscriptionManager: await task except asyncio.CancelledError: logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully") + self.connection_states[subscription_name] = "stopped" + _clear_graphql_error_burst(subscription_name) logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped") async def stop_all(self) -> None: @@ -461,6 +482,7 @@ class SubscriptionManager: logger.info( f"[DATA:{subscription_name}] Received subscription data update" ) + _clear_graphql_error_burst(subscription_name) capped_data = ( _cap_log_content(payload["data"]) if isinstance(payload["data"], dict) diff --git a/unraid_mcp/tools/_disk.py b/unraid_mcp/tools/_disk.py index 1d643c6..ecec392 100644 --- a/unraid_mcp/tools/_disk.py +++ b/unraid_mcp/tools/_disk.py @@ -183,7 +183,7 @@ async def _handle_disk( return {"log_files": data.get("logFiles", [])} if subaction == "logs": result = data.get("logFile") - if result is None: + if not result: raise ToolError(f"Log file not found or inaccessible: {log_path}") return dict(result) diff --git a/unraid_mcp/tools/_health.py b/unraid_mcp/tools/_health.py index 0d553dc..88532d9 100644 --- a/unraid_mcp/tools/_health.py +++ b/unraid_mcp/tools/_health.py @@ -157,7 +157,7 @@ async def _comprehensive_health_check() -> dict[str, Any]: "timestamp": datetime.datetime.now(datetime.UTC).isoformat(), "error": str(e), } - except Exception as e: + except _client.ToolError as e: # make_graphql_request wraps httpx network errors in ToolError; catch them # here so health/check returns {"status": "unhealthy"} on real outages # rather than propagating an unhandled ToolError to the caller. diff --git a/unraid_mcp/tools/_live.py b/unraid_mcp/tools/_live.py index 875e2b3..ef7a6e6 100644 --- a/unraid_mcp/tools/_live.py +++ b/unraid_mcp/tools/_live.py @@ -23,7 +23,7 @@ _HANDLED_COLLECT_SUBACTIONS: frozenset[str] = frozenset({"log_tail", "notificati def _assert_collect_subactions_complete() -> None: - """Raise AssertionError at import time if COLLECT_ACTIONS has an unhandled key. + """Raise RuntimeError at import time if collect subactions drift. Every key in COLLECT_ACTIONS must appear in _HANDLED_COLLECT_SUBACTIONS AND have a matching if-branch in _handle_live. This assertion catches the former @@ -32,11 +32,16 @@ def _assert_collect_subactions_complete() -> None: from ..subscriptions.queries import COLLECT_ACTIONS missing = set(COLLECT_ACTIONS) - _HANDLED_COLLECT_SUBACTIONS + stale = _HANDLED_COLLECT_SUBACTIONS - set(COLLECT_ACTIONS) if missing: raise RuntimeError( f"_HANDLED_COLLECT_SUBACTIONS is missing keys from COLLECT_ACTIONS: {missing}. " "Add a handler branch in _handle_live and update _HANDLED_COLLECT_SUBACTIONS." ) + if stale: + raise RuntimeError( + f"_HANDLED_COLLECT_SUBACTIONS contains stale keys not present in COLLECT_ACTIONS: {stale}." + ) _assert_collect_subactions_complete() diff --git a/unraid_mcp/tools/_setting.py b/unraid_mcp/tools/_setting.py index 20f25d7..3491bbb 100644 --- a/unraid_mcp/tools/_setting.py +++ b/unraid_mcp/tools/_setting.py @@ -21,8 +21,8 @@ from ..core.validation import DANGEROUS_KEY_PATTERN, MAX_VALUE_LENGTH _MAX_SETTINGS_KEYS = 100 -def _validate_settings_input(settings_input: dict[str, Any]) -> dict[str, Any]: - """Validate settings_input before forwarding to the Unraid API. +def _validate_settings_mapping(settings_input: dict[str, Any]) -> dict[str, Any]: + """Validate flat scalar settings data before forwarding to the Unraid API. Enforces a key count cap and rejects dangerous key names and oversized values to prevent unvalidated bulk input from reaching the API. Modeled on @@ -60,6 +60,22 @@ def _validate_settings_input(settings_input: dict[str, Any]) -> dict[str, Any]: return validated +def _validate_json_settings_input(settings_input: dict[str, Any]) -> dict[str, Any]: + """Validate JSON-typed settings input without narrowing valid JSON members.""" + if len(settings_input) > _MAX_SETTINGS_KEYS: + raise ToolError(f"settings_input has {len(settings_input)} keys (max {_MAX_SETTINGS_KEYS})") + validated: dict[str, Any] = {} + for key, value in settings_input.items(): + if not isinstance(key, str) or not key.strip(): + raise ToolError( + f"settings_input keys must be non-empty strings, got: {type(key).__name__}" + ) + if DANGEROUS_KEY_PATTERN.search(key): + raise ToolError(f"settings_input key '{key}' contains disallowed characters") + validated[key] = value + return validated + + _SETTING_MUTATIONS: dict[str, str] = { "update": "mutation UpdateSettings($input: JSON!) { updateSettings(input: $input) { restartRequired values warnings } }", "configure_ups": "mutation ConfigureUps($config: UPSConfigInput!) { configureUps(config: $config) }", @@ -95,7 +111,7 @@ async def _handle_setting( if subaction == "update": if settings_input is None: raise ToolError("settings_input is required for setting/update") - validated_input = _validate_settings_input(settings_input) + validated_input = _validate_json_settings_input(settings_input) data = await _client.make_graphql_request( _SETTING_MUTATIONS["update"], {"input": validated_input} ) @@ -107,7 +123,7 @@ async def _handle_setting( # Validate ups_config with the same rules as settings_input — key count # cap, scalar-only values, MAX_VALUE_LENGTH — to prevent unvalidated bulk # input from reaching the GraphQL mutation. - validated_ups = _validate_settings_input(ups_config) + validated_ups = _validate_settings_mapping(ups_config) data = await _client.make_graphql_request( _SETTING_MUTATIONS["configure_ups"], {"config": validated_ups} )