lobehub/apps/device-gateway/src/DeviceGatewayDO.ts
Rylan Cai 4dd271c968
feat(cli): support api key auth in cli (#13190)
*  support cli api key auth

* 🔒 reject invalid x-api-key without fallback auth

* ♻️ clean up cli api key auth diff

* ♻️ clean up cli auth command diff

* ♻️ clean up remaining cli auth diff

* ♻️ split stored auth token fields

* ♻️ trim connect auth surface

* ♻️ drop redundant jwt user id carry-over

* ♻️ trim auth test wording diff

* 🐛 fix api key model imports

* 🐛 fix api key util subpath import

* 🔐 chore(cli): use env-only api key auth

* ♻️ refactor(cli): simplify auth credential flow

*  feat: simplify cli api key login flow

* 🐛 fix(cli): prefer jwt for webapi auth

* ♻️ refactor(cli): trim auth http diff

* 🐛 fix(cli): skip api key auth expiry handling

* 🐛 fix(cli): restore non-jwt expiry handling

* ♻️ refactor(cli): trim connect auth expired diff

* ♻️ refactor(cli): trim login comment diff

* ♻️ refactor(cli): trim resolve token comment diff

* ♻️ refactor(cli): restore connect expiry flow

* ♻️ refactor(cli): trim login api key message

* 🐛 fix(cli): support api key gateway auth

* ♻️ refactor(cli): restore resolve token comment

* ♻️ refactor(cli): trim test-only auth diffs

* ♻️ refactor(cli): restore resolve token comments

*  test(cli): add api key expiry coverage

* 🐛 fix cli auth server resolution and gateway auth

* ♻️ prune auth fix diff noise

* ♻️ unify cli server url precedence

* ♻️ simplify device gateway auth tests

*  add gateway auth edge case coverage

*  remove low-value gateway auth test

* 🐛 fix api key context test mock typing
2026-03-26 10:11:38 +08:00

345 lines
10 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { DurableObject } from 'cloudflare:workers';
import { Hono } from 'hono';
import { resolveSocketAuth, verifyApiKeyToken, verifyDesktopToken } from './auth';
import type { DeviceAttachment, Env } from './types';
const AUTH_TIMEOUT = 10_000; // 10s to authenticate after connect
const HEARTBEAT_TIMEOUT = 90_000; // 90s without heartbeat → close
const HEARTBEAT_CHECK_INTERVAL = 90_000; // check every 90s
export class DeviceGatewayDO extends DurableObject<Env> {
private pendingRequests = new Map<
string,
{
resolve: (result: any) => void;
timer: ReturnType<typeof setTimeout>;
}
>();
private router = new Hono()
.all('/api/device/status', async () => {
const sockets = this.getAuthenticatedSockets();
return Response.json({
deviceCount: sockets.length,
online: sockets.length > 0,
});
})
.post('/api/device/tool-call', async (c) => {
return this.handleToolCall(c.req.raw);
})
.post('/api/device/system-info', async (c) => {
return this.handleSystemInfo(c.req.raw);
})
.all('/api/device/devices', async () => {
const sockets = this.getAuthenticatedSockets();
const devices = sockets.map((ws) => ws.deserializeAttachment() as DeviceAttachment);
return Response.json({ devices });
});
async fetch(request: Request): Promise<Response> {
// ─── WebSocket upgrade (from Desktop) ───
if (request.headers.get('Upgrade') === 'websocket') {
return this.handleWebSocketUpgrade(request);
}
// ─── HTTP API routes ───
return this.router.fetch(request);
}
// ─── Hibernation Handlers ───
async webSocketMessage(ws: WebSocket, message: string | ArrayBuffer) {
const data = JSON.parse(message as string);
const att = ws.deserializeAttachment() as DeviceAttachment;
// ─── Auth message handling ───
if (data.type === 'auth') {
if (att.authenticated) return; // Already authenticated, ignore
try {
const token = data.token as string | undefined;
const tokenType = data.tokenType as 'apiKey' | 'jwt' | 'serviceToken' | undefined;
const serverUrl = data.serverUrl as string | undefined;
const storedUserId = await this.ctx.storage.get<string>('_userId');
const verifiedUserId = await resolveSocketAuth({
serverUrl,
serviceToken: this.env.SERVICE_TOKEN,
storedUserId,
token,
tokenType,
verifyApiKey: verifyApiKeyToken,
verifyJwt: async (jwt) => {
const result = await verifyDesktopToken(this.env, jwt);
return { userId: result.userId };
},
});
// Verify userId matches the DO routing
if (storedUserId && verifiedUserId !== storedUserId) {
throw new Error('userId mismatch');
}
// Mark as authenticated
att.authenticated = true;
att.authDeadline = undefined;
ws.serializeAttachment(att);
ws.send(JSON.stringify({ type: 'auth_success' }));
// Schedule heartbeat check for authenticated connections
await this.scheduleHeartbeatCheck();
} catch (err) {
const reason = err instanceof Error ? err.message : 'Authentication failed';
ws.send(JSON.stringify({ reason, type: 'auth_failed' }));
ws.close(1008, reason);
}
return;
}
// ─── Reject unauthenticated messages ───
if (!att.authenticated) return;
// ─── Business messages (authenticated only) ───
if (data.type === 'tool_call_response' || data.type === 'system_info_response') {
const pending = this.pendingRequests.get(data.requestId);
if (pending) {
clearTimeout(pending.timer);
pending.resolve(data.result);
this.pendingRequests.delete(data.requestId);
}
}
if (data.type === 'heartbeat') {
att.lastHeartbeat = Date.now();
ws.serializeAttachment(att);
ws.send(JSON.stringify({ type: 'heartbeat_ack' }));
}
}
async webSocketClose(_ws: WebSocket, _code: number) {
// Hibernation API handles connection cleanup automatically
}
async webSocketError(ws: WebSocket, _error: unknown) {
ws.close(1011, 'Internal error');
}
// ─── Heartbeat Timeout ───
async alarm() {
const now = Date.now();
const closedSockets = new Set<WebSocket>();
for (const ws of this.ctx.getWebSockets()) {
const att = ws.deserializeAttachment() as DeviceAttachment;
// Auth timeout: close unauthenticated connections past deadline
if (!att.authenticated && att.authDeadline && now > att.authDeadline) {
ws.send(JSON.stringify({ reason: 'Authentication timeout', type: 'auth_failed' }));
ws.close(1008, 'Authentication timeout');
closedSockets.add(ws);
continue;
}
// Heartbeat timeout: only for authenticated connections
if (att.authenticated && now - att.lastHeartbeat > HEARTBEAT_TIMEOUT) {
ws.close(1000, 'Heartbeat timeout');
closedSockets.add(ws);
}
}
// Keep alarm running while there are active connections
const remaining = this.ctx.getWebSockets().filter((ws) => !closedSockets.has(ws));
if (remaining.length > 0) {
await this.scheduleHeartbeatCheck();
}
}
// ─── WebSocket Upgrade ───
private async handleWebSocketUpgrade(request: Request): Promise<Response> {
const url = new URL(request.url);
const userId = request.headers.get('X-User-Id');
const deviceId = url.searchParams.get('deviceId') || 'unknown';
const hostname = url.searchParams.get('hostname') || '';
const platform = url.searchParams.get('platform') || '';
// Close stale connection from the same device
for (const ws of this.ctx.getWebSockets()) {
const att = ws.deserializeAttachment() as DeviceAttachment;
if (att.deviceId === deviceId) {
ws.close(1000, 'Replaced by new connection');
}
}
const pair = new WebSocketPair();
const [client, server] = Object.values(pair);
this.ctx.acceptWebSocket(server);
const now = Date.now();
server.serializeAttachment({
authDeadline: now + AUTH_TIMEOUT,
authenticated: false,
connectedAt: now,
deviceId,
hostname,
lastHeartbeat: now,
platform,
} satisfies DeviceAttachment);
if (userId) {
await this.ctx.storage.put('_userId', userId);
}
// Schedule auth timeout check (10s)
await this.scheduleAuthTimeout();
return new Response(null, { status: 101, webSocket: client });
}
private async scheduleAuthTimeout() {
const currentAlarm = await this.ctx.storage.getAlarm();
if (!currentAlarm) {
await this.ctx.storage.setAlarm(Date.now() + AUTH_TIMEOUT);
}
}
private async scheduleHeartbeatCheck() {
const currentAlarm = await this.ctx.storage.getAlarm();
if (!currentAlarm) {
await this.ctx.storage.setAlarm(Date.now() + HEARTBEAT_CHECK_INTERVAL);
}
}
// ─── Helpers ───
private getAuthenticatedSockets(): WebSocket[] {
return this.ctx.getWebSockets().filter((ws) => {
const att = ws.deserializeAttachment() as DeviceAttachment;
return att.authenticated;
});
}
// ─── System Info RPC ───
private async handleSystemInfo(request: Request): Promise<Response> {
const sockets = this.getAuthenticatedSockets();
if (sockets.length === 0) {
return Response.json({ error: 'DEVICE_OFFLINE', success: false }, { status: 503 });
}
const { deviceId, timeout = 10_000 } = (await request.json()) as {
deviceId?: string;
timeout?: number;
};
const requestId = crypto.randomUUID();
const targetWs = deviceId
? sockets.find((ws) => {
const att = ws.deserializeAttachment() as DeviceAttachment;
return att.deviceId === deviceId;
})
: sockets[0];
if (!targetWs) {
return Response.json({ error: 'DEVICE_NOT_FOUND', success: false }, { status: 503 });
}
try {
const result = await new Promise((resolve, reject) => {
const timer = setTimeout(() => {
this.pendingRequests.delete(requestId);
reject(new Error('TIMEOUT'));
}, timeout);
this.pendingRequests.set(requestId, { resolve, timer });
targetWs.send(
JSON.stringify({
requestId,
type: 'system_info_request',
}),
);
});
return Response.json({ success: true, ...(result as object) });
} catch (err) {
return Response.json(
{
error: (err as Error).message,
success: false,
},
{ status: 504 },
);
}
}
// ─── Tool Call RPC ───
private async handleToolCall(request: Request): Promise<Response> {
const sockets = this.getAuthenticatedSockets();
if (sockets.length === 0) {
return Response.json(
{ content: '桌面设备不在线', error: 'DEVICE_OFFLINE', success: false },
{ status: 503 },
);
}
const {
deviceId,
timeout = 30_000,
toolCall,
} = (await request.json()) as {
deviceId?: string;
timeout?: number;
toolCall: unknown;
};
const requestId = crypto.randomUUID();
// Select target device (specified > first available)
const targetWs = deviceId
? sockets.find((ws) => {
const att = ws.deserializeAttachment() as DeviceAttachment;
return att.deviceId === deviceId;
})
: sockets[0];
if (!targetWs) {
return Response.json({ error: 'DEVICE_NOT_FOUND', success: false }, { status: 503 });
}
try {
const result = await new Promise((resolve, reject) => {
const timer = setTimeout(() => {
this.pendingRequests.delete(requestId);
reject(new Error('TIMEOUT'));
}, timeout);
this.pendingRequests.set(requestId, { resolve, timer });
targetWs.send(
JSON.stringify({
requestId,
toolCall,
type: 'tool_call_request',
}),
);
});
return Response.json({ success: true, ...(result as object) });
} catch (err) {
return Response.json(
{
content: `工具调用超时(${timeout / 1000}s`,
error: (err as Error).message,
success: false,
},
{ status: 504 },
);
}
}
}