unraid-mcp/unraid_mcp/subscriptions/manager.py

757 lines
36 KiB
Python

"""WebSocket subscription manager for real-time Unraid data.
This module manages GraphQL subscriptions over WebSocket connections,
providing real-time data streaming for MCP resources with comprehensive
error handling, reconnection logic, and authentication.
"""
import asyncio
import json
import os
import re
import time
from datetime import UTC, datetime
from typing import Any
import websockets
from websockets.typing import Subprotocol
from ..config import settings as _settings
from ..config.logging import logger
from ..core.client import redact_sensitive
from ..core.types import SubscriptionData
from .utils import build_connection_init, build_ws_ssl_context, build_ws_url
# Resource data size limits to prevent unbounded memory growth
_MAX_RESOURCE_DATA_BYTES = 1_048_576 # 1MB
_MAX_RESOURCE_DATA_LINES = 5_000
# Minimum stable connection duration (seconds) before resetting reconnect counter
_STABLE_CONNECTION_SECONDS = 30
# Track last GraphQL error per subscription to deduplicate log spam.
# Key: subscription name, Value: first error message seen in the current burst.
_last_graphql_error: dict[str, str] = {}
_graphql_error_count: dict[str, int] = {}
def _clear_graphql_error_burst(subscription_name: str) -> None:
"""Reset deduplicated GraphQL error tracking for one subscription."""
_last_graphql_error.pop(subscription_name, None)
_graphql_error_count.pop(subscription_name, None)
def _preview(message: str | bytes, n: int = 200) -> str:
"""Return the first *n* characters of *message* as a UTF-8 string.
Safe for both str and bytes payloads; replaces unmappable bytes rather
than raising. Used to log a safe, bounded snippet before/after errors.
"""
if isinstance(message, str):
return message[:n]
return message[:n].decode("utf-8", errors="replace")
# Subscription names that carry log content and therefore need _cap_log_content.
# All other subscriptions never have a 'content' field — skip the cap call for them.
_LOG_SUBSCRIPTION_NAMES: frozenset[str] = frozenset({"log_tail", "logFileSubscription"})
def _cap_log_content(data: dict[str, Any]) -> dict[str, Any]:
"""Cap log content in subscription data to prevent unbounded memory growth.
Returns a new dict — does NOT mutate the input. If any nested 'content'
field (from log subscriptions) exceeds the byte limit, truncate it to the
most recent _MAX_RESOURCE_DATA_LINES lines.
The final content is guaranteed to be <= _MAX_RESOURCE_DATA_BYTES.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = _cap_log_content(value)
elif (
key == "content"
and isinstance(value, str)
# Pre-check uses byte count so multibyte UTF-8 chars cannot bypass the cap
and len(value.encode("utf-8", errors="replace")) > _MAX_RESOURCE_DATA_BYTES
):
lines = value.splitlines()
original_line_count = len(lines)
# Keep most recent lines first.
if len(lines) > _MAX_RESOURCE_DATA_LINES:
lines = lines[-_MAX_RESOURCE_DATA_LINES:]
truncated = "\n".join(lines)
# Encode once and slice bytes instead of an O(n²) line-trim loop
encoded = truncated.encode("utf-8", errors="replace")
if len(encoded) > _MAX_RESOURCE_DATA_BYTES:
truncated = encoded[-_MAX_RESOURCE_DATA_BYTES:].decode("utf-8", errors="ignore")
# Strip partial first line that may have been cut mid-character
nl_pos = truncated.find("\n")
if nl_pos != -1:
truncated = truncated[nl_pos + 1 :]
logger.warning(
f"[RESOURCE] Capped log content from {original_line_count} to "
f"{len(lines)} lines ({len(value)} -> {len(truncated)} chars)"
)
result[key] = truncated
else:
result[key] = value
return result
class SubscriptionManager:
"""Manages GraphQL subscriptions and converts them to MCP resources."""
def __init__(self) -> None:
self.active_subscriptions: dict[str, asyncio.Task[None]] = {}
self.resource_data: dict[str, SubscriptionData] = {}
# Two fine-grained locks instead of one coarse lock (P-01):
# _task_lock guards active_subscriptions dict (task lifecycle).
# _data_lock guards resource_data dict (WebSocket message writes + reads).
# Splitting prevents WebSocket message updates from blocking tool reads
# of active_subscriptions and vice versa.
self._task_lock = asyncio.Lock()
self._data_lock = asyncio.Lock()
# Configuration
self.auto_start_enabled = (
os.getenv("UNRAID_AUTO_START_SUBSCRIPTIONS", "true").lower() == "true"
)
self.reconnect_attempts: dict[str, int] = {}
self.max_reconnect_attempts = int(os.getenv("UNRAID_MAX_RECONNECT_ATTEMPTS", "10"))
self.connection_states: dict[str, str] = {} # Track connection state per subscription
self.last_error: dict[str, str] = {} # Track last error per subscription
self._connection_start_times: dict[str, float] = {} # Track when connections started
# Define subscription configurations
from .queries import SNAPSHOT_ACTIONS
self.subscription_configs: dict[str, dict] = {
action: {
"query": query,
"resource": f"unraid://live/{action}",
"description": f"Real-time {action.replace('_', ' ')} data",
"auto_start": True,
}
for action, query in SNAPSHOT_ACTIONS.items()
}
self.subscription_configs["logFileSubscription"] = {
"query": """
subscription LogFileSubscription($path: String!) {
logFile(path: $path) {
path
content
totalLines
}
}
""",
"resource": "unraid://logs/stream",
"description": "Real-time log file streaming",
"auto_start": False, # Started manually with path parameter
}
logger.info(
f"[SUBSCRIPTION_MANAGER] Initialized with auto_start={self.auto_start_enabled}, max_reconnects={self.max_reconnect_attempts}"
)
logger.debug(
f"[SUBSCRIPTION_MANAGER] Available subscriptions: {list(self.subscription_configs.keys())}"
)
async def auto_start_all_subscriptions(self) -> None:
"""Auto-start all subscriptions marked for auto-start.
All auto-start subscriptions are launched in parallel via asyncio.TaskGroup
to avoid blocking the first MCP request on a sequential startup delay.
"""
if not self.auto_start_enabled:
logger.info("[SUBSCRIPTION_MANAGER] Auto-start disabled")
return
auto_start_configs = [
(name, config)
for name, config in self.subscription_configs.items()
if config.get("auto_start", False)
]
if not auto_start_configs:
logger.info("[SUBSCRIPTION_MANAGER] No subscriptions marked for auto-start")
return
logger.info(
f"[SUBSCRIPTION_MANAGER] Starting auto-start process for {len(auto_start_configs)} subscriptions in parallel..."
)
start_errors: list[tuple[str, Exception]] = []
async def _start_one(name: str, config: dict[str, Any]) -> None:
try:
logger.info(f"[SUBSCRIPTION_MANAGER] Auto-starting subscription: {name}")
await self.start_subscription(name, str(config["query"]))
except asyncio.CancelledError:
raise # Never swallow cancellation — propagate for clean shutdown
except Exception as e:
logger.error(f"[SUBSCRIPTION_MANAGER] Failed to auto-start {name}: {e}")
async with self._task_lock:
self.last_error[name] = str(e)
start_errors.append((name, e))
started_names: list[str] = []
async def _tracked_start(name: str, config: dict[str, Any]) -> None:
await _start_one(name, config)
started_names.append(name)
try:
async with asyncio.TaskGroup() as tg:
for name, config in auto_start_configs:
tg.create_task(_tracked_start(name, config))
except asyncio.CancelledError:
for name in started_names:
if name in self.active_subscriptions:
await self.stop_subscription(name)
raise
started = len(auto_start_configs) - len(start_errors)
logger.info(
f"[SUBSCRIPTION_MANAGER] Auto-start completed. Started {started}/{len(auto_start_configs)} subscriptions"
)
if start_errors:
failed_names = ", ".join(n for n, _ in start_errors)
logger.warning(
f"[SUBSCRIPTION_MANAGER] {len(start_errors)} subscription(s) failed to auto-start: {failed_names}"
)
async def start_subscription(
self, subscription_name: str, query: str, variables: dict[str, Any] | None = None
) -> None:
"""Start a GraphQL subscription and maintain it as a resource."""
if not re.fullmatch(r"[a-zA-Z0-9_]+", subscription_name):
raise ValueError(
f"subscription_name must contain only [a-zA-Z0-9_], got: {subscription_name!r}"
)
_clear_graphql_error_burst(subscription_name)
logger.info(f"[SUBSCRIPTION:{subscription_name}] Starting subscription...")
# Guard must be inside the lock to prevent a TOCTOU race where two
# concurrent callers both pass the check before either creates the task.
async with self._task_lock:
if subscription_name in self.active_subscriptions:
logger.warning(
f"[SUBSCRIPTION:{subscription_name}] Subscription already active, skipping"
)
return
# Reset connection tracking inside the lock so state is consistent
# with the task creation that follows immediately.
self.reconnect_attempts[subscription_name] = 0
self.connection_states[subscription_name] = "starting"
self._connection_start_times.pop(subscription_name, None)
try:
task = asyncio.create_task(
self._subscription_loop(subscription_name, query, variables or {})
)
self.active_subscriptions[subscription_name] = task
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription task created and started"
)
self.connection_states[subscription_name] = "active"
except Exception as e:
logger.error(
f"[SUBSCRIPTION:{subscription_name}] Failed to start subscription task: {e}"
)
self.connection_states[subscription_name] = "failed"
self.last_error[subscription_name] = str(e)
raise
async def stop_subscription(self, subscription_name: str) -> None:
"""Stop a specific subscription.
Snapshots the task inside _task_lock, then releases the lock BEFORE
awaiting cancellation. Holding _task_lock across 'await task' would
deadlock if _subscription_loop's cleanup path also needs _task_lock
(it does, at loop exit). Pattern: lock → snapshot → release → await.
"""
logger.info(f"[SUBSCRIPTION:{subscription_name}] Stopping subscription...")
async with self._task_lock:
task = self.active_subscriptions.pop(subscription_name, None)
if task is None:
logger.warning(f"[SUBSCRIPTION:{subscription_name}] No active subscription to stop")
return
self.connection_states[subscription_name] = "stopped"
self._connection_start_times.pop(subscription_name, None)
# Await cancellation OUTSIDE the lock — _subscription_loop cleanup path
# acquires _task_lock at loop exit; holding it here while waiting for the
# task to finish would deadlock.
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Task cancelled successfully")
self.connection_states[subscription_name] = "stopped"
_clear_graphql_error_burst(subscription_name)
logger.info(f"[SUBSCRIPTION:{subscription_name}] Subscription stopped")
async def stop_all(self) -> None:
"""Stop all active subscriptions (called during server shutdown)."""
async with self._task_lock:
subscription_names = list(self.active_subscriptions.keys())
for name in subscription_names:
try:
await self.stop_subscription(name)
except Exception as e:
logger.error(f"[SHUTDOWN] Error stopping subscription '{name}': {e}", exc_info=True)
logger.info(f"[SHUTDOWN] Stopped {len(subscription_names)} subscription(s)")
async def _subscription_loop(
self, subscription_name: str, query: str, variables: dict[str, Any] | None
) -> None:
"""Main loop for maintaining a GraphQL subscription with comprehensive logging."""
retry_delay: int | float = 5
max_retry_delay = 300 # 5 minutes max
while True:
attempt = self.reconnect_attempts.get(subscription_name, 0) + 1
self.reconnect_attempts[subscription_name] = attempt
logger.info(
f"[WEBSOCKET:{subscription_name}] Connection attempt #{attempt} (max: {self.max_reconnect_attempts})"
)
if attempt > self.max_reconnect_attempts:
logger.error(
f"[WEBSOCKET:{subscription_name}] Max reconnection attempts ({self.max_reconnect_attempts}) exceeded, stopping"
)
self.connection_states[subscription_name] = "max_retries_exceeded"
break
try:
ws_url = build_ws_url()
logger.debug(f"[WEBSOCKET:{subscription_name}] Connecting to: {ws_url}")
logger.debug(
f"[WEBSOCKET:{subscription_name}] API Key present: {'Yes' if _settings.UNRAID_API_KEY else 'No'}"
)
ssl_context = build_ws_ssl_context(ws_url)
# Connection with timeout
connect_timeout = 10
logger.debug(
f"[WEBSOCKET:{subscription_name}] Connection timeout: {connect_timeout}s"
)
async with websockets.connect(
ws_url,
subprotocols=[Subprotocol("graphql-transport-ws"), Subprotocol("graphql-ws")],
open_timeout=connect_timeout,
ping_interval=20,
ping_timeout=10,
close_timeout=10,
ssl=ssl_context,
) as websocket:
selected_proto = websocket.subprotocol or "none"
logger.info(
f"[WEBSOCKET:{subscription_name}] Connected! Protocol: {selected_proto}"
)
self.connection_states[subscription_name] = "connected"
# Track connection start time — only reset retry counter
# after the connection proves stable (>30s connected)
self._connection_start_times[subscription_name] = time.monotonic()
# Initialize GraphQL-WS protocol
logger.debug(
f"[PROTOCOL:{subscription_name}] Initializing GraphQL-WS protocol..."
)
init_payload = build_connection_init()
if "payload" in init_payload:
logger.debug(f"[AUTH:{subscription_name}] Adding authentication payload")
else:
logger.warning(
f"[AUTH:{subscription_name}] No API key available for authentication"
)
logger.debug(f"[PROTOCOL:{subscription_name}] Sending connection_init message")
await websocket.send(json.dumps(init_payload))
# Wait for connection acknowledgment
logger.debug(f"[PROTOCOL:{subscription_name}] Waiting for connection_ack...")
init_raw = await asyncio.wait_for(websocket.recv(), timeout=30)
try:
init_data = json.loads(init_raw)
logger.debug(
f"[PROTOCOL:{subscription_name}] Received init response: {init_data.get('type')}"
)
except json.JSONDecodeError as e:
init_preview = (
init_raw[:200]
if isinstance(init_raw, str)
else init_raw[:200].decode("utf-8", errors="replace")
)
logger.error(
f"[PROTOCOL:{subscription_name}] Failed to decode init response: {init_preview}..."
)
# Raise rather than continue — continue skips the reconnect
# backoff at the bottom of the while loop, causing tight retry
# loops on malformed handshake responses (e.g. misconfigured
# proxies). The outer except Exception handler catches this,
# sets last_error, and the normal sleep/backoff runs.
raise RuntimeError(f"Invalid JSON in init handshake response: {e}") from e
# Handle connection acknowledgment
if init_data.get("type") == "connection_ack":
logger.info(
f"[PROTOCOL:{subscription_name}] Connection acknowledged successfully"
)
self.connection_states[subscription_name] = "authenticated"
elif init_data.get("type") == "connection_error":
error_payload = init_data.get("payload", {})
logger.error(
f"[AUTH:{subscription_name}] Authentication failed: {error_payload}"
)
self.last_error[subscription_name] = (
f"Authentication error: {error_payload}"
)
self.connection_states[subscription_name] = "auth_failed"
break
else:
logger.warning(
f"[PROTOCOL:{subscription_name}] Unexpected init response: {init_data}"
)
# Continue anyway - some servers send other messages first
# Start the subscription
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Starting GraphQL subscription..."
)
start_type = (
"subscribe" if selected_proto == "graphql-transport-ws" else "start"
)
subscription_message = {
"id": subscription_name,
"type": start_type,
"payload": {"query": query, "variables": variables},
}
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Subscription message type: {start_type}"
)
logger.debug(f"[SUBSCRIPTION:{subscription_name}] Query: {query[:100]}...")
logger.debug(
f"[SUBSCRIPTION:{subscription_name}] Variables: {redact_sensitive(variables)}"
)
await websocket.send(json.dumps(subscription_message))
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription started successfully"
)
self.connection_states[subscription_name] = "subscribed"
# Listen for subscription data
message_count = 0
async for message in websocket:
try:
data = json.loads(message)
message_count += 1
message_type = data.get("type", "unknown")
logger.debug(
f"[DATA:{subscription_name}] Message #{message_count}: {message_type}"
)
# Handle different message types
expected_data_type = (
"next" if selected_proto == "graphql-transport-ws" else "data"
)
if (
data.get("type") == expected_data_type
and data.get("id") == subscription_name
):
payload = data.get("payload", {})
if payload.get("data"):
logger.info(
f"[DATA:{subscription_name}] Received subscription data update"
)
_clear_graphql_error_burst(subscription_name)
capped_data = (
_cap_log_content(payload["data"])
if isinstance(payload["data"], dict)
and subscription_name in _LOG_SUBSCRIPTION_NAMES
else payload["data"]
)
new_entry = SubscriptionData(
data=capped_data,
last_updated=datetime.now(UTC),
subscription_type=subscription_name,
)
async with self._data_lock:
self.resource_data[subscription_name] = new_entry
logger.debug(
f"[RESOURCE:{subscription_name}] Resource data updated successfully"
)
elif payload.get("errors"):
err_msg = str(payload["errors"])
prev = _last_graphql_error.get(subscription_name)
count = _graphql_error_count.get(subscription_name, 0) + 1
_graphql_error_count[subscription_name] = count
if prev != err_msg:
# First occurrence of this error — log as warning
_last_graphql_error[subscription_name] = err_msg
_graphql_error_count[subscription_name] = 1
logger.warning(
"[DATA:%s] GraphQL error (will suppress repeats): %s",
subscription_name,
err_msg,
)
elif count in (10, 100, 1000):
# Periodic reminder at powers of 10
logger.warning(
"[DATA:%s] GraphQL error repeated %d times: %s",
subscription_name,
count,
err_msg,
)
else:
logger.debug(
"[DATA:%s] GraphQL error (repeat #%d)",
subscription_name,
count,
)
self.last_error[subscription_name] = (
f"GraphQL errors: {payload['errors']}"
)
else:
logger.warning(
f"[DATA:{subscription_name}] Empty or invalid data payload: {payload}"
)
elif data.get("type") == "ping":
logger.debug(
f"[PROTOCOL:{subscription_name}] Received ping, sending pong"
)
await websocket.send(json.dumps({"type": "pong"}))
elif data.get("type") == "error":
error_payload = data.get("payload", {})
logger.error(
f"[SUBSCRIPTION:{subscription_name}] Subscription error: {error_payload}"
)
self.last_error[subscription_name] = (
f"Subscription error: {error_payload}"
)
self.connection_states[subscription_name] = "error"
elif data.get("type") == "complete":
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription completed by server"
)
self.connection_states[subscription_name] = "completed"
break
elif data.get("type") in ["ka", "pong"]:
logger.debug(
f"[PROTOCOL:{subscription_name}] Keepalive message: {message_type}"
)
else:
logger.debug(
f"[PROTOCOL:{subscription_name}] Unhandled message type: {message_type}"
)
except json.JSONDecodeError as e:
logger.error(
f"[PROTOCOL:{subscription_name}] Failed to decode message: {_preview(message)}..."
)
logger.error(f"[PROTOCOL:{subscription_name}] JSON decode error: {e}")
except Exception as e:
logger.error(
f"[DATA:{subscription_name}] Error processing message: {e}",
exc_info=True,
)
logger.debug(
f"[DATA:{subscription_name}] Raw message: {_preview(message)}..."
)
except TimeoutError:
error_msg = "Connection or authentication timeout"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "timeout"
except websockets.exceptions.ConnectionClosed as e:
error_msg = f"WebSocket connection closed: {e}"
logger.warning(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "disconnected"
except websockets.exceptions.InvalidURI as e:
error_msg = f"Invalid WebSocket URI: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "invalid_uri"
break # Don't retry on invalid URI
except ValueError as e:
# Non-retryable configuration error (e.g. UNRAID_API_URL not set)
error_msg = f"Configuration error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}")
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
break # Don't retry on configuration errors
except Exception as e:
error_msg = f"Unexpected error: {e}"
logger.error(f"[WEBSOCKET:{subscription_name}] {error_msg}", exc_info=True)
self.last_error[subscription_name] = error_msg
self.connection_states[subscription_name] = "error"
# Check if connection was stable before deciding on retry behavior
start_time = self._connection_start_times.pop(subscription_name, None)
if start_time is not None:
connected_duration = time.monotonic() - start_time
if connected_duration >= _STABLE_CONNECTION_SECONDS:
# Connection was stable — reset retry counter and backoff
logger.info(
f"[WEBSOCKET:{subscription_name}] Connection was stable "
f"({connected_duration:.0f}s >= {_STABLE_CONNECTION_SECONDS}s), "
f"resetting retry counter"
)
self.reconnect_attempts[subscription_name] = 0
retry_delay = 5
else:
logger.warning(
f"[WEBSOCKET:{subscription_name}] Connection was unstable "
f"({connected_duration:.0f}s < {_STABLE_CONNECTION_SECONDS}s), "
f"keeping retry counter at {self.reconnect_attempts.get(subscription_name, 0)}"
)
# Only escalate backoff when connection was NOT stable
retry_delay = min(retry_delay * 1.5, max_retry_delay)
else:
# No connection was established — escalate backoff
retry_delay = min(retry_delay * 1.5, max_retry_delay)
logger.info(
f"[WEBSOCKET:{subscription_name}] Reconnecting in {retry_delay:.1f} seconds..."
)
self.connection_states[subscription_name] = "reconnecting"
await asyncio.sleep(retry_delay)
# The while loop exited (via break or max_retries exceeded).
# Remove from active_subscriptions so start_subscription() can restart it.
async with self._task_lock:
self.active_subscriptions.pop(subscription_name, None)
logger.info(
f"[SUBSCRIPTION:{subscription_name}] Subscription loop ended — "
f"removed from active_subscriptions. Final state: "
f"{self.connection_states.get(subscription_name, 'unknown')}"
)
async def get_resource_data(self, resource_name: str) -> dict[str, Any] | None:
"""Get current resource data with enhanced logging."""
logger.debug(f"[RESOURCE:{resource_name}] Resource data requested")
async with self._data_lock:
if resource_name in self.resource_data:
data = self.resource_data[resource_name]
age_seconds = (datetime.now(UTC) - data.last_updated).total_seconds()
logger.debug(f"[RESOURCE:{resource_name}] Data found, age: {age_seconds:.1f}s")
return data.data
logger.debug(f"[RESOURCE:{resource_name}] No data available")
return None
def list_active_subscriptions(self) -> list[str]:
"""List all active subscriptions."""
active = list(self.active_subscriptions.keys())
logger.debug(f"[SUBSCRIPTION_MANAGER] Active subscriptions: {active}")
return active
async def get_error_state(self, name: str) -> tuple[str | None, str]:
"""Return (last_error, connection_state) for a subscription.
Public accessor so external callers (e.g. resources.py) never touch
private lock attributes directly.
Consistency note: _subscription_loop writes connection_states and
last_error without holding _task_lock (it runs as an asyncio Task with
no await between the two writes). The pair is consistent because asyncio
cooperative scheduling prevents interleaving at non-await points — not
because of the lock. _task_lock here prevents two *readers* from racing
each other, not writers. If an await is ever introduced between the two
write sites in _subscription_loop, this guarantee breaks.
"""
async with self._task_lock:
return (
self.last_error.get(name),
self.connection_states.get(name, ""),
)
async def get_subscription_status(self) -> dict[str, dict[str, Any]]:
"""Get detailed status of all subscriptions for diagnostics.
Acquires _task_lock and _data_lock separately (sequential, not simultaneous)
to prevent an ABBA deadlock:
- This method would hold _task_lock → then acquire _data_lock.
- _subscription_loop holds _data_lock (during resource_data write) and at
loop exit acquires _task_lock (to remove itself from active_subscriptions).
Holding both locks simultaneously from here creates the classic ABBA cycle.
Instead, snapshot state under each lock independently, then release before
acquiring the other.
"""
# Snapshot task-related state under _task_lock
async with self._task_lock:
active_snapshot = set(self.active_subscriptions)
state_snapshot = dict(self.connection_states)
reconnect_snapshot = dict(self.reconnect_attempts)
error_snapshot = dict(self.last_error)
# Snapshot resource data under _data_lock (acquired separately)
async with self._data_lock:
resource_snapshot = dict(self.resource_data)
status: dict[str, dict[str, Any]] = {}
for sub_name, config in self.subscription_configs.items():
sub_status: dict[str, Any] = {
"config": {
"resource": config["resource"],
"description": config["description"],
"auto_start": config.get("auto_start", False),
},
"runtime": {
"active": sub_name in active_snapshot,
"connection_state": state_snapshot.get(sub_name, "not_started"),
"reconnect_attempts": reconnect_snapshot.get(sub_name, 0),
"last_error": error_snapshot.get(sub_name),
},
}
# Add data info if available
if sub_name in resource_snapshot:
data_info = resource_snapshot[sub_name]
age_seconds = (datetime.now(UTC) - data_info.last_updated).total_seconds()
sub_status["data"] = {
"available": True,
"last_updated": data_info.last_updated.isoformat(),
"age_seconds": age_seconds,
}
else:
sub_status["data"] = {"available": False}
status[sub_name] = sub_status
logger.debug(f"[SUBSCRIPTION_MANAGER] Generated status for {len(status)} subscriptions")
return status
# Global subscription manager instance
subscription_manager = SubscriptionManager()