waveterm/pkg/web/ws.go
Mike Sawka ae3e9f05b7
new job manager / framework for creating persistent remove sessions (#2779)
lots of stuff here.

introduces a streaming framework for the RPC system with flow control.
new authentication primitives for the RPC system. this is used to create
a persistent "job manager" process (via wsh) that can survive
disconnects. and then a jobcontroller in the main server that can
create, reconnect, and manage these new persistent jobs.

code is currently not actively hooked up to anything minus some new
debugging wsh commands, and a switch in the term block that lets me test
viewing the output.

after PRing this change the next steps are more testing and then
integrating this functionality into the product.
2026-01-21 16:54:18 -08:00

316 lines
8.5 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package web
import (
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/wavetermdev/waveterm/pkg/authkey"
"github.com/wavetermdev/waveterm/pkg/baseds"
"github.com/wavetermdev/waveterm/pkg/eventbus"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/web/webcmd"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
const wsReadWaitTimeout = 15 * time.Second
const wsWriteWaitTimeout = 10 * time.Second
const wsPingPeriodTickTime = 10 * time.Second
const wsInitialPingTime = 1 * time.Second
const wsMaxMessageSize = 10 * 1024 * 1024
const DefaultCommandTimeout = 2 * time.Second
type StableConnInfo struct {
ConnId string
LinkId baseds.LinkId
}
var GlobalLock = &sync.Mutex{}
var RouteToConnMap = map[string]*StableConnInfo{} // stableid => StableConnInfo
func RunWebSocketServer(listener net.Listener) {
gr := mux.NewRouter()
gr.HandleFunc("/ws", HandleWs)
server := &http.Server{
ReadTimeout: HttpReadTimeout,
WriteTimeout: HttpWriteTimeout,
MaxHeaderBytes: HttpMaxHeaderBytes,
Handler: gr,
}
server.SetKeepAlivesEnabled(false)
log.Printf("[websocket] running websocket server on %s\n", listener.Addr())
err := server.Serve(listener)
if err != nil {
log.Printf("[websocket] error trying to run websocket server: %v\n", err)
}
}
var WebSocketUpgrader = websocket.Upgrader{
ReadBufferSize: 4 * 1024,
WriteBufferSize: 32 * 1024,
HandshakeTimeout: 1 * time.Second,
CheckOrigin: func(r *http.Request) bool { return true },
}
func HandleWs(w http.ResponseWriter, r *http.Request) {
err := HandleWsInternal(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func getMessageType(jmsg map[string]any) string {
if str, ok := jmsg["type"].(string); ok {
return str
}
return ""
}
func getStringFromMap(jmsg map[string]any, key string) string {
if str, ok := jmsg[key].(string); ok {
return str
}
return ""
}
func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan baseds.RpcInputChType) {
var rtnErr error
var cmdType string
defer func() {
panicCtx := "processWSCommand"
if cmdType != "" {
panicCtx = fmt.Sprintf("processWSCommand:%s", cmdType)
}
panicErr := panichandler.PanicHandler(panicCtx, recover())
if panicErr != nil {
rtnErr = panicErr
}
if rtnErr == nil {
return
}
rtn := map[string]any{"type": "error", "error": rtnErr.Error()}
outputCh <- rtn
}()
wsCommand, err := webcmd.ParseWSCommandMap(jmsg)
if err != nil {
rtnErr = fmt.Errorf("cannot parse wscommand: %v", err)
return
}
cmdType = wsCommand.GetWSCommand()
switch cmd := wsCommand.(type) {
case *webcmd.WSRpcCommand:
rpcMsg := cmd.Message
if rpcMsg == nil {
return
}
if rpcMsg.Command != "" {
cmdType = fmt.Sprintf("%s:%s", cmdType, rpcMsg.Command)
}
msgBytes, err := json.Marshal(rpcMsg)
if err != nil {
// this really should never fail since we just unmarshalled this value
return
}
rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes}
}
}
func processMessage(jmsg map[string]any, outputCh chan any, rpcInputCh chan baseds.RpcInputChType) {
wsCommand := getStringFromMap(jmsg, "wscommand")
if wsCommand == "" {
return
}
processWSCommand(jmsg, outputCh, rpcInputCh)
}
func ReadLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, rpcInputCh chan baseds.RpcInputChType, routeId string) {
readWait := wsReadWaitTimeout
conn.SetReadLimit(wsMaxMessageSize)
conn.SetReadDeadline(time.Now().Add(readWait))
defer close(closeCh)
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Printf("[websocket] ReadPump error (%s): %v\n", routeId, err)
break
}
jmsg := map[string]any{}
err = json.Unmarshal(message, &jmsg)
if err != nil {
log.Printf("[websocket] error unmarshalling json: %v\n", err)
break
}
conn.SetReadDeadline(time.Now().Add(readWait))
msgType := getMessageType(jmsg)
if msgType == "pong" {
// nothing
continue
}
if msgType == "ping" {
now := time.Now()
pongMessage := map[string]interface{}{"type": "pong", "stime": now.UnixMilli()}
outputCh <- pongMessage
continue
}
go processMessage(jmsg, outputCh, rpcInputCh)
}
}
func WritePing(conn *websocket.Conn) error {
now := time.Now()
pingMessage := map[string]interface{}{"type": "ping", "stime": now.UnixMilli()}
jsonVal, _ := json.Marshal(pingMessage)
_ = conn.SetWriteDeadline(time.Now().Add(wsWriteWaitTimeout)) // no error
err := conn.WriteMessage(websocket.TextMessage, jsonVal)
if err != nil {
return err
}
return nil
}
func WriteLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, routeId string) {
ticker := time.NewTicker(wsInitialPingTime)
defer ticker.Stop()
initialPing := true
for {
select {
case msg := <-outputCh:
var barr []byte
var err error
if _, ok := msg.([]byte); ok {
barr = msg.([]byte)
} else {
barr, err = json.Marshal(msg)
if err != nil {
log.Printf("[websocket] cannot marshal websocket message: %v\n", err)
// just loop again
break
}
}
err = conn.WriteMessage(websocket.TextMessage, barr)
if err != nil {
conn.Close()
log.Printf("[websocket] WritePump error (%s): %v\n", routeId, err)
return
}
case <-ticker.C:
err := WritePing(conn)
if err != nil {
log.Printf("[websocket] WritePump error (%s): %v\n", routeId, err)
return
}
if initialPing {
initialPing = false
ticker.Reset(wsPingPeriodTickTime)
}
case <-closeCh:
return
}
}
}
func registerConn(wsConnId string, stableId string, wproxy *wshutil.WshRpcProxy) {
GlobalLock.Lock()
defer GlobalLock.Unlock()
curConnInfo := RouteToConnMap[stableId]
if curConnInfo != nil {
log.Printf("[websocket] warning: replacing existing connection for stableid %q\n", stableId)
if curConnInfo.LinkId != baseds.NoLinkId {
wshutil.DefaultRouter.UnregisterLink(curConnInfo.LinkId)
}
}
linkId := wshutil.DefaultRouter.RegisterTrustedRouter(wproxy)
RouteToConnMap[stableId] = &StableConnInfo{
ConnId: wsConnId,
LinkId: linkId,
}
}
func unregisterConn(wsConnId string, stableId string) {
GlobalLock.Lock()
defer GlobalLock.Unlock()
curConnInfo := RouteToConnMap[stableId]
if curConnInfo == nil || curConnInfo.ConnId != wsConnId {
log.Printf("[websocket] warning: trying to unregister connection %q for stableid %q but it is not the current connection (ignoring)\n", wsConnId, stableId)
return
}
delete(RouteToConnMap, stableId)
if curConnInfo.LinkId != baseds.NoLinkId {
wshutil.DefaultRouter.UnregisterLink(curConnInfo.LinkId)
}
}
func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
stableId := r.URL.Query().Get("stableid")
if stableId == "" {
return fmt.Errorf("stableid is required")
}
err := authkey.ValidateIncomingRequest(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(fmt.Sprintf("error validating authkey: %v", err)))
log.Printf("[websocket] error validating authkey: %v\n", err)
return err
}
conn, err := WebSocketUpgrader.Upgrade(w, r, nil)
if err != nil {
return fmt.Errorf("WebSocket Upgrade Failed: %v", err)
}
defer conn.Close()
wsConnId := uuid.New().String()
outputCh := make(chan any, 100)
closeCh := make(chan any)
log.Printf("[websocket] new connection: connid:%s stableid:%s\n", wsConnId, stableId)
eventbus.RegisterWSChannel(wsConnId, stableId, outputCh)
defer eventbus.UnregisterWSChannel(wsConnId)
wproxy := wshutil.MakeRpcProxy(fmt.Sprintf("ws:%s", stableId))
defer close(wproxy.ToRemoteCh)
registerConn(wsConnId, stableId, wproxy)
defer unregisterConn(wsConnId, stableId)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
defer func() {
panichandler.PanicHandler("HandleWsInternal:outputCh", recover())
}()
// no waitgroup add here
// move values from rpcOutputCh to outputCh
for msgBytes := range wproxy.ToRemoteCh {
rpcWSMsg := map[string]any{
"eventtype": "rpc", // TODO don't hard code this (but def is in eventbus)
"data": json.RawMessage(msgBytes),
}
outputCh <- rpcWSMsg
}
}()
go func() {
defer func() {
panichandler.PanicHandler("HandleWsInternal:ReadLoop", recover())
}()
defer wg.Done()
ReadLoop(conn, outputCh, closeCh, wproxy.FromRemoteCh, stableId)
}()
go func() {
defer func() {
panichandler.PanicHandler("HandleWsInternal:WriteLoop", recover())
}()
defer wg.Done()
WriteLoop(conn, outputCh, closeCh, stableId)
}()
wg.Wait()
close(wproxy.FromRemoteCh)
return nil
}