new job manager / framework for creating persistent remove sessions (#2779)

lots of stuff here.

introduces a streaming framework for the RPC system with flow control.
new authentication primitives for the RPC system. this is used to create
a persistent "job manager" process (via wsh) that can survive
disconnects. and then a jobcontroller in the main server that can
create, reconnect, and manage these new persistent jobs.

code is currently not actively hooked up to anything minus some new
debugging wsh commands, and a switch in the term block that lets me test
viewing the output.

after PRing this change the next steps are more testing and then
integrating this functionality into the product.
This commit is contained in:
Mike Sawka 2026-01-21 16:54:18 -08:00 committed by GitHub
parent 011ca146df
commit ae3e9f05b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 5983 additions and 1076 deletions

View file

@ -61,5 +61,8 @@
},
"directoryFilters": ["-tsunami/frontend/scaffold", "-dist", "-make"]
},
"tailwindCSS.lint.suggestCanonicalClasses": "ignore"
"tailwindCSS.lint.suggestCanonicalClasses": "ignore",
"go.coverageDecorator": {
"type": "gutter"
}
}

View file

@ -20,6 +20,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/blocklogger"
"github.com/wavetermdev/waveterm/pkg/filebackup"
"github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/jobcontroller"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
@ -391,7 +392,7 @@ func createMainWshClient() {
wshfs.RpcClient = rpc
wshutil.DefaultRouter.RegisterTrustedLeaf(rpc, wshutil.DefaultRoute)
wps.Broker.SetClient(wshutil.DefaultRouter)
localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}, "conn:local")
localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, wshremote.MakeRemoteRpcServerImpl(nil, wshutil.DefaultRouter, wshclient.GetBareRpcClient(), true), "conn:local")
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
wshutil.DefaultRouter.RegisterTrustedLeaf(localConnWsh, wshutil.MakeConnectionRouteId(wshrpc.LocalConnName))
}
@ -572,6 +573,7 @@ func main() {
go backupCleanupLoop()
go startupActivityUpdate(firstLaunch) // must be after startConfigWatcher()
blocklogger.InitBlockLogger()
jobcontroller.InitJobController()
go func() {
defer func() {
panichandler.PanicHandler("GetSystemSummary", recover())

View file

@ -38,11 +38,14 @@ var serverCmd = &cobra.Command{
}
var connServerRouter bool
var connServerRouterDomainSocket bool
var connServerConnName string
var connServerDev bool
var ConnServerWshRouter *wshutil.WshRouter
func init() {
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode (stdio upstream)")
serverCmd.Flags().BoolVar(&connServerRouterDomainSocket, "router-domainsocket", false, "run in local router mode (domain socket upstream)")
serverCmd.Flags().StringVar(&connServerConnName, "conn", "", "connection name")
serverCmd.Flags().BoolVar(&connServerDev, "dev", false, "enable dev mode with file logging and PID in logs")
rootCmd.AddCommand(serverCmd)
@ -123,7 +126,12 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh
RouteId: routeId,
Conn: connServerConnName,
}
connServerClient := wshutil.MakeWshRpc(rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}, routeId)
bareRouteId := wshutil.MakeRandomProcRouteId()
bareClient := wshutil.MakeWshRpc(wshrpc.RpcContext{}, &wshclient.WshServer{}, bareRouteId)
router.RegisterTrustedLeaf(bareClient, bareRouteId)
connServerClient := wshutil.MakeWshRpc(rpcCtx, wshremote.MakeRemoteRpcServerImpl(os.Stdout, router, bareClient, false), routeId)
router.RegisterTrustedLeaf(connServerClient, routeId)
return connServerClient, nil
}
@ -131,6 +139,7 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh
func serverRunRouter() error {
log.Printf("starting connserver router")
router := wshutil.NewWshRouter()
ConnServerWshRouter = router
termProxy := wshutil.MakeRpcProxy("connserver-term")
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
go func() {
@ -209,8 +218,112 @@ func serverRunRouter() error {
select {}
}
func serverRunRouterDomainSocket(jwtToken string) error {
log.Printf("starting connserver router (domain socket upstream)")
// extract socket name from JWT token (unverified - we're on the client side)
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
if err != nil {
return fmt.Errorf("error extracting socket name from JWT: %v", err)
}
// connect to the forwarded domain socket
sockName = wavebase.ExpandHomeDirSafe(sockName)
conn, err := net.Dial("unix", sockName)
if err != nil {
return fmt.Errorf("error connecting to domain socket %s: %v", sockName, err)
}
// create router
router := wshutil.NewWshRouter()
ConnServerWshRouter = router
// create proxy for the domain socket connection
upstreamProxy := wshutil.MakeRpcProxy("connserver-upstream")
// goroutine to write to the domain socket
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterDomainSocket:WriteLoop", recover())
}()
writeErr := wshutil.AdaptOutputChToStream(upstreamProxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to upstream domain socket: %v\n", writeErr)
}
}()
// goroutine to read from the domain socket
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterDomainSocket:ReadLoop", recover())
}()
defer func() {
log.Printf("upstream domain socket closed, shutting down")
wshutil.DoShutdown("", 0, true)
}()
wshutil.AdaptStreamToMsgCh(conn, upstreamProxy.FromRemoteCh)
}()
// register the domain socket connection as upstream
router.RegisterUpstream(upstreamProxy)
// setup the connserver rpc client (leaf)
client, err := setupConnServerRpcClientWithRouter(router)
if err != nil {
return fmt.Errorf("error setting up connserver rpc client: %v", err)
}
wshfs.RpcClient = client
// authenticate with the upstream router using the JWT
_, err = wshclient.AuthenticateCommand(client, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute})
if err != nil {
return fmt.Errorf("error authenticating with upstream: %v", err)
}
log.Printf("authenticated with upstream router")
// fetch and set JWT public key
log.Printf("trying to get JWT public key")
jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(client, nil)
if err != nil {
return fmt.Errorf("error getting jwt public key: %v", err)
}
jwtPublicKeyBytes, err := base64.StdEncoding.DecodeString(jwtPublicKeyB64)
if err != nil {
return fmt.Errorf("error decoding jwt public key: %v", err)
}
err = wavejwt.SetPublicKey(jwtPublicKeyBytes)
if err != nil {
return fmt.Errorf("error setting jwt public key: %v", err)
}
log.Printf("got JWT public key")
// set up the local domain socket listener for local wsh commands
unixListener, err := MakeRemoteUnixListener()
if err != nil {
return fmt.Errorf("cannot create unix listener: %v", err)
}
log.Printf("unix listener started")
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterDomainSocket:runListener", recover())
}()
runListener(unixListener, router)
}()
// run the sysinfo loop
go func() {
defer func() {
panichandler.PanicHandler("serverRunRouterDomainSocket:RunSysInfoLoop", recover())
}()
wshremote.RunSysInfoLoop(client, connServerConnName)
}()
log.Printf("running server (router-domainsocket mode), successfully started")
select {}
}
func serverRunNormal(jwtToken string) error {
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout}, jwtToken)
err := setupRpcClient(wshremote.MakeRemoteRpcServerImpl(os.Stdout, nil, nil, false), jwtToken)
if err != nil {
return err
}
@ -283,6 +396,20 @@ func serverRun(cmd *cobra.Command, args []string) error {
}
return err
}
if connServerRouterDomainSocket {
jwtToken, err := askForJwtToken()
if err != nil {
if logFile != nil {
fmt.Fprintf(logFile, "askForJwtToken error: %v\n", err)
}
return err
}
err = serverRunRouterDomainSocket(jwtToken)
if err != nil && logFile != nil {
fmt.Fprintf(logFile, "serverRunRouterDomainSocket error: %v\n", err)
}
return err
}
jwtToken, err := askForJwtToken()
if err != nil {
if logFile != nil {

View file

@ -0,0 +1,382 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
)
var jobDebugCmd = &cobra.Command{
Use: "jobdebug",
Short: "debugging commands for the job system",
Hidden: true,
PersistentPreRunE: preRunSetupRpcClient,
}
var jobDebugListCmd = &cobra.Command{
Use: "list",
Short: "list all jobs with debug information",
RunE: jobDebugListRun,
}
var jobDebugDeleteCmd = &cobra.Command{
Use: "delete",
Short: "delete a job entry by jobid",
RunE: jobDebugDeleteRun,
}
var jobDebugDeleteAllCmd = &cobra.Command{
Use: "deleteall",
Short: "delete all jobs",
RunE: jobDebugDeleteAllRun,
}
var jobDebugPruneCmd = &cobra.Command{
Use: "prune",
Short: "remove jobs where the job manager is no longer running",
RunE: jobDebugPruneRun,
}
var jobDebugExitCmd = &cobra.Command{
Use: "exit",
Short: "exit a job manager",
RunE: jobDebugExitRun,
}
var jobDebugDisconnectCmd = &cobra.Command{
Use: "disconnect",
Short: "disconnect from a job manager",
RunE: jobDebugDisconnectRun,
}
var jobDebugReconnectCmd = &cobra.Command{
Use: "reconnect",
Short: "reconnect to a job manager",
RunE: jobDebugReconnectRun,
}
var jobDebugReconnectConnCmd = &cobra.Command{
Use: "reconnectconn",
Short: "reconnect all jobs for a connection",
RunE: jobDebugReconnectConnRun,
}
var jobDebugGetOutputCmd = &cobra.Command{
Use: "getoutput",
Short: "get the terminal output for a job",
RunE: jobDebugGetOutputRun,
}
var jobDebugStartCmd = &cobra.Command{
Use: "start",
Short: "start a new job",
Args: cobra.MinimumNArgs(1),
RunE: jobDebugStartRun,
}
var jobDebugAttachJobCmd = &cobra.Command{
Use: "attachjob",
Short: "attach a job to a block",
RunE: jobDebugAttachJobRun,
}
var jobDebugDetachJobCmd = &cobra.Command{
Use: "detachjob",
Short: "detach a job from its block",
RunE: jobDebugDetachJobRun,
}
var jobIdFlag string
var jobDebugJsonFlag bool
var jobConnFlag string
var exitJobIdFlag string
var disconnectJobIdFlag string
var reconnectJobIdFlag string
var reconnectConnNameFlag string
var attachJobIdFlag string
var attachBlockIdFlag string
var detachJobIdFlag string
func init() {
rootCmd.AddCommand(jobDebugCmd)
jobDebugCmd.AddCommand(jobDebugListCmd)
jobDebugCmd.AddCommand(jobDebugDeleteCmd)
jobDebugCmd.AddCommand(jobDebugDeleteAllCmd)
jobDebugCmd.AddCommand(jobDebugPruneCmd)
jobDebugCmd.AddCommand(jobDebugExitCmd)
jobDebugCmd.AddCommand(jobDebugDisconnectCmd)
jobDebugCmd.AddCommand(jobDebugReconnectCmd)
jobDebugCmd.AddCommand(jobDebugReconnectConnCmd)
jobDebugCmd.AddCommand(jobDebugGetOutputCmd)
jobDebugCmd.AddCommand(jobDebugStartCmd)
jobDebugCmd.AddCommand(jobDebugAttachJobCmd)
jobDebugCmd.AddCommand(jobDebugDetachJobCmd)
jobDebugListCmd.Flags().BoolVar(&jobDebugJsonFlag, "json", false, "output as JSON")
jobDebugDeleteCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to delete (required)")
jobDebugDeleteCmd.MarkFlagRequired("jobid")
jobDebugExitCmd.Flags().StringVar(&exitJobIdFlag, "jobid", "", "job id to exit (required)")
jobDebugExitCmd.MarkFlagRequired("jobid")
jobDebugDisconnectCmd.Flags().StringVar(&disconnectJobIdFlag, "jobid", "", "job id to disconnect (required)")
jobDebugDisconnectCmd.MarkFlagRequired("jobid")
jobDebugReconnectCmd.Flags().StringVar(&reconnectJobIdFlag, "jobid", "", "job id to reconnect (required)")
jobDebugReconnectCmd.MarkFlagRequired("jobid")
jobDebugReconnectConnCmd.Flags().StringVar(&reconnectConnNameFlag, "conn", "", "connection name (required)")
jobDebugReconnectConnCmd.MarkFlagRequired("conn")
jobDebugGetOutputCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to get output for (required)")
jobDebugGetOutputCmd.MarkFlagRequired("jobid")
jobDebugStartCmd.Flags().StringVar(&jobConnFlag, "conn", "", "connection name (required)")
jobDebugStartCmd.MarkFlagRequired("conn")
jobDebugAttachJobCmd.Flags().StringVar(&attachJobIdFlag, "jobid", "", "job id to attach (required)")
jobDebugAttachJobCmd.MarkFlagRequired("jobid")
jobDebugAttachJobCmd.Flags().StringVar(&attachBlockIdFlag, "blockid", "", "block id to attach to (required)")
jobDebugAttachJobCmd.MarkFlagRequired("blockid")
jobDebugDetachJobCmd.Flags().StringVar(&detachJobIdFlag, "jobid", "", "job id to detach (required)")
jobDebugDetachJobCmd.MarkFlagRequired("jobid")
}
func jobDebugListRun(cmd *cobra.Command, args []string) error {
rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("getting job debug list: %w", err)
}
connectedJobIds, err := wshclient.JobControllerConnectedJobsCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("getting connected job ids: %w", err)
}
connectedMap := make(map[string]bool)
for _, jobId := range connectedJobIds {
connectedMap[jobId] = true
}
if jobDebugJsonFlag {
jsonData, err := json.MarshalIndent(rtnData, "", " ")
if err != nil {
return fmt.Errorf("marshaling json: %w", err)
}
fmt.Printf("%s\n", string(jsonData))
return nil
}
fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n", "OID", "Connection", "Connected", "Manager", "Cmd", "ExitCode", "Stream")
for _, job := range rtnData {
connectedStatus := "no"
if connectedMap[job.OID] {
connectedStatus = "yes"
}
streamStatus := "-"
if job.StreamDone {
if job.StreamError == "" {
streamStatus = "EOF"
} else {
streamStatus = fmt.Sprintf("%q", job.StreamError)
}
}
exitCode := "-"
if job.CmdExitTs > 0 {
if job.CmdExitCode != nil {
exitCode = fmt.Sprintf("%d", *job.CmdExitCode)
} else if job.CmdExitSignal != "" {
exitCode = job.CmdExitSignal
} else {
exitCode = "?"
}
}
fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n",
job.OID, job.Connection, connectedStatus, job.JobManagerStatus, job.Cmd, exitCode, streamStatus)
}
return nil
}
func jobDebugDeleteRun(cmd *cobra.Command, args []string) error {
err := wshclient.JobControllerDeleteJobCommand(RpcClient, jobIdFlag, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("deleting job: %w", err)
}
fmt.Printf("Job %s deleted successfully\n", jobIdFlag)
return nil
}
func jobDebugDeleteAllRun(cmd *cobra.Command, args []string) error {
rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("getting job debug list: %w", err)
}
if len(rtnData) == 0 {
fmt.Printf("No jobs to delete\n")
return nil
}
deletedCount := 0
for _, job := range rtnData {
err := wshclient.JobControllerDeleteJobCommand(RpcClient, job.OID, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
fmt.Printf("Error deleting job %s: %v\n", job.OID, err)
} else {
deletedCount++
}
}
fmt.Printf("Deleted %d of %d job(s)\n", deletedCount, len(rtnData))
return nil
}
func jobDebugPruneRun(cmd *cobra.Command, args []string) error {
rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("getting job debug list: %w", err)
}
if len(rtnData) == 0 {
fmt.Printf("No jobs to prune\n")
return nil
}
deletedCount := 0
for _, job := range rtnData {
if job.JobManagerStatus != "running" {
err := wshclient.JobControllerDeleteJobCommand(RpcClient, job.OID, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
fmt.Printf("Error deleting job %s: %v\n", job.OID, err)
} else {
deletedCount++
}
}
}
if deletedCount == 0 {
fmt.Printf("No jobs with stopped job managers to prune\n")
} else {
fmt.Printf("Pruned %d job(s) with stopped job managers\n", deletedCount)
}
return nil
}
func jobDebugExitRun(cmd *cobra.Command, args []string) error {
err := wshclient.JobControllerExitJobCommand(RpcClient, exitJobIdFlag, nil)
if err != nil {
return fmt.Errorf("exiting job manager: %w", err)
}
fmt.Printf("Job manager for %s exited successfully\n", exitJobIdFlag)
return nil
}
func jobDebugDisconnectRun(cmd *cobra.Command, args []string) error {
err := wshclient.JobControllerDisconnectJobCommand(RpcClient, disconnectJobIdFlag, nil)
if err != nil {
return fmt.Errorf("disconnecting from job manager: %w", err)
}
fmt.Printf("Disconnected from job manager for %s successfully\n", disconnectJobIdFlag)
return nil
}
func jobDebugReconnectRun(cmd *cobra.Command, args []string) error {
err := wshclient.JobControllerReconnectJobCommand(RpcClient, reconnectJobIdFlag, nil)
if err != nil {
return fmt.Errorf("reconnecting to job manager: %w", err)
}
fmt.Printf("Reconnected to job manager for %s successfully\n", reconnectJobIdFlag)
return nil
}
func jobDebugReconnectConnRun(cmd *cobra.Command, args []string) error {
err := wshclient.JobControllerReconnectJobsForConnCommand(RpcClient, reconnectConnNameFlag, nil)
if err != nil {
return fmt.Errorf("reconnecting jobs for connection: %w", err)
}
fmt.Printf("Reconnected all jobs for connection %s successfully\n", reconnectConnNameFlag)
return nil
}
func jobDebugGetOutputRun(cmd *cobra.Command, args []string) error {
fileData, err := wshclient.FileReadCommand(RpcClient, wshrpc.FileData{
Info: &wshrpc.FileInfo{
Path: fmt.Sprintf("wavefile://%s/term", jobIdFlag),
},
}, &wshrpc.RpcOpts{Timeout: 10000})
if err != nil {
return fmt.Errorf("reading job output: %w", err)
}
if fileData.Data64 != "" {
decoded, err := base64.StdEncoding.DecodeString(fileData.Data64)
if err != nil {
return fmt.Errorf("decoding output data: %w", err)
}
fmt.Printf("%s", string(decoded))
}
return nil
}
func jobDebugStartRun(cmd *cobra.Command, args []string) error {
cmdToRun := args[0]
cmdArgs := args[1:]
data := wshrpc.CommandJobControllerStartJobData{
ConnName: jobConnFlag,
Cmd: cmdToRun,
Args: cmdArgs,
Env: make(map[string]string),
TermSize: nil,
}
jobId, err := wshclient.JobControllerStartJobCommand(RpcClient, data, &wshrpc.RpcOpts{Timeout: 10000})
if err != nil {
return fmt.Errorf("starting job: %w", err)
}
fmt.Printf("Job started successfully with ID: %s\n", jobId)
return nil
}
func jobDebugAttachJobRun(cmd *cobra.Command, args []string) error {
data := wshrpc.CommandJobControllerAttachJobData{
JobId: attachJobIdFlag,
BlockId: attachBlockIdFlag,
}
err := wshclient.JobControllerAttachJobCommand(RpcClient, data, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("attaching job: %w", err)
}
fmt.Printf("Job %s attached to block %s successfully\n", attachJobIdFlag, attachBlockIdFlag)
return nil
}
func jobDebugDetachJobRun(cmd *cobra.Command, args []string) error {
err := wshclient.JobControllerDetachJobCommand(RpcClient, detachJobIdFlag, &wshrpc.RpcOpts{Timeout: 5000})
if err != nil {
return fmt.Errorf("detaching job: %w", err)
}
fmt.Printf("Job %s detached successfully\n", detachJobIdFlag)
return nil
}

View file

@ -0,0 +1,119 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"os"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/jobmanager"
)
var jobManagerCmd = &cobra.Command{
Use: "jobmanager",
Hidden: true,
Short: "job manager for wave terminal",
Args: cobra.NoArgs,
RunE: jobManagerRun,
}
var jobManagerJobId string
var jobManagerClientId string
func init() {
jobManagerCmd.Flags().StringVar(&jobManagerJobId, "jobid", "", "job ID (UUID, required)")
jobManagerCmd.Flags().StringVar(&jobManagerClientId, "clientid", "", "client ID (UUID, required)")
jobManagerCmd.MarkFlagRequired("jobid")
jobManagerCmd.MarkFlagRequired("clientid")
rootCmd.AddCommand(jobManagerCmd)
}
func jobManagerRun(cmd *cobra.Command, args []string) error {
_, err := uuid.Parse(jobManagerJobId)
if err != nil {
return fmt.Errorf("invalid jobid: must be a valid UUID")
}
_, err = uuid.Parse(jobManagerClientId)
if err != nil {
return fmt.Errorf("invalid clientid: must be a valid UUID")
}
publicKeyB64 := os.Getenv("WAVETERM_PUBLICKEY")
if publicKeyB64 == "" {
return fmt.Errorf("WAVETERM_PUBLICKEY environment variable is not set")
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
if err != nil {
return fmt.Errorf("failed to decode WAVETERM_PUBLICKEY: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
jobAuthToken, err := readJobAuthToken(ctx)
if err != nil {
return fmt.Errorf("failed to read job auth token: %v", err)
}
readyFile := os.NewFile(3, "ready-pipe")
_, err = readyFile.Stat()
if err != nil {
return fmt.Errorf("ready pipe (fd 3) not available: %v", err)
}
err = jobmanager.SetupJobManager(jobManagerClientId, jobManagerJobId, publicKeyBytes, jobAuthToken, readyFile)
if err != nil {
return fmt.Errorf("error setting up job manager: %v", err)
}
select {}
}
func readJobAuthToken(ctx context.Context) (string, error) {
resultCh := make(chan string, 1)
errorCh := make(chan error, 1)
go func() {
reader := bufio.NewReader(os.Stdin)
line, err := reader.ReadString('\n')
if err != nil {
errorCh <- fmt.Errorf("error reading from stdin: %v", err)
return
}
line = strings.TrimSpace(line)
prefix := jobmanager.JobAccessTokenLabel + ":"
if !strings.HasPrefix(line, prefix) {
errorCh <- fmt.Errorf("invalid token format: expected '%s'", prefix)
return
}
token := strings.TrimPrefix(line, prefix)
token = strings.TrimSpace(token)
if token == "" {
errorCh <- fmt.Errorf("empty job auth token")
return
}
resultCh <- token
}()
select {
case token := <-resultCh:
return token, nil
case err := <-errorCh:
return "", err
case <-ctx.Done():
return "", ctx.Err()
}
}

View file

@ -0,0 +1 @@
DROP TABLE IF EXISTS db_job;

View file

@ -0,0 +1,5 @@
CREATE TABLE IF NOT EXISTS db_job (
oid varchar(36) PRIMARY KEY,
version int NOT NULL,
data json NOT NULL
);

View file

@ -22,6 +22,21 @@ class RpcApiType {
return client.wshRpcCall("authenticate", data, opts);
}
// command "authenticatejobmanager" [call]
AuthenticateJobManagerCommand(client: WshClient, data: CommandAuthenticateJobManagerData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("authenticatejobmanager", data, opts);
}
// command "authenticatejobmanagerverify" [call]
AuthenticateJobManagerVerifyCommand(client: WshClient, data: CommandAuthenticateJobManagerData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("authenticatejobmanagerverify", data, opts);
}
// command "authenticatetojobmanager" [call]
AuthenticateToJobManagerCommand(client: WshClient, data: CommandAuthenticateToJobData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("authenticatetojobmanager", data, opts);
}
// command "authenticatetoken" [call]
AuthenticateTokenCommand(client: WshClient, data: CommandAuthenticateTokenData, opts?: RpcOpts): Promise<CommandAuthenticateRtnData> {
return client.wshRpcCall("authenticatetoken", data, opts);
@ -377,6 +392,76 @@ class RpcApiType {
return client.wshRpcCall("getwaveairatelimit", null, opts);
}
// command "jobcmdexited" [call]
JobCmdExitedCommand(client: WshClient, data: CommandJobCmdExitedData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcmdexited", data, opts);
}
// command "jobcontrollerattachjob" [call]
JobControllerAttachJobCommand(client: WshClient, data: CommandJobControllerAttachJobData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerattachjob", data, opts);
}
// command "jobcontrollerconnectedjobs" [call]
JobControllerConnectedJobsCommand(client: WshClient, opts?: RpcOpts): Promise<string[]> {
return client.wshRpcCall("jobcontrollerconnectedjobs", null, opts);
}
// command "jobcontrollerdeletejob" [call]
JobControllerDeleteJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerdeletejob", data, opts);
}
// command "jobcontrollerdetachjob" [call]
JobControllerDetachJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerdetachjob", data, opts);
}
// command "jobcontrollerdisconnectjob" [call]
JobControllerDisconnectJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerdisconnectjob", data, opts);
}
// command "jobcontrollerexitjob" [call]
JobControllerExitJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerexitjob", data, opts);
}
// command "jobcontrollerlist" [call]
JobControllerListCommand(client: WshClient, opts?: RpcOpts): Promise<Job[]> {
return client.wshRpcCall("jobcontrollerlist", null, opts);
}
// command "jobcontrollerreconnectjob" [call]
JobControllerReconnectJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerreconnectjob", data, opts);
}
// command "jobcontrollerreconnectjobsforconn" [call]
JobControllerReconnectJobsForConnCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobcontrollerreconnectjobsforconn", data, opts);
}
// command "jobcontrollerstartjob" [call]
JobControllerStartJobCommand(client: WshClient, data: CommandJobControllerStartJobData, opts?: RpcOpts): Promise<string> {
return client.wshRpcCall("jobcontrollerstartjob", data, opts);
}
// command "jobinput" [call]
JobInputCommand(client: WshClient, data: CommandJobInputData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobinput", data, opts);
}
// command "jobprepareconnect" [call]
JobPrepareConnectCommand(client: WshClient, data: CommandJobPrepareConnectData, opts?: RpcOpts): Promise<CommandJobConnectRtnData> {
return client.wshRpcCall("jobprepareconnect", data, opts);
}
// command "jobstartstream" [call]
JobStartStreamCommand(client: WshClient, data: CommandJobStartStreamData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("jobstartstream", data, opts);
}
// command "listallappfiles" [call]
ListAllAppFilesCommand(client: WshClient, data: CommandListAllAppFilesData, opts?: RpcOpts): Promise<CommandListAllAppFilesRtnData> {
return client.wshRpcCall("listallappfiles", data, opts);
@ -432,6 +517,11 @@ class RpcApiType {
return client.wshRpcCall("recordtevent", data, opts);
}
// command "remotedisconnectfromjobmanager" [call]
RemoteDisconnectFromJobManagerCommand(client: WshClient, data: CommandRemoteDisconnectFromJobManagerData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("remotedisconnectfromjobmanager", data, opts);
}
// command "remotefilecopy" [call]
RemoteFileCopyCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise<boolean> {
return client.wshRpcCall("remotefilecopy", data, opts);
@ -482,6 +572,16 @@ class RpcApiType {
return client.wshRpcCall("remotemkdir", data, opts);
}
// command "remotereconnecttojobmanager" [call]
RemoteReconnectToJobManagerCommand(client: WshClient, data: CommandRemoteReconnectToJobManagerData, opts?: RpcOpts): Promise<CommandRemoteReconnectToJobManagerRtnData> {
return client.wshRpcCall("remotereconnecttojobmanager", data, opts);
}
// command "remotestartjob" [call]
RemoteStartJobCommand(client: WshClient, data: CommandRemoteStartJobData, opts?: RpcOpts): Promise<CommandStartJobRtnData> {
return client.wshRpcCall("remotestartjob", data, opts);
}
// command "remotestreamcpudata" [responsestream]
RemoteStreamCpuDataCommand(client: WshClient, opts?: RpcOpts): AsyncGenerator<TimeSeriesData, void, boolean> {
return client.wshRpcStream("remotestreamcpudata", null, opts);
@ -497,6 +597,11 @@ class RpcApiType {
return client.wshRpcStream("remotetarstream", data, opts);
}
// command "remoteterminatejobmanager" [call]
RemoteTerminateJobManagerCommand(client: WshClient, data: CommandRemoteTerminateJobManagerData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("remoteterminatejobmanager", data, opts);
}
// command "remotewritefile" [call]
RemoteWriteFileCommand(client: WshClient, data: FileData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("remotewritefile", data, opts);
@ -572,6 +677,11 @@ class RpcApiType {
return client.wshRpcCall("startbuilder", data, opts);
}
// command "startjob" [call]
StartJobCommand(client: WshClient, data: CommandStartJobData, opts?: RpcOpts): Promise<CommandStartJobRtnData> {
return client.wshRpcCall("startjob", data, opts);
}
// command "stopbuilder" [call]
StopBuilderCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("stopbuilder", data, opts);
@ -607,6 +717,11 @@ class RpcApiType {
return client.wshRpcCall("termgetscrollbacklines", data, opts);
}
// command "termupdateattachedjob" [call]
TermUpdateAttachedJobCommand(client: WshClient, data: CommandTermUpdateAttachedJobData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("termupdateattachedjob", data, opts);
}
// command "test" [call]
TestCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("test", data, opts);

View file

@ -104,6 +104,11 @@ export class TermWshClient extends WshClient {
}
}
async handle_termupdateattachedjob(rh: RpcResponseHelper, data: CommandTermUpdateAttachedJobData): Promise<void> {
console.log("term-update-attached-job", this.blockId, data);
// TODO: implement frontend logic to handle job attachment updates
}
async handle_termgetscrollbacklines(
rh: RpcResponseHelper,
data: CommandTermGetScrollbackLinesData

View file

@ -298,6 +298,7 @@ const TerminalView = ({ blockId, model }: ViewComponentProps<TermViewModel>) =>
useWebGl: !termSettings?.["term:disablewebgl"],
sendDataHandler: model.sendDataToController.bind(model),
nodeModel: model.nodeModel,
jobId: blockData?.jobid,
}
);
(window as any).term = termWrap;

View file

@ -3,7 +3,6 @@
import type { BlockNodeModel } from "@/app/block/blocktypes";
import { getFileSubject } from "@/app/store/wps";
import { sendWSCommand } from "@/app/store/ws";
import { RpcApi } from "@/app/store/wshclientapi";
import { TabRpcClient } from "@/app/store/wshrpcutil";
import { WOS, fetchWaveFile, getApi, getSettingsKeyAtom, globalStore, openLink, recordTEvent } from "@/store/global";
@ -50,6 +49,7 @@ type TermWrapOptions = {
useWebGl?: boolean;
sendDataHandler?: (data: string) => void;
nodeModel?: BlockNodeModel;
jobId?: string;
};
// for xterm OSC handlers, we return true always because we "own" the OSC number.
@ -375,6 +375,7 @@ function handleOsc16162Command(data: string, blockId: string, loaded: boolean, t
export class TermWrap {
tabId: string;
blockId: string;
jobId: string;
ptyOffset: number;
dataBytesProcessed: number;
terminal: Terminal;
@ -422,6 +423,7 @@ export class TermWrap {
this.loaded = false;
this.tabId = tabId;
this.blockId = blockId;
this.jobId = waveOptions.jobId;
this.sendDataHandler = waveOptions.sendDataHandler;
this.nodeModel = waveOptions.nodeModel;
this.ptyOffset = 0;
@ -495,6 +497,10 @@ export class TermWrap {
});
}
getZoneId(): string {
return this.jobId ?? this.blockId;
}
resetCompositionState() {
this.isComposing = false;
this.composingData = "";
@ -566,7 +572,7 @@ export class TermWrap {
});
}
this.mainFileSubject = getFileSubject(this.blockId, TermFileName);
this.mainFileSubject = getFileSubject(this.getZoneId(), TermFileName);
this.mainFileSubject.subscribe(this.handleNewFileSubjectData.bind(this));
try {
@ -699,8 +705,9 @@ export class TermWrap {
}
async loadInitialTerminalData(): Promise<void> {
let startTs = Date.now();
const { data: cacheData, fileInfo: cacheFile } = await fetchWaveFile(this.blockId, TermCacheFileName);
const startTs = Date.now();
const zoneId = this.getZoneId();
const { data: cacheData, fileInfo: cacheFile } = await fetchWaveFile(zoneId, TermCacheFileName);
let ptyOffset = 0;
if (cacheFile != null) {
ptyOffset = cacheFile.meta["ptyoffset"] ?? 0;
@ -722,7 +729,7 @@ export class TermWrap {
}
}
}
const { data: mainData, fileInfo: mainFile } = await fetchWaveFile(this.blockId, TermFileName, ptyOffset);
const { data: mainData, fileInfo: mainFile } = await fetchWaveFile(zoneId, TermFileName, ptyOffset);
console.log(
`terminal loaded cachefile:${cacheData?.byteLength ?? 0} main:${mainData?.byteLength ?? 0} bytes, ${Date.now() - startTs}ms`
);
@ -751,12 +758,7 @@ export class TermWrap {
this.fitAddon.fit();
if (oldRows !== this.terminal.rows || oldCols !== this.terminal.cols) {
const termSize: TermSize = { rows: this.terminal.rows, cols: this.terminal.cols };
const wsCommand: SetBlockTermSizeWSCommand = {
wscommand: "setblocktermsize",
blockid: this.blockId,
termsize: termSize,
};
sendWSCommand(wsCommand);
RpcApi.ControllerInputCommand(TabRpcClient, { blockid: this.blockId, termsize: termSize });
}
dlog("resize", `${this.terminal.rows}x${this.terminal.cols}`, `${oldRows}x${oldCols}`, this.hasResized);
if (!this.hasResized) {

View file

@ -162,7 +162,7 @@ export class VDomModel {
this.queueUpdate(true);
}
this.routeGoneUnsub = waveEventSubscribe({
eventType: "route:gone",
eventType: "route:down",
scope: curBackendRoute,
handler: (event: WaveEvent) => {
this.disposed = true;

View file

@ -112,6 +112,7 @@ declare global {
runtimeopts?: RuntimeOpts;
stickers?: StickerType[];
subblockids?: string[];
jobid?: string;
};
// blockcontroller.BlockControllerRuntimeStatus
@ -139,13 +140,6 @@ declare global {
files: FileInfo[];
};
// webcmd.BlockInputWSCommand
type BlockInputWSCommand = {
wscommand: "blockinput";
blockid: string;
inputdata64: string;
};
// wshrpc.BlocksListEntry
type BlocksListEntry = {
windowid: string;
@ -179,6 +173,7 @@ declare global {
tosagreed?: number;
hasoldhistory?: boolean;
tempoid?: string;
installid?: string;
};
// workspaceservice.CloseTabRtnType
@ -194,6 +189,12 @@ declare global {
data: {[key: string]: any};
};
// wshrpc.CommandAuthenticateJobManagerData
type CommandAuthenticateJobManagerData = {
jobid: string;
jobauthtoken: string;
};
// wshrpc.CommandAuthenticateRtnData
type CommandAuthenticateRtnData = {
env?: {[key: string]: string};
@ -201,6 +202,11 @@ declare global {
rpccontext?: RpcContext;
};
// wshrpc.CommandAuthenticateToJobData
type CommandAuthenticateToJobData = {
jobaccesstoken: string;
};
// wshrpc.CommandAuthenticateTokenData
type CommandAuthenticateTokenData = {
token: string;
@ -343,6 +349,59 @@ declare global {
chatid: string;
};
// wshrpc.CommandJobCmdExitedData
type CommandJobCmdExitedData = {
jobid: string;
exitcode?: number;
exitsignal?: string;
exiterr?: string;
exitts?: number;
};
// wshrpc.CommandJobConnectRtnData
type CommandJobConnectRtnData = {
seq: number;
streamdone?: boolean;
streamerror?: string;
hasexited?: boolean;
exitcode?: number;
exitsignal?: string;
exiterr?: string;
};
// wshrpc.CommandJobControllerAttachJobData
type CommandJobControllerAttachJobData = {
jobid: string;
blockid: string;
};
// wshrpc.CommandJobControllerStartJobData
type CommandJobControllerStartJobData = {
connname: string;
cmd: string;
args: string[];
env: {[key: string]: string};
termsize?: TermSize;
};
// wshrpc.CommandJobInputData
type CommandJobInputData = {
jobid: string;
inputdata64?: string;
signame?: string;
termsize?: TermSize;
};
// wshrpc.CommandJobPrepareConnectData
type CommandJobPrepareConnectData = {
streammeta: StreamMeta;
seq: number;
};
// wshrpc.CommandJobStartStreamData
type CommandJobStartStreamData = {
};
// wshrpc.CommandListAllAppFilesData
type CommandListAllAppFilesData = {
appid: string;
@ -397,6 +456,11 @@ declare global {
modts?: number;
};
// wshrpc.CommandRemoteDisconnectFromJobManagerData
type CommandRemoteDisconnectFromJobManagerData = {
jobid: string;
};
// wshrpc.CommandRemoteListEntriesData
type CommandRemoteListEntriesData = {
path: string;
@ -408,6 +472,36 @@ declare global {
fileinfo?: FileInfo[];
};
// wshrpc.CommandRemoteReconnectToJobManagerData
type CommandRemoteReconnectToJobManagerData = {
jobid: string;
jobauthtoken: string;
mainserverjwttoken: string;
jobmanagerpid: number;
jobmanagerstartts: number;
};
// wshrpc.CommandRemoteReconnectToJobManagerRtnData
type CommandRemoteReconnectToJobManagerRtnData = {
success: boolean;
jobmanagergone: boolean;
error?: string;
};
// wshrpc.CommandRemoteStartJobData
type CommandRemoteStartJobData = {
cmd: string;
args: string[];
env: {[key: string]: string};
termsize: TermSize;
streammeta?: StreamMeta;
jobauthtoken: string;
jobid: string;
mainserverjwttoken: string;
clientid: string;
publickeybase64: string;
};
// wshrpc.CommandRemoteStreamFileData
type CommandRemoteStreamFileData = {
path: string;
@ -420,6 +514,13 @@ declare global {
opts?: FileCopyOpts;
};
// wshrpc.CommandRemoteTerminateJobManagerData
type CommandRemoteTerminateJobManagerData = {
jobid: string;
jobmanagerpid: number;
jobmanagerstartts: number;
};
// wshrpc.CommandRenameAppFileData
type CommandRenameAppFileData = {
appid: string;
@ -461,9 +562,26 @@ declare global {
builderid: string;
};
// wshrpc.CommandStartJobData
type CommandStartJobData = {
cmd: string;
args: string[];
env: {[key: string]: string};
termsize: TermSize;
streammeta?: StreamMeta;
};
// wshrpc.CommandStartJobRtnData
type CommandStartJobRtnData = {
cmdpid: number;
cmdstartts: number;
jobmanagerpid: number;
jobmanagerstartts: number;
};
// wshrpc.CommandStreamAckData
type CommandStreamAckData = {
id: number;
id: string;
seq: number;
rwnd: number;
fin?: boolean;
@ -474,7 +592,7 @@ declare global {
// wshrpc.CommandStreamData
type CommandStreamData = {
id: number;
id: string;
seq: number;
data64?: string;
eof?: boolean;
@ -496,6 +614,12 @@ declare global {
lastupdated: number;
};
// wshrpc.CommandTermUpdateAttachedJobData
type CommandTermUpdateAttachedJobData = {
blockid: string;
jobid?: string;
};
// wshrpc.CommandVarData
type CommandVarData = {
key: string;
@ -793,6 +917,32 @@ declare global {
configerrors: ConfigError[];
};
// waveobj.Job
type Job = WaveObj & {
connection: string;
jobkind: string;
cmd: string;
cmdargs?: string[];
cmdenv?: {[key: string]: string};
jobauthtoken: string;
attachedblockid?: string;
terminateonreconnect?: boolean;
jobmanagerstatus: string;
jobmanagerdonereason?: string;
jobmanagerstartuperror?: string;
jobmanagerpid?: number;
jobmanagerstartts?: number;
cmdpid?: number;
cmdstartts?: number;
cmdtermsize: TermSize;
cmdexitts?: number;
cmdexitcode?: number;
cmdexitsignal?: string;
cmdexiterror?: string;
streamdone?: boolean;
streamerror?: string;
};
// waveobj.LayoutActionData
type LayoutActionData = {
actiontype: string;
@ -1062,13 +1212,6 @@ declare global {
optional: boolean;
};
// webcmd.SetBlockTermSizeWSCommand
type SetBlockTermSizeWSCommand = {
wscommand: "setblocktermsize";
blockid: string;
termsize: TermSize;
};
// wconfig.SettingsType
type SettingsType = {
"app:*"?: boolean;
@ -1186,6 +1329,14 @@ declare global {
display: StickerDisplayOptsType;
};
// wshrpc.StreamMeta
type StreamMeta = {
id: string;
rwnd: number;
readerrouteid: string;
writerrouteid: string;
};
// wps.SubscriptionRequest
type SubscriptionRequest = {
event: string;
@ -1640,7 +1791,7 @@ declare global {
type WSCommandType = {
wscommand: string;
} & ( SetBlockTermSizeWSCommand | BlockInputWSCommand | WSRpcCommand );
} & ( WSRpcCommand );
// eventbus.WSEventType
type WSEventType = {

View file

@ -0,0 +1,853 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobcontroller
import (
"context"
"encoding/base64"
"fmt"
"io"
"log"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/streamclient"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavejwt"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
const (
JobStatus_Init = "init"
JobStatus_Running = "running"
JobStatus_Done = "done"
)
const (
JobDoneReason_StartupError = "startuperror"
JobDoneReason_Gone = "gone"
JobDoneReason_Terminated = "terminated"
)
const (
JobConnStatus_Disconnected = "disconnected"
JobConnStatus_Connecting = "connecting"
JobConnStatus_Connected = "connected"
)
const DefaultStreamRwnd = 64 * 1024
const MetaKey_TotalGap = "totalgap"
const JobOutputFileName = "term"
func isJobManagerRunning(job *waveobj.Job) bool {
return job.JobManagerStatus == JobStatus_Running
}
var (
jobConnStates = make(map[string]string)
jobConnStatesLock sync.Mutex
)
func getMetaInt64(meta wshrpc.FileMeta, key string) int64 {
val, ok := meta[key]
if !ok {
return 0
}
if intVal, ok := val.(int64); ok {
return intVal
}
if floatVal, ok := val.(float64); ok {
return int64(floatVal)
}
return 0
}
func InitJobController() {
rpcClient := wshclient.GetBareRpcClient()
rpcClient.EventListener.On(wps.Event_RouteUp, handleRouteUpEvent)
rpcClient.EventListener.On(wps.Event_RouteDown, handleRouteDownEvent)
wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{
Event: wps.Event_RouteUp,
AllScopes: true,
}, nil)
wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{
Event: wps.Event_RouteDown,
AllScopes: true,
}, nil)
}
func handleRouteUpEvent(event *wps.WaveEvent) {
handleRouteEvent(event, JobConnStatus_Connected)
}
func handleRouteDownEvent(event *wps.WaveEvent) {
handleRouteEvent(event, JobConnStatus_Disconnected)
}
func handleRouteEvent(event *wps.WaveEvent, newStatus string) {
for _, scope := range event.Scopes {
if strings.HasPrefix(scope, "job:") {
jobId := strings.TrimPrefix(scope, "job:")
SetJobConnStatus(jobId, newStatus)
log.Printf("[job:%s] connection status changed to %s", jobId, newStatus)
}
}
}
func GetJobConnStatus(jobId string) string {
jobConnStatesLock.Lock()
defer jobConnStatesLock.Unlock()
status, exists := jobConnStates[jobId]
if !exists {
return JobConnStatus_Disconnected
}
return status
}
func SetJobConnStatus(jobId string, status string) {
jobConnStatesLock.Lock()
defer jobConnStatesLock.Unlock()
if status == JobConnStatus_Disconnected {
delete(jobConnStates, jobId)
} else {
jobConnStates[jobId] = status
}
}
func GetConnectedJobIds() []string {
jobConnStatesLock.Lock()
defer jobConnStatesLock.Unlock()
var connectedJobIds []string
for jobId, status := range jobConnStates {
if status == JobConnStatus_Connected {
connectedJobIds = append(connectedJobIds, jobId)
}
}
return connectedJobIds
}
func ensureJobConnected(ctx context.Context, jobId string) (*waveobj.Job, error) {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
if err != nil {
return nil, fmt.Errorf("failed to get job: %w", err)
}
isConnected, err := conncontroller.IsConnected(job.Connection)
if err != nil {
return nil, fmt.Errorf("error checking connection status: %w", err)
}
if !isConnected {
return nil, fmt.Errorf("connection %q is not connected", job.Connection)
}
jobConnStatus := GetJobConnStatus(jobId)
if jobConnStatus != JobConnStatus_Connected {
return nil, fmt.Errorf("job is not connected (status: %s)", jobConnStatus)
}
return job, nil
}
type StartJobParams struct {
ConnName string
Cmd string
Args []string
Env map[string]string
TermSize *waveobj.TermSize
}
func StartJob(ctx context.Context, params StartJobParams) (string, error) {
if params.ConnName == "" {
return "", fmt.Errorf("connection name is required")
}
if params.Cmd == "" {
return "", fmt.Errorf("command is required")
}
if params.TermSize == nil {
params.TermSize = &waveobj.TermSize{Rows: 24, Cols: 80}
}
isConnected, err := conncontroller.IsConnected(params.ConnName)
if err != nil {
return "", fmt.Errorf("error checking connection status: %w", err)
}
if !isConnected {
return "", fmt.Errorf("connection %q is not connected", params.ConnName)
}
jobId := uuid.New().String()
jobAuthToken, err := utilfn.RandomHexString(32)
if err != nil {
return "", fmt.Errorf("failed to generate job auth token: %w", err)
}
jobAccessClaims := &wavejwt.WaveJwtClaims{
MainServer: true,
JobId: jobId,
}
jobAccessToken, err := wavejwt.Sign(jobAccessClaims)
if err != nil {
return "", fmt.Errorf("failed to generate job access token: %w", err)
}
job := &waveobj.Job{
OID: jobId,
Connection: params.ConnName,
Cmd: params.Cmd,
CmdArgs: params.Args,
CmdEnv: params.Env,
CmdTermSize: *params.TermSize,
JobAuthToken: jobAuthToken,
JobManagerStatus: JobStatus_Init,
Meta: make(waveobj.MetaMapType),
}
err = wstore.DBInsert(ctx, job)
if err != nil {
return "", fmt.Errorf("failed to create job in database: %w", err)
}
bareRpc := wshclient.GetBareRpcClient()
broker := bareRpc.StreamBroker
readerRouteId := wshclient.GetBareRpcClientRouteId()
writerRouteId := wshutil.MakeJobRouteId(jobId)
reader, streamMeta := broker.CreateStreamReader(readerRouteId, writerRouteId, DefaultStreamRwnd)
fileOpts := wshrpc.FileOpts{
MaxSize: 10 * 1024 * 1024,
Circular: true,
}
err = filestore.WFS.MakeFile(ctx, jobId, JobOutputFileName, wshrpc.FileMeta{}, fileOpts)
if err != nil {
return "", fmt.Errorf("failed to create WaveFS file: %w", err)
}
clientId, err := wstore.DBGetSingleton[*waveobj.Client](ctx)
if err != nil || clientId == nil {
return "", fmt.Errorf("failed to get client: %w", err)
}
publicKey := wavejwt.GetPublicKey()
publicKeyBase64 := base64.StdEncoding.EncodeToString(publicKey)
startJobData := wshrpc.CommandRemoteStartJobData{
Cmd: params.Cmd,
Args: params.Args,
Env: params.Env,
TermSize: *params.TermSize,
StreamMeta: streamMeta,
JobAuthToken: jobAuthToken,
JobId: jobId,
MainServerJwtToken: jobAccessToken,
ClientId: clientId.OID,
PublicKeyBase64: publicKeyBase64,
}
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeConnectionRouteId(params.ConnName),
Timeout: 30000,
}
log.Printf("[job:%s] sending RemoteStartJobCommand to connection %s", jobId, params.ConnName)
rtnData, err := wshclient.RemoteStartJobCommand(bareRpc, startJobData, rpcOpts)
if err != nil {
log.Printf("[job:%s] RemoteStartJobCommand failed: %v", jobId, err)
errMsg := fmt.Sprintf("failed to start job: %v", err)
wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.JobManagerStatus = JobStatus_Done
job.JobManagerDoneReason = JobDoneReason_StartupError
job.JobManagerStartupError = errMsg
})
return "", fmt.Errorf("failed to start remote job: %w", err)
}
log.Printf("[job:%s] RemoteStartJobCommand succeeded, cmdpid=%d cmdstartts=%d jobmanagerpid=%d jobmanagerstartts=%d", jobId, rtnData.CmdPid, rtnData.CmdStartTs, rtnData.JobManagerPid, rtnData.JobManagerStartTs)
err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.CmdPid = rtnData.CmdPid
job.CmdStartTs = rtnData.CmdStartTs
job.JobManagerPid = rtnData.JobManagerPid
job.JobManagerStartTs = rtnData.JobManagerStartTs
job.JobManagerStatus = JobStatus_Running
})
if err != nil {
log.Printf("[job:%s] warning: failed to update job status to running: %v", jobId, err)
} else {
log.Printf("[job:%s] job status updated to running", jobId)
}
go func() {
defer func() {
panichandler.PanicHandler("jobcontroller:runOutputLoop", recover())
}()
runOutputLoop(context.Background(), jobId, reader)
}()
return jobId, nil
}
func handleAppendJobFile(ctx context.Context, jobId string, fileName string, data []byte) error {
err := filestore.WFS.AppendData(ctx, jobId, fileName, data)
if err != nil {
return fmt.Errorf("error appending to job file: %w", err)
}
wps.Broker.Publish(wps.WaveEvent{
Event: wps.Event_BlockFile,
Scopes: []string{
waveobj.MakeORef(waveobj.OType_Job, jobId).String(),
},
Data: &wps.WSFileEventData{
ZoneId: jobId,
FileName: fileName,
FileOp: wps.FileOp_Append,
Data64: base64.StdEncoding.EncodeToString(data),
},
})
return nil
}
func runOutputLoop(ctx context.Context, jobId string, reader *streamclient.Reader) {
defer func() {
log.Printf("[job:%s] output loop finished", jobId)
}()
log.Printf("[job:%s] output loop started", jobId)
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
log.Printf("[job:%s] received %d bytes of data", jobId, n)
appendErr := handleAppendJobFile(ctx, jobId, JobOutputFileName, buf[:n])
if appendErr != nil {
log.Printf("[job:%s] error appending data to WaveFS: %v", jobId, appendErr)
} else {
log.Printf("[job:%s] successfully appended %d bytes to WaveFS", jobId, n)
}
}
if err == io.EOF {
log.Printf("[job:%s] stream ended (EOF)", jobId)
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.StreamDone = true
})
if updateErr != nil {
log.Printf("[job:%s] error updating job stream status: %v", jobId, updateErr)
}
tryTerminateJobManager(ctx, jobId)
break
}
if err != nil {
log.Printf("[job:%s] stream error: %v", jobId, err)
streamErr := err.Error()
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.StreamDone = true
job.StreamError = streamErr
})
if updateErr != nil {
log.Printf("[job:%s] error updating job stream error: %v", jobId, updateErr)
}
tryTerminateJobManager(ctx, jobId)
break
}
}
}
func HandleCmdJobExited(ctx context.Context, jobId string, data wshrpc.CommandJobCmdExitedData) error {
err := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.CmdExitError = data.ExitErr
job.CmdExitCode = data.ExitCode
job.CmdExitSignal = data.ExitSignal
job.CmdExitTs = data.ExitTs
})
if err != nil {
return fmt.Errorf("failed to update job exit status: %w", err)
}
tryTerminateJobManager(ctx, jobId)
return nil
}
func tryTerminateJobManager(ctx context.Context, jobId string) {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
if err != nil {
log.Printf("[job:%s] error getting job for termination check: %v", jobId, err)
return
}
if job.JobManagerStatus != JobStatus_Running {
return
}
cmdExited := job.CmdExitTs != 0
if !cmdExited || !job.StreamDone {
log.Printf("[job:%s] not ready for termination: exited=%v streamDone=%v", jobId, cmdExited, job.StreamDone)
return
}
log.Printf("[job:%s] both job cmd exited and stream finished, terminating job manager", jobId)
err = TerminateJobManager(ctx, jobId)
if err != nil {
log.Printf("[job:%s] error terminating job manager: %v", jobId, err)
}
}
func TerminateJobManager(ctx context.Context, jobId string) error {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
if err != nil {
return fmt.Errorf("failed to get job: %w", err)
}
return remoteTerminateJobManager(ctx, job)
}
func DisconnectJob(ctx context.Context, jobId string) error {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
if err != nil {
return fmt.Errorf("failed to get job: %w", err)
}
bareRpc := wshclient.GetBareRpcClient()
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeConnectionRouteId(job.Connection),
Timeout: 5000,
}
disconnectData := wshrpc.CommandRemoteDisconnectFromJobManagerData{
JobId: jobId,
}
err = wshclient.RemoteDisconnectFromJobManagerCommand(bareRpc, disconnectData, rpcOpts)
if err != nil {
return fmt.Errorf("failed to send disconnect command: %w", err)
}
log.Printf("[job:%s] job disconnect command sent successfully", jobId)
return nil
}
func remoteTerminateJobManager(ctx context.Context, job *waveobj.Job) error {
log.Printf("[job:%s] terminating job manager", job.OID)
bareRpc := wshclient.GetBareRpcClient()
terminateData := wshrpc.CommandRemoteTerminateJobManagerData{
JobId: job.OID,
JobManagerPid: job.JobManagerPid,
JobManagerStartTs: job.JobManagerStartTs,
}
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeConnectionRouteId(job.Connection),
Timeout: 5000,
}
err := wshclient.RemoteTerminateJobManagerCommand(bareRpc, terminateData, rpcOpts)
if err != nil {
log.Printf("[job:%s] error terminating job manager: %v", job.OID, err)
return fmt.Errorf("failed to terminate job manager: %w", err)
}
updateErr := wstore.DBUpdateFn(ctx, job.OID, func(job *waveobj.Job) {
job.JobManagerStatus = JobStatus_Done
job.JobManagerDoneReason = JobDoneReason_Terminated
job.TerminateOnReconnect = false
if !job.StreamDone {
job.StreamDone = true
job.StreamError = "job manager terminated"
}
})
if updateErr != nil {
log.Printf("[job:%s] error updating job status after termination: %v", job.OID, updateErr)
}
log.Printf("[job:%s] job manager terminated successfully", job.OID)
return nil
}
func ReconnectJob(ctx context.Context, jobId string) error {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
if err != nil {
return fmt.Errorf("failed to get job: %w", err)
}
isConnected, err := conncontroller.IsConnected(job.Connection)
if err != nil {
return fmt.Errorf("error checking connection status: %w", err)
}
if !isConnected {
return fmt.Errorf("connection %q is not connected", job.Connection)
}
if job.TerminateOnReconnect {
return remoteTerminateJobManager(ctx, job)
}
bareRpc := wshclient.GetBareRpcClient()
jobAccessClaims := &wavejwt.WaveJwtClaims{
MainServer: true,
JobId: jobId,
}
jobAccessToken, err := wavejwt.Sign(jobAccessClaims)
if err != nil {
return fmt.Errorf("failed to generate job access token: %w", err)
}
reconnectData := wshrpc.CommandRemoteReconnectToJobManagerData{
JobId: jobId,
JobAuthToken: job.JobAuthToken,
MainServerJwtToken: jobAccessToken,
JobManagerPid: job.JobManagerPid,
JobManagerStartTs: job.JobManagerStartTs,
}
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeConnectionRouteId(job.Connection),
Timeout: 5000,
}
log.Printf("[job:%s] sending RemoteReconnectToJobManagerCommand to connection %s", jobId, job.Connection)
rtnData, err := wshclient.RemoteReconnectToJobManagerCommand(bareRpc, reconnectData, rpcOpts)
if err != nil {
log.Printf("[job:%s] RemoteReconnectToJobManagerCommand failed: %v", jobId, err)
return fmt.Errorf("failed to reconnect to job manager: %w", err)
}
if !rtnData.Success {
log.Printf("[job:%s] RemoteReconnectToJobManagerCommand returned error: %s", jobId, rtnData.Error)
if rtnData.JobManagerGone {
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.JobManagerStatus = JobStatus_Done
job.JobManagerDoneReason = JobDoneReason_Gone
})
if updateErr != nil {
log.Printf("[job:%s] error updating job manager running status: %v", jobId, updateErr)
}
return fmt.Errorf("job manager has exited: %s", rtnData.Error)
}
return fmt.Errorf("failed to reconnect to job manager: %s", rtnData.Error)
}
log.Printf("[job:%s] RemoteReconnectToJobManagerCommand succeeded, waiting for route", jobId)
routeId := wshutil.MakeJobRouteId(jobId)
waitCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(waitCtx, routeId)
if err != nil {
return fmt.Errorf("route did not establish after successful reconnection: %w", err)
}
log.Printf("[job:%s] route established, restarting streaming", jobId)
return RestartStreaming(ctx, jobId, true)
}
func ReconnectJobsForConn(ctx context.Context, connName string) error {
isConnected, err := conncontroller.IsConnected(connName)
if err != nil {
return fmt.Errorf("error checking connection status: %w", err)
}
if !isConnected {
return fmt.Errorf("connection %q is not connected", connName)
}
allJobs, err := wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job)
if err != nil {
return fmt.Errorf("failed to get jobs: %w", err)
}
var jobsToReconnect []*waveobj.Job
for _, job := range allJobs {
if job.Connection == connName && isJobManagerRunning(job) {
jobsToReconnect = append(jobsToReconnect, job)
}
}
log.Printf("[conn:%s] found %d jobs to reconnect", connName, len(jobsToReconnect))
for _, job := range jobsToReconnect {
err = ReconnectJob(ctx, job.OID)
if err != nil {
log.Printf("[job:%s] error reconnecting: %v", job.OID, err)
}
}
return nil
}
func RestartStreaming(ctx context.Context, jobId string, knownConnected bool) error {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
if err != nil {
return fmt.Errorf("failed to get job: %w", err)
}
if !knownConnected {
isConnected, err := conncontroller.IsConnected(job.Connection)
if err != nil {
return fmt.Errorf("error checking connection status: %w", err)
}
if !isConnected {
return fmt.Errorf("connection %q is not connected", job.Connection)
}
jobConnStatus := GetJobConnStatus(jobId)
if jobConnStatus != JobConnStatus_Connected {
return fmt.Errorf("job manager is not connected (status: %s)", jobConnStatus)
}
}
var currentSeq int64 = 0
var totalGap int64 = 0
waveFile, err := filestore.WFS.Stat(ctx, jobId, JobOutputFileName)
if err == nil {
currentSeq = waveFile.Size
totalGap = getMetaInt64(waveFile.Meta, MetaKey_TotalGap)
currentSeq += totalGap
}
bareRpc := wshclient.GetBareRpcClient()
broker := bareRpc.StreamBroker
readerRouteId := wshclient.GetBareRpcClientRouteId()
writerRouteId := wshutil.MakeJobRouteId(jobId)
reader, streamMeta := broker.CreateStreamReaderWithSeq(readerRouteId, writerRouteId, DefaultStreamRwnd, currentSeq)
prepareData := wshrpc.CommandJobPrepareConnectData{
StreamMeta: *streamMeta,
Seq: currentSeq,
}
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeJobRouteId(jobId),
Timeout: 5000,
}
log.Printf("[job:%s] sending JobPrepareConnectCommand with seq=%d (fileSize=%d, totalGap=%d)", jobId, currentSeq, waveFile.Size, totalGap)
rtnData, err := wshclient.JobPrepareConnectCommand(bareRpc, prepareData, rpcOpts)
if err != nil {
reader.Close()
return fmt.Errorf("failed to prepare connect: %w", err)
}
if rtnData.HasExited {
exitCodeStr := "nil"
if rtnData.ExitCode != nil {
exitCodeStr = fmt.Sprintf("%d", *rtnData.ExitCode)
}
log.Printf("[job:%s] job has already exited: code=%s signal=%q err=%q", jobId, exitCodeStr, rtnData.ExitSignal, rtnData.ExitErr)
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.JobManagerStatus = JobStatus_Done
job.CmdExitCode = rtnData.ExitCode
job.CmdExitSignal = rtnData.ExitSignal
job.CmdExitError = rtnData.ExitErr
})
if updateErr != nil {
log.Printf("[job:%s] error updating job exit status: %v", jobId, updateErr)
}
}
if rtnData.StreamDone {
log.Printf("[job:%s] stream is already done: error=%q", jobId, rtnData.StreamError)
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
if !job.StreamDone {
job.StreamDone = true
if rtnData.StreamError != "" {
job.StreamError = rtnData.StreamError
}
}
})
if updateErr != nil {
log.Printf("[job:%s] error updating job stream status: %v", jobId, updateErr)
}
}
if rtnData.StreamDone && rtnData.HasExited {
reader.Close()
log.Printf("[job:%s] both stream done and job exited, calling tryExitJobManager", jobId)
tryTerminateJobManager(ctx, jobId)
return nil
}
if rtnData.StreamDone {
reader.Close()
log.Printf("[job:%s] stream already done, no need to restart streaming", jobId)
return nil
}
if rtnData.Seq > currentSeq {
gap := rtnData.Seq - currentSeq
totalGap += gap
log.Printf("[job:%s] detected gap: our seq=%d, server seq=%d, gap=%d, new totalGap=%d", jobId, currentSeq, rtnData.Seq, gap, totalGap)
metaErr := filestore.WFS.WriteMeta(ctx, jobId, JobOutputFileName, wshrpc.FileMeta{
MetaKey_TotalGap: totalGap,
}, true)
if metaErr != nil {
log.Printf("[job:%s] error updating totalgap metadata: %v", jobId, metaErr)
}
reader.UpdateNextSeq(rtnData.Seq)
}
log.Printf("[job:%s] sending JobStartStreamCommand", jobId)
startStreamData := wshrpc.CommandJobStartStreamData{}
err = wshclient.JobStartStreamCommand(bareRpc, startStreamData, rpcOpts)
if err != nil {
reader.Close()
return fmt.Errorf("failed to start stream: %w", err)
}
go func() {
defer func() {
panichandler.PanicHandler("jobcontroller:RestartStreaming:runOutputLoop", recover())
}()
runOutputLoop(context.Background(), jobId, reader)
}()
log.Printf("[job:%s] streaming restarted successfully", jobId)
return nil
}
func DeleteJob(ctx context.Context, jobId string) error {
SetJobConnStatus(jobId, JobConnStatus_Disconnected)
err := filestore.WFS.DeleteZone(ctx, jobId)
if err != nil {
log.Printf("[job:%s] warning: error deleting WaveFS zone: %v", jobId, err)
}
return wstore.DBDelete(ctx, waveobj.OType_Job, jobId)
}
func AttachJobToBlock(ctx context.Context, jobId string, blockId string) error {
err := wstore.WithTx(ctx, func(tx *wstore.TxWrap) error {
err := wstore.DBUpdateFn(tx.Context(), blockId, func(block *waveobj.Block) {
block.JobId = jobId
})
if err != nil {
return fmt.Errorf("failed to update block: %w", err)
}
err = wstore.DBUpdateFnErr(tx.Context(), jobId, func(job *waveobj.Job) error {
if job.AttachedBlockId != "" {
return fmt.Errorf("job %s already attached to block %s", jobId, job.AttachedBlockId)
}
job.AttachedBlockId = blockId
return nil
})
if err != nil {
return fmt.Errorf("failed to update job: %w", err)
}
log.Printf("[job:%s] attached to block:%s", jobId, blockId)
return nil
})
if err != nil {
return err
}
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeFeBlockRouteId(blockId),
NoResponse: true,
}
bareRpc := wshclient.GetBareRpcClient()
wshclient.TermUpdateAttachedJobCommand(bareRpc, wshrpc.CommandTermUpdateAttachedJobData{
BlockId: blockId,
JobId: jobId,
}, rpcOpts)
return nil
}
func DetachJobFromBlock(ctx context.Context, jobId string, updateBlock bool) error {
var blockId string
err := wstore.WithTx(ctx, func(tx *wstore.TxWrap) error {
job, err := wstore.DBMustGet[*waveobj.Job](tx.Context(), jobId)
if err != nil {
return fmt.Errorf("failed to get job: %w", err)
}
blockId = job.AttachedBlockId
if blockId == "" {
return nil
}
if updateBlock {
block, err := wstore.DBGet[*waveobj.Block](tx.Context(), blockId)
if err == nil && block != nil {
err = wstore.DBUpdateFn(tx.Context(), blockId, func(block *waveobj.Block) {
block.JobId = ""
})
if err != nil {
log.Printf("[job:%s] warning: failed to clear JobId from block:%s: %v", jobId, blockId, err)
}
}
}
err = wstore.DBUpdateFn(tx.Context(), jobId, func(job *waveobj.Job) {
job.AttachedBlockId = ""
})
if err != nil {
return fmt.Errorf("failed to update job: %w", err)
}
log.Printf("[job:%s] detached from block:%s", jobId, blockId)
return nil
})
if err != nil {
return err
}
if blockId != "" {
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeFeBlockRouteId(blockId),
NoResponse: true,
}
bareRpc := wshclient.GetBareRpcClient()
wshclient.TermUpdateAttachedJobCommand(bareRpc, wshrpc.CommandTermUpdateAttachedJobData{
BlockId: blockId,
JobId: "",
}, rpcOpts)
}
return nil
}
func SendInput(ctx context.Context, data wshrpc.CommandJobInputData) error {
jobId := data.JobId
_, err := ensureJobConnected(ctx, jobId)
if err != nil {
return err
}
rpcOpts := &wshrpc.RpcOpts{
Route: wshutil.MakeJobRouteId(jobId),
Timeout: 5000,
NoResponse: false,
}
bareRpc := wshclient.GetBareRpcClient()
err = wshclient.JobInputCommand(bareRpc, data, rpcOpts)
if err != nil {
return fmt.Errorf("failed to send input to job: %w", err)
}
if data.TermSize != nil {
err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
job.CmdTermSize = *data.TermSize
})
if err != nil {
log.Printf("[job:%s] warning: failed to update termsize in DB: %v", jobId, err)
}
}
return nil
}

218
pkg/jobmanager/cirbuf.go Normal file
View file

@ -0,0 +1,218 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobmanager
import (
"context"
"fmt"
"sync"
)
type CirBuf struct {
lock sync.Mutex
waiterChan chan chan struct{}
buf []byte
readPos int
writePos int
count int
totalSize int64
syncMode bool
windowSize int
}
func MakeCirBuf(maxSize int, initSyncMode bool) *CirBuf {
cb := &CirBuf{
buf: make([]byte, maxSize),
syncMode: initSyncMode,
waiterChan: make(chan chan struct{}, 1),
windowSize: maxSize,
}
return cb
}
// SetEffectiveWindow changes the sync mode and effective window size for flow control.
// The windowSize is capped at the buffer size.
// When window shrinks: sync mode blocks new writes, async mode truncates old data to enforce limit.
// When window increases: blocked writers are woken up if space becomes available.
func (cb *CirBuf) SetEffectiveWindow(syncMode bool, windowSize int) {
cb.lock.Lock()
defer cb.lock.Unlock()
maxSize := len(cb.buf)
if windowSize > maxSize {
windowSize = maxSize
}
oldSyncMode := cb.syncMode
oldWindowSize := cb.windowSize
cb.windowSize = windowSize
cb.syncMode = syncMode
// In async mode, enforce window size by truncating buffer if needed
if !syncMode && cb.count > windowSize {
excess := cb.count - windowSize
cb.readPos = (cb.readPos + excess) % maxSize
cb.count = windowSize
}
// Only sync mode blocks writers, so only wake if we were in sync mode.
// Wake when window grows (more space available) or switching to async (no longer blocking).
if oldSyncMode && (windowSize > oldWindowSize || !syncMode) {
cb.tryWakeWriter()
}
}
// Write will never block if syncMode is false
// If syncMode is true, write will block until enough data is consumed to allow the write to finish
// to cancel a write in progress use WriteCtx
func (cb *CirBuf) Write(data []byte) (int, error) {
return cb.WriteCtx(context.Background(), data)
}
// WriteCtx writes data to the circular buffer with context support for cancellation.
// In sync mode, blocks when buffer is full until space is available or context is cancelled.
// Returns partial byte count and context error if cancelled mid-write.
// NOTE: Only one concurrent blocked write is allowed. Multiple blocked writes will panic.
func (cb *CirBuf) WriteCtx(ctx context.Context, data []byte) (int, error) {
if len(data) == 0 {
return 0, nil
}
bytesWritten := 0
for bytesWritten < len(data) {
if err := ctx.Err(); err != nil {
return bytesWritten, err
}
n, spaceAvailable := cb.writeAvailable(data[bytesWritten:])
bytesWritten += n
if spaceAvailable != nil {
select {
case <-spaceAvailable:
continue
case <-ctx.Done():
tryReadCh(cb.waiterChan)
return bytesWritten, ctx.Err()
}
}
}
return bytesWritten, nil
}
func (cb *CirBuf) writeAvailable(data []byte) (int, chan struct{}) {
cb.lock.Lock()
defer cb.lock.Unlock()
size := len(cb.buf)
written := 0
for i := 0; i < len(data); i++ {
if cb.syncMode && cb.count >= cb.windowSize {
spaceAvailable := make(chan struct{})
if !tryWriteCh(cb.waiterChan, spaceAvailable) {
panic("CirBuf: multiple concurrent blocked writes not allowed")
}
return written, spaceAvailable
}
cb.buf[cb.writePos] = data[i]
cb.writePos = (cb.writePos + 1) % size
if cb.count < cb.windowSize {
cb.count++
} else {
cb.readPos = (cb.readPos + 1) % size
}
cb.totalSize++
written++
}
return written, nil
}
func (cb *CirBuf) PeekData(data []byte) int {
return cb.PeekDataAt(0, data)
}
func (cb *CirBuf) PeekDataAt(offset int, data []byte) int {
cb.lock.Lock()
defer cb.lock.Unlock()
if cb.count == 0 || offset >= cb.count {
return 0
}
size := len(cb.buf)
pos := (cb.readPos + offset) % size
maxRead := cb.count - offset
read := 0
for i := 0; i < len(data) && i < maxRead; i++ {
data[i] = cb.buf[pos]
pos = (pos + 1) % size
read++
}
return read
}
func (cb *CirBuf) Consume(numBytes int) error {
cb.lock.Lock()
defer cb.lock.Unlock()
if numBytes > cb.count {
return fmt.Errorf("cannot consume %d bytes, only %d available", numBytes, cb.count)
}
size := len(cb.buf)
cb.readPos = (cb.readPos + numBytes) % size
cb.count -= numBytes
cb.tryWakeWriter()
return nil
}
func (cb *CirBuf) HeadPos() int64 {
cb.lock.Lock()
defer cb.lock.Unlock()
return cb.totalSize - int64(cb.count)
}
func (cb *CirBuf) Size() int {
cb.lock.Lock()
defer cb.lock.Unlock()
return cb.count
}
func (cb *CirBuf) TotalSize() int64 {
cb.lock.Lock()
defer cb.lock.Unlock()
return cb.totalSize
}
func tryWriteCh[T any](ch chan<- T, val T) bool {
select {
case ch <- val:
return true
default:
return false
}
}
func tryReadCh[T any](ch <-chan T) (*T, bool) {
select {
case rtn := <-ch:
return &rtn, true
default:
return nil, false
}
}
func (cb *CirBuf) tryWakeWriter() {
if waiterCh, ok := tryReadCh(cb.waiterChan); ok {
close(*waiterCh)
}
}

208
pkg/jobmanager/jobcmd.go Normal file
View file

@ -0,0 +1,208 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobmanager
import (
"encoding/base64"
"fmt"
"log"
"os"
"os/exec"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type CmdDef struct {
Cmd string
Args []string
Env map[string]string
TermSize waveobj.TermSize
}
type JobCmd struct {
jobId string
lock sync.Mutex
cmd *exec.Cmd
cmdPty pty.Pty
ptsName string
cleanedUp bool
ptyClosed bool
processExited bool
exitCode *int
exitSignal string
exitErr error
exitTs int64
}
func MakeJobCmd(jobId string, cmdDef CmdDef) (*JobCmd, error) {
jm := &JobCmd{
jobId: jobId,
}
if cmdDef.TermSize.Rows == 0 || cmdDef.TermSize.Cols == 0 {
cmdDef.TermSize.Rows = 25
cmdDef.TermSize.Cols = 80
}
if cmdDef.TermSize.Rows <= 0 || cmdDef.TermSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", cmdDef.TermSize)
}
ecmd := exec.Command(cmdDef.Cmd, cmdDef.Args...)
if len(cmdDef.Env) > 0 {
ecmd.Env = os.Environ()
for key, val := range cmdDef.Env {
ecmd.Env = append(ecmd.Env, fmt.Sprintf("%s=%s", key, val))
}
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(cmdDef.TermSize.Rows), Cols: uint16(cmdDef.TermSize.Cols)})
if err != nil {
return nil, fmt.Errorf("failed to start command: %w", err)
}
setCloseOnExec(int(cmdPty.Fd()))
jm.cmd = ecmd
jm.cmdPty = cmdPty
jm.ptsName = jm.cmdPty.Name()
go jm.waitForProcess()
return jm, nil
}
func (jm *JobCmd) waitForProcess() {
if jm.cmd == nil || jm.cmd.Process == nil {
return
}
err := jm.cmd.Wait()
jm.lock.Lock()
defer jm.lock.Unlock()
jm.processExited = true
jm.exitTs = time.Now().UnixMilli()
jm.exitErr = err
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
if status.Signaled() {
jm.exitSignal = status.Signal().String()
} else if status.Exited() {
code := status.ExitStatus()
jm.exitCode = &code
} else {
log.Printf("Invalid WaitStatus, not exited or signaled: %v", status)
}
}
}
} else {
code := 0
jm.exitCode = &code
}
exitCodeStr := "nil"
if jm.exitCode != nil {
exitCodeStr = fmt.Sprintf("%d", *jm.exitCode)
}
log.Printf("process exited: exitcode=%s, signal=%s, err=%v\n", exitCodeStr, jm.exitSignal, jm.exitErr)
go WshCmdJobManager.sendJobExited()
}
func (jm *JobCmd) GetCmd() (*exec.Cmd, pty.Pty) {
jm.lock.Lock()
defer jm.lock.Unlock()
return jm.cmd, jm.cmdPty
}
func (jm *JobCmd) GetPGID() (int, error) {
jm.lock.Lock()
defer jm.lock.Unlock()
if jm.cmd == nil || jm.cmd.Process == nil {
return 0, fmt.Errorf("no active process")
}
if jm.processExited {
return 0, fmt.Errorf("process already exited")
}
pgid, err := getProcessGroupId(jm.cmd.Process.Pid)
if err != nil {
return 0, fmt.Errorf("failed to get pgid: %w", err)
}
if pgid <= 0 {
return 0, fmt.Errorf("invalid pgid returned: %d", pgid)
}
return pgid, nil
}
func (jm *JobCmd) GetExitInfo() (bool, *wshrpc.CommandJobCmdExitedData) {
jm.lock.Lock()
defer jm.lock.Unlock()
if !jm.processExited {
return false, nil
}
exitData := &wshrpc.CommandJobCmdExitedData{
JobId: WshCmdJobManager.JobId,
ExitCode: jm.exitCode,
ExitSignal: jm.exitSignal,
ExitTs: jm.exitTs,
}
if jm.exitErr != nil {
exitData.ExitErr = jm.exitErr.Error()
}
return true, exitData
}
// TODO set up a single input handler loop + queue so we dont need to hold the lock but still get synchronized in-order execution
func (jm *JobCmd) HandleInput(data wshrpc.CommandJobInputData) error {
jm.lock.Lock()
defer jm.lock.Unlock()
if jm.cmd == nil || jm.cmdPty == nil {
return fmt.Errorf("no active process")
}
if len(data.InputData64) > 0 {
inputBuf := make([]byte, base64.StdEncoding.DecodedLen(len(data.InputData64)))
nw, err := base64.StdEncoding.Decode(inputBuf, []byte(data.InputData64))
if err != nil {
return fmt.Errorf("error decoding input data: %w", err)
}
_, err = jm.cmdPty.Write(inputBuf[:nw])
if err != nil {
return fmt.Errorf("error writing to pty: %w", err)
}
}
if data.SigName != "" {
sig := normalizeSignal(data.SigName)
if sig != nil && jm.cmd.Process != nil {
err := jm.cmd.Process.Signal(sig)
if err != nil {
return fmt.Errorf("error sending signal: %w", err)
}
}
}
if data.TermSize != nil {
err := pty.Setsize(jm.cmdPty, &pty.Winsize{
Rows: uint16(data.TermSize.Rows),
Cols: uint16(data.TermSize.Cols),
})
if err != nil {
return fmt.Errorf("error setting terminal size: %w", err)
}
}
return nil
}
func (jm *JobCmd) TerminateByClosingPtyMaster() {
jm.lock.Lock()
defer jm.lock.Unlock()
if jm.ptyClosed {
return
}
if jm.cmdPty != nil {
jm.cmdPty.Close()
jm.ptyClosed = true
log.Printf("pty closed for job %s\n", jm.jobId)
}
}

View file

@ -0,0 +1,246 @@
// Copyright 2026, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobmanager
import (
"fmt"
"log"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"github.com/wavetermdev/waveterm/pkg/baseds"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wavejwt"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
const JobAccessTokenLabel = "Wave-JobAccessToken"
const JobManagerStartLabel = "Wave-JobManagerStart"
var WshCmdJobManager JobManager
type JobManager struct {
ClientId string
JobId string
Cmd *JobCmd
JwtPublicKey []byte
JobAuthToken string
StreamManager *StreamManager
lock sync.Mutex
attachedClient *MainServerConn
connectedStreamClient *MainServerConn
pendingStreamMeta *wshrpc.StreamMeta
}
func SetupJobManager(clientId string, jobId string, publicKeyBytes []byte, jobAuthToken string, readyFile *os.File) error {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
return fmt.Errorf("job manager only supported on unix systems, not %s", runtime.GOOS)
}
WshCmdJobManager.ClientId = clientId
WshCmdJobManager.JobId = jobId
WshCmdJobManager.JwtPublicKey = publicKeyBytes
WshCmdJobManager.JobAuthToken = jobAuthToken
WshCmdJobManager.StreamManager = MakeStreamManager()
err := wavejwt.SetPublicKey(publicKeyBytes)
if err != nil {
return fmt.Errorf("failed to set public key: %w", err)
}
err = MakeJobDomainSocket(clientId, jobId)
if err != nil {
return err
}
fmt.Fprintf(readyFile, JobManagerStartLabel+"\n")
readyFile.Close()
err = daemonize(clientId, jobId)
if err != nil {
return fmt.Errorf("failed to daemonize: %w", err)
}
return nil
}
func (jm *JobManager) GetCmd() *JobCmd {
jm.lock.Lock()
defer jm.lock.Unlock()
return jm.Cmd
}
func (jm *JobManager) sendJobExited() {
jm.lock.Lock()
attachedClient := jm.attachedClient
cmd := jm.Cmd
jm.lock.Unlock()
if attachedClient == nil {
log.Printf("sendJobExited: no attached client, exit notification not sent\n")
return
}
if attachedClient.WshRpc == nil {
log.Printf("sendJobExited: no wsh rpc connection, exit notification not sent\n")
return
}
if cmd == nil {
log.Printf("sendJobExited: no cmd, exit notification not sent\n")
return
}
exited, exitData := cmd.GetExitInfo()
if !exited || exitData == nil {
log.Printf("sendJobExited: process not exited yet\n")
return
}
exitCodeStr := "nil"
if exitData.ExitCode != nil {
exitCodeStr = fmt.Sprintf("%d", *exitData.ExitCode)
}
log.Printf("sendJobExited: sending exit notification to main server exitcode=%s signal=%s\n", exitCodeStr, exitData.ExitSignal)
err := wshclient.JobCmdExitedCommand(attachedClient.WshRpc, *exitData, nil)
if err != nil {
log.Printf("sendJobExited: error sending exit notification: %v\n", err)
}
}
func (jm *JobManager) GetJobAuthInfo() (string, string) {
jm.lock.Lock()
defer jm.lock.Unlock()
return jm.JobId, jm.JobAuthToken
}
func (jm *JobManager) IsJobStarted() bool {
jm.lock.Lock()
defer jm.lock.Unlock()
return jm.Cmd != nil
}
func (jm *JobManager) connectToStreamHelper_withlock(mainServerConn *MainServerConn, streamMeta wshrpc.StreamMeta, seq int64) (int64, error) {
rwndSize := int(streamMeta.RWnd)
if rwndSize < 0 {
return 0, fmt.Errorf("invalid rwnd size: %d", rwndSize)
}
if jm.connectedStreamClient != nil {
log.Printf("connectToStreamHelper: disconnecting existing client\n")
oldStreamId := jm.StreamManager.GetStreamId()
jm.StreamManager.ClientDisconnected()
if oldStreamId != "" {
mainServerConn.WshRpc.StreamBroker.DetachStreamWriter(oldStreamId)
log.Printf("connectToStreamHelper: detached old stream id=%s\n", oldStreamId)
}
jm.connectedStreamClient = nil
}
dataSender := &routedDataSender{
wshRpc: mainServerConn.WshRpc,
route: streamMeta.ReaderRouteId,
}
serverSeq, err := jm.StreamManager.ClientConnected(
streamMeta.Id,
dataSender,
rwndSize,
seq,
)
if err != nil {
return 0, fmt.Errorf("failed to connect client: %w", err)
}
jm.connectedStreamClient = mainServerConn
return serverSeq, nil
}
func (jm *JobManager) disconnectFromStreamHelper(mainServerConn *MainServerConn) {
jm.lock.Lock()
defer jm.lock.Unlock()
if jm.connectedStreamClient == nil || jm.connectedStreamClient != mainServerConn {
return
}
jm.StreamManager.ClientDisconnected()
jm.connectedStreamClient = nil
}
func GetJobSocketPath(jobId string) string {
socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid()))
return filepath.Join(socketDir, fmt.Sprintf("%s.sock", jobId))
}
func GetJobFilePath(clientId string, jobId string, extension string) string {
homeDir := wavebase.GetHomeDir()
jobDir := filepath.Join(homeDir, ".waveterm", "jobs", clientId)
return filepath.Join(jobDir, fmt.Sprintf("%s.%s", jobId, extension))
}
func MakeJobDomainSocket(clientId string, jobId string) error {
socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid()))
err := os.MkdirAll(socketDir, 0700)
if err != nil {
return fmt.Errorf("failed to create socket directory: %w", err)
}
socketPath := GetJobSocketPath(jobId)
os.Remove(socketPath)
listener, err := net.Listen("unix", socketPath)
if err != nil {
return fmt.Errorf("failed to listen on domain socket: %w", err)
}
go func() {
defer func() {
panichandler.PanicHandler("MakeJobDomainSocket:accept", recover())
listener.Close()
os.Remove(socketPath)
}()
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
return
}
go handleJobDomainSocketClient(conn)
}
}()
return nil
}
func handleJobDomainSocketClient(conn net.Conn) {
inputCh := make(chan baseds.RpcInputChType, wshutil.DefaultInputChSize)
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
serverImpl := &MainServerConn{
Conn: conn,
inputCh: inputCh,
}
rpcCtx := wshrpc.RpcContext{}
wshRpc := wshutil.MakeWshRpcWithChannels(inputCh, outputCh, rpcCtx, serverImpl, "job-domain")
serverImpl.WshRpc = wshRpc
defer WshCmdJobManager.disconnectFromStreamHelper(serverImpl)
go func() {
defer func() {
panichandler.PanicHandler("handleJobDomainSocketClient:AdaptOutputChToStream", recover())
}()
defer serverImpl.Close()
writeErr := wshutil.AdaptOutputChToStream(outputCh, conn)
if writeErr != nil {
log.Printf("error writing to domain socket: %v\n", writeErr)
}
}()
go func() {
defer func() {
panichandler.PanicHandler("handleJobDomainSocketClient:AdaptStreamToMsgCh", recover())
}()
defer serverImpl.Close()
wshutil.AdaptStreamToMsgCh(conn, inputCh)
}()
_ = wshRpc
}

View file

@ -0,0 +1,95 @@
// Copyright 2026, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
//go:build unix
package jobmanager
import (
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"golang.org/x/sys/unix"
)
func getProcessGroupId(pid int) (int, error) {
pgid, err := syscall.Getpgid(pid)
if err != nil {
return 0, err
}
return pgid, nil
}
func normalizeSignal(sigName string) os.Signal {
sigName = strings.ToUpper(sigName)
sigName = strings.TrimPrefix(sigName, "SIG")
switch sigName {
case "HUP":
return syscall.SIGHUP
case "INT":
return syscall.SIGINT
case "QUIT":
return syscall.SIGQUIT
case "KILL":
return syscall.SIGKILL
case "TERM":
return syscall.SIGTERM
case "USR1":
return syscall.SIGUSR1
case "USR2":
return syscall.SIGUSR2
case "STOP":
return syscall.SIGSTOP
case "CONT":
return syscall.SIGCONT
default:
return nil
}
}
func daemonize(clientId string, jobId string) error {
_, err := unix.Setsid()
if err != nil {
return fmt.Errorf("failed to setsid: %w", err)
}
devNull, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("failed to open /dev/null: %w", err)
}
err = unix.Dup2(int(devNull.Fd()), int(os.Stdin.Fd()))
if err != nil {
return fmt.Errorf("failed to dup2 stdin: %w", err)
}
devNull.Close()
logPath := GetJobFilePath(clientId, jobId, "log")
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
return fmt.Errorf("failed to open log file: %w", err)
}
err = unix.Dup2(int(logFile.Fd()), int(os.Stdout.Fd()))
if err != nil {
return fmt.Errorf("failed to dup2 stdout: %w", err)
}
err = unix.Dup2(int(logFile.Fd()), int(os.Stderr.Fd()))
if err != nil {
return fmt.Errorf("failed to dup2 stderr: %w", err)
}
log.SetOutput(logFile)
log.Printf("job manager daemonized, logging to %s\n", logPath)
signal.Ignore(syscall.SIGHUP)
return nil
}
func setCloseOnExec(fd int) {
unix.CloseOnExec(fd)
}

View file

@ -0,0 +1,29 @@
// Copyright 2026, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
//go:build windows
package jobmanager
import (
"fmt"
"os"
)
func getProcessGroupId(pid int) (int, error) {
return 0, fmt.Errorf("process group id not supported on windows")
}
func normalizeSignal(sigName string) os.Signal {
return nil
}
func daemonize(clientId string, jobId string) error {
return fmt.Errorf("daemonize not supported on windows")
}
func setupJobManagerSignalHandlers() {
}
func setCloseOnExec(fd int) {
}

View file

@ -0,0 +1,292 @@
// Copyright 2026, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobmanager
import (
"context"
"fmt"
"log"
"net"
"os"
"sync"
"sync/atomic"
"github.com/shirou/gopsutil/v4/process"
"github.com/wavetermdev/waveterm/pkg/baseds"
"github.com/wavetermdev/waveterm/pkg/wavejwt"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
type MainServerConn struct {
PeerAuthenticated atomic.Bool
SelfAuthenticated atomic.Bool
WshRpc *wshutil.WshRpc
Conn net.Conn
inputCh chan baseds.RpcInputChType
closeOnce sync.Once
}
func (*MainServerConn) WshServerImpl() {}
func (msc *MainServerConn) Close() {
msc.closeOnce.Do(func() {
msc.Conn.Close()
close(msc.inputCh)
})
}
type routedDataSender struct {
wshRpc *wshutil.WshRpc
route string
}
func (rds *routedDataSender) SendData(dataPk wshrpc.CommandStreamData) {
log.Printf("SendData: sending seq=%d, len=%d, eof=%t, error=%s, route=%s",
dataPk.Seq, len(dataPk.Data64), dataPk.Eof, dataPk.Error, rds.route)
err := wshclient.StreamDataCommand(rds.wshRpc, dataPk, &wshrpc.RpcOpts{NoResponse: true, Route: rds.route})
if err != nil {
log.Printf("SendData: error sending stream data: %v\n", err)
}
}
func (msc *MainServerConn) authenticateSelfToServer(jobAuthToken string) error {
jobId, _ := WshCmdJobManager.GetJobAuthInfo()
authData := wshrpc.CommandAuthenticateJobManagerData{
JobId: jobId,
JobAuthToken: jobAuthToken,
}
err := wshclient.AuthenticateJobManagerCommand(msc.WshRpc, authData, &wshrpc.RpcOpts{Route: wshutil.ControlRoute})
if err != nil {
log.Printf("authenticateSelfToServer: failed to authenticate to server: %v\n", err)
return fmt.Errorf("failed to authenticate to server: %w", err)
}
msc.SelfAuthenticated.Store(true)
log.Printf("authenticateSelfToServer: successfully authenticated to server\n")
return nil
}
func (msc *MainServerConn) AuthenticateToJobManagerCommand(ctx context.Context, data wshrpc.CommandAuthenticateToJobData) error {
jobId, jobAuthToken := WshCmdJobManager.GetJobAuthInfo()
claims, err := wavejwt.ValidateAndExtract(data.JobAccessToken)
if err != nil {
log.Printf("AuthenticateToJobManager: failed to validate token: %v\n", err)
return fmt.Errorf("failed to validate token: %w", err)
}
if !claims.MainServer {
log.Printf("AuthenticateToJobManager: MainServer claim not set\n")
return fmt.Errorf("MainServer claim not set")
}
if claims.JobId != jobId {
log.Printf("AuthenticateToJobManager: JobId mismatch: expected %s, got %s\n", jobId, claims.JobId)
return fmt.Errorf("JobId mismatch")
}
msc.PeerAuthenticated.Store(true)
log.Printf("AuthenticateToJobManager: authentication successful for JobId=%s\n", claims.JobId)
err = msc.authenticateSelfToServer(jobAuthToken)
if err != nil {
msc.PeerAuthenticated.Store(false)
return err
}
WshCmdJobManager.lock.Lock()
defer WshCmdJobManager.lock.Unlock()
if WshCmdJobManager.attachedClient != nil {
log.Printf("AuthenticateToJobManager: kicking out existing client\n")
WshCmdJobManager.attachedClient.Close()
}
WshCmdJobManager.attachedClient = msc
return nil
}
func (msc *MainServerConn) StartJobCommand(ctx context.Context, data wshrpc.CommandStartJobData) (*wshrpc.CommandStartJobRtnData, error) {
log.Printf("StartJobCommand: received command=%s args=%v", data.Cmd, data.Args)
if !msc.PeerAuthenticated.Load() {
log.Printf("StartJobCommand: not authenticated")
return nil, fmt.Errorf("not authenticated")
}
if WshCmdJobManager.IsJobStarted() {
log.Printf("StartJobCommand: job already started")
return nil, fmt.Errorf("job already started")
}
WshCmdJobManager.lock.Lock()
defer WshCmdJobManager.lock.Unlock()
if WshCmdJobManager.Cmd != nil {
log.Printf("StartJobCommand: job already started (double check)")
return nil, fmt.Errorf("job already started")
}
cmdDef := CmdDef{
Cmd: data.Cmd,
Args: data.Args,
Env: data.Env,
TermSize: data.TermSize,
}
log.Printf("StartJobCommand: creating job cmd for jobid=%s", WshCmdJobManager.JobId)
jobCmd, err := MakeJobCmd(WshCmdJobManager.JobId, cmdDef)
if err != nil {
log.Printf("StartJobCommand: failed to make job cmd: %v", err)
return nil, fmt.Errorf("failed to start job: %w", err)
}
WshCmdJobManager.Cmd = jobCmd
log.Printf("StartJobCommand: job cmd created successfully")
if data.StreamMeta != nil {
serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, *data.StreamMeta, 0)
if err != nil {
return nil, fmt.Errorf("failed to connect stream: %w", err)
}
err = msc.WshRpc.StreamBroker.AttachStreamWriter(data.StreamMeta, WshCmdJobManager.StreamManager)
if err != nil {
return nil, fmt.Errorf("failed to attach stream writer: %w", err)
}
log.Printf("StartJob: connected stream streamid=%s serverSeq=%d\n", data.StreamMeta.Id, serverSeq)
}
_, cmdPty := jobCmd.GetCmd()
if cmdPty != nil {
log.Printf("StartJobCommand: attaching pty reader to stream manager")
err = WshCmdJobManager.StreamManager.AttachReader(cmdPty)
if err != nil {
log.Printf("StartJobCommand: failed to attach reader: %v", err)
return nil, fmt.Errorf("failed to attach reader to stream manager: %w", err)
}
log.Printf("StartJobCommand: pty reader attached successfully")
} else {
log.Printf("StartJobCommand: no pty to attach")
}
cmd, _ := jobCmd.GetCmd()
if cmd == nil || cmd.Process == nil {
log.Printf("StartJobCommand: cmd or process is nil")
return nil, fmt.Errorf("cmd or process is nil")
}
cmdPid := cmd.Process.Pid
cmdProc, err := process.NewProcess(int32(cmdPid))
if err != nil {
log.Printf("StartJobCommand: failed to get cmd process: %v", err)
return nil, fmt.Errorf("failed to get cmd process: %w", err)
}
cmdStartTs, err := cmdProc.CreateTime()
if err != nil {
log.Printf("StartJobCommand: failed to get cmd start time: %v", err)
return nil, fmt.Errorf("failed to get cmd start time: %w", err)
}
jobManagerPid := os.Getpid()
jobManagerProc, err := process.NewProcess(int32(jobManagerPid))
if err != nil {
log.Printf("StartJobCommand: failed to get job manager process: %v", err)
return nil, fmt.Errorf("failed to get job manager process: %w", err)
}
jobManagerStartTs, err := jobManagerProc.CreateTime()
if err != nil {
log.Printf("StartJobCommand: failed to get job manager start time: %v", err)
return nil, fmt.Errorf("failed to get job manager start time: %w", err)
}
log.Printf("StartJobCommand: job started successfully cmdPid=%d cmdStartTs=%d jobManagerPid=%d jobManagerStartTs=%d", cmdPid, cmdStartTs, jobManagerPid, jobManagerStartTs)
return &wshrpc.CommandStartJobRtnData{
CmdPid: cmdPid,
CmdStartTs: cmdStartTs,
JobManagerPid: jobManagerPid,
JobManagerStartTs: jobManagerStartTs,
}, nil
}
func (msc *MainServerConn) JobPrepareConnectCommand(ctx context.Context, data wshrpc.CommandJobPrepareConnectData) (*wshrpc.CommandJobConnectRtnData, error) {
WshCmdJobManager.lock.Lock()
defer WshCmdJobManager.lock.Unlock()
if !msc.PeerAuthenticated.Load() {
return nil, fmt.Errorf("peer not authenticated")
}
if !msc.SelfAuthenticated.Load() {
return nil, fmt.Errorf("not authenticated to server")
}
if WshCmdJobManager.Cmd == nil {
return nil, fmt.Errorf("job not started")
}
rtnData := &wshrpc.CommandJobConnectRtnData{}
streamDone, streamError := WshCmdJobManager.StreamManager.GetStreamDoneInfo()
if streamDone {
log.Printf("JobPrepareConnect: stream already done, skipping connection streamError=%q\n", streamError)
rtnData.Seq = data.Seq
rtnData.StreamDone = true
rtnData.StreamError = streamError
} else {
corkedStreamMeta := data.StreamMeta
corkedStreamMeta.RWnd = 0
serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, corkedStreamMeta, data.Seq)
if err != nil {
return nil, err
}
WshCmdJobManager.pendingStreamMeta = &data.StreamMeta
rtnData.Seq = serverSeq
rtnData.StreamDone = false
}
hasExited, exitData := WshCmdJobManager.Cmd.GetExitInfo()
if hasExited && exitData != nil {
rtnData.HasExited = true
rtnData.ExitCode = exitData.ExitCode
rtnData.ExitSignal = exitData.ExitSignal
rtnData.ExitErr = exitData.ExitErr
}
log.Printf("JobPrepareConnect: streamid=%s clientSeq=%d serverSeq=%d streamDone=%v streamError=%q hasExited=%v\n", data.StreamMeta.Id, data.Seq, rtnData.Seq, rtnData.StreamDone, rtnData.StreamError, hasExited)
return rtnData, nil
}
func (msc *MainServerConn) JobStartStreamCommand(ctx context.Context, data wshrpc.CommandJobStartStreamData) error {
WshCmdJobManager.lock.Lock()
defer WshCmdJobManager.lock.Unlock()
if !msc.PeerAuthenticated.Load() {
return fmt.Errorf("not authenticated")
}
if WshCmdJobManager.Cmd == nil {
return fmt.Errorf("job not started")
}
if WshCmdJobManager.pendingStreamMeta == nil {
return fmt.Errorf("no pending stream (call JobPrepareConnect first)")
}
err := msc.WshRpc.StreamBroker.AttachStreamWriter(WshCmdJobManager.pendingStreamMeta, WshCmdJobManager.StreamManager)
if err != nil {
return fmt.Errorf("failed to attach stream writer: %w", err)
}
err = WshCmdJobManager.StreamManager.SetRwndSize(int(WshCmdJobManager.pendingStreamMeta.RWnd))
if err != nil {
return fmt.Errorf("failed to set rwnd size: %w", err)
}
log.Printf("JobStartStream: streamid=%s rwnd=%d streaming started\n", WshCmdJobManager.pendingStreamMeta.Id, WshCmdJobManager.pendingStreamMeta.RWnd)
WshCmdJobManager.pendingStreamMeta = nil
return nil
}
func (msc *MainServerConn) JobInputCommand(ctx context.Context, data wshrpc.CommandJobInputData) error {
WshCmdJobManager.lock.Lock()
defer WshCmdJobManager.lock.Unlock()
if !msc.PeerAuthenticated.Load() {
return fmt.Errorf("not authenticated")
}
if WshCmdJobManager.Cmd == nil {
return fmt.Errorf("job not started")
}
return WshCmdJobManager.Cmd.HandleInput(data)
}

View file

@ -0,0 +1,419 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobmanager
import (
"encoding/base64"
"fmt"
"io"
"log"
"sync"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
const (
CwndSize = 64 * 1024 // 64 KB window for connected mode
CirBufSize = 2 * 1024 * 1024 // 2 MB max buffer size
DisconnReadSz = 4 * 1024 // 4 KB read chunks when disconnected
MaxPacketSize = 4 * 1024 // 4 KB max data per packet
)
type DataSender interface {
SendData(dataPk wshrpc.CommandStreamData)
}
type streamTerminalEvent struct {
isEof bool
err string
}
// StreamManager handles PTY output buffering with ACK-based flow control
type StreamManager struct {
lock sync.Mutex
drainCond *sync.Cond
streamId string
// this is the data read from the attached reader
buf *CirBuf
terminalEvent *streamTerminalEvent
eofPos int64 // fixed position when EOF/error occurs (-1 if not yet)
reader io.Reader
cwndSize int
rwndSize int
// invariant: if connected is true, dataSender is non-nil
connected bool
dataSender DataSender
// unacked state (reset on disconnect)
sentNotAcked int64
terminalEventSent bool
// terminal state - once true, stream is complete
terminalEventAcked bool
closed bool
}
func MakeStreamManager() *StreamManager {
return MakeStreamManagerWithSizes(CwndSize, CirBufSize)
}
func MakeStreamManagerWithSizes(cwndSize, cirbufSize int) *StreamManager {
sm := &StreamManager{
buf: MakeCirBuf(cirbufSize, true),
eofPos: -1,
cwndSize: cwndSize,
rwndSize: cwndSize,
}
sm.drainCond = sync.NewCond(&sm.lock)
go sm.senderLoop()
return sm
}
// AttachReader starts reading from the given reader
func (sm *StreamManager) AttachReader(r io.Reader) error {
sm.lock.Lock()
defer sm.lock.Unlock()
if sm.reader != nil {
return fmt.Errorf("reader already attached")
}
sm.reader = r
go sm.readLoop()
return nil
}
// ClientConnected transitions to CONNECTED mode
func (sm *StreamManager) ClientConnected(streamId string, dataSender DataSender, rwndSize int, clientSeq int64) (int64, error) {
sm.lock.Lock()
defer sm.lock.Unlock()
if sm.closed || sm.terminalEventAcked {
return 0, fmt.Errorf("stream is closed")
}
if sm.connected {
return 0, fmt.Errorf("client already connected")
}
if dataSender == nil {
return 0, fmt.Errorf("dataSender cannot be nil")
}
headPos := sm.buf.HeadPos()
if clientSeq > headPos {
bytesToConsume := int(clientSeq - headPos)
available := sm.buf.Size()
if bytesToConsume > available {
return 0, fmt.Errorf("client seq %d is beyond our stream end (head=%d, size=%d)", clientSeq, headPos, available)
}
if bytesToConsume > 0 {
if err := sm.buf.Consume(bytesToConsume); err != nil {
return 0, fmt.Errorf("failed to consume buffer: %w", err)
}
headPos = sm.buf.HeadPos()
}
}
sm.streamId = streamId
sm.dataSender = dataSender
sm.connected = true
sm.rwndSize = rwndSize
sm.sentNotAcked = 0
effectiveWindow := sm.cwndSize
if sm.rwndSize < effectiveWindow {
effectiveWindow = sm.rwndSize
}
sm.buf.SetEffectiveWindow(true, effectiveWindow)
sm.drainCond.Signal()
startSeq := headPos
if clientSeq > startSeq {
startSeq = clientSeq
}
return startSeq, nil
}
// GetStreamId returns the current stream ID (safe to call with lock held by caller)
func (sm *StreamManager) GetStreamId() string {
sm.lock.Lock()
defer sm.lock.Unlock()
return sm.streamId
}
// GetStreamDoneInfo returns whether the stream is done and the error if there was one.
// The error is only meaningful if done=true, as the error is delivered as part of the stream otherwise.
func (sm *StreamManager) GetStreamDoneInfo() (done bool, streamError string) {
sm.lock.Lock()
defer sm.lock.Unlock()
if !sm.terminalEventAcked {
return false, ""
}
if sm.terminalEvent != nil && !sm.terminalEvent.isEof {
return true, sm.terminalEvent.err
}
return true, ""
}
// ClientDisconnected transitions to DISCONNECTED mode
func (sm *StreamManager) ClientDisconnected() {
sm.lock.Lock()
defer sm.lock.Unlock()
if !sm.connected {
return
}
sm.connected = false
sm.dataSender = nil
sm.sentNotAcked = 0
if !sm.terminalEventAcked {
sm.terminalEventSent = false
}
sm.buf.SetEffectiveWindow(false, CirBufSize)
sm.drainCond.Signal()
}
// RecvAck processes an ACK from the client
// must be connected, and streamid must match
func (sm *StreamManager) RecvAck(ackPk wshrpc.CommandStreamAckData) {
sm.lock.Lock()
defer sm.lock.Unlock()
if !sm.connected || ackPk.Id != sm.streamId {
return
}
if ackPk.Fin {
sm.terminalEventAcked = true
sm.drainCond.Signal()
return
}
seq := ackPk.Seq
headPos := sm.buf.HeadPos()
if seq < headPos {
return
}
ackedBytes := seq - headPos
if ackedBytes > sm.sentNotAcked {
return
}
if ackedBytes > 0 {
if err := sm.buf.Consume(int(ackedBytes)); err != nil {
return
}
sm.sentNotAcked -= ackedBytes
}
prevRwnd := sm.rwndSize
sm.rwndSize = int(ackPk.RWnd)
effectiveWindow := sm.cwndSize
if sm.rwndSize < effectiveWindow {
effectiveWindow = sm.rwndSize
}
sm.buf.SetEffectiveWindow(true, effectiveWindow)
if sm.rwndSize > prevRwnd || ackedBytes > 0 {
sm.drainCond.Signal()
}
}
// SetRwndSize dynamically updates the receive window size
func (sm *StreamManager) SetRwndSize(rwndSize int) error {
sm.lock.Lock()
defer sm.lock.Unlock()
if rwndSize < 0 {
return fmt.Errorf("rwndSize cannot be negative")
}
if !sm.connected {
return fmt.Errorf("not connected")
}
sm.rwndSize = rwndSize
effectiveWindow := sm.cwndSize
if sm.rwndSize < effectiveWindow {
effectiveWindow = sm.rwndSize
}
sm.buf.SetEffectiveWindow(true, effectiveWindow)
sm.drainCond.Signal()
return nil
}
// Close shuts down the sender loop. The reader loop will exit on its next iteration
// or when the underlying reader is closed.
func (sm *StreamManager) Close() {
sm.lock.Lock()
defer sm.lock.Unlock()
sm.closed = true
sm.drainCond.Signal()
}
// readLoop is the main read goroutine
func (sm *StreamManager) readLoop() {
readBuf := make([]byte, MaxPacketSize)
for {
sm.lock.Lock()
closed := sm.closed
sm.lock.Unlock()
if closed {
return
}
n, err := sm.reader.Read(readBuf)
log.Printf("readLoop: read %d bytes from PTY, err=%v", n, err)
if n > 0 {
sm.handleReadData(readBuf[:n])
}
if err != nil {
if err == io.EOF {
sm.handleEOF()
} else {
sm.handleError(err)
}
return
}
}
}
func (sm *StreamManager) handleReadData(data []byte) {
log.Printf("handleReadData: writing %d bytes to buffer", len(data))
sm.buf.Write(data)
sm.lock.Lock()
defer sm.lock.Unlock()
log.Printf("handleReadData: buffer size=%d, connected=%t, signaling=%t", sm.buf.Size(), sm.connected, sm.connected)
if sm.connected {
sm.drainCond.Signal()
}
}
func (sm *StreamManager) handleEOF() {
sm.lock.Lock()
defer sm.lock.Unlock()
log.Printf("handleEOF: PTY reached EOF, totalSize=%d", sm.buf.TotalSize())
sm.eofPos = sm.buf.TotalSize()
sm.terminalEvent = &streamTerminalEvent{isEof: true}
sm.drainCond.Signal()
}
func (sm *StreamManager) handleError(err error) {
sm.lock.Lock()
defer sm.lock.Unlock()
log.Printf("handleError: PTY error=%v, totalSize=%d", err, sm.buf.TotalSize())
sm.eofPos = sm.buf.TotalSize()
sm.terminalEvent = &streamTerminalEvent{err: err.Error()}
sm.drainCond.Signal()
}
func (sm *StreamManager) senderLoop() {
for {
done, pkt, sender := sm.prepareNextPacket()
if done {
return
}
if pkt == nil {
continue
}
sender.SendData(*pkt)
}
}
func (sm *StreamManager) prepareNextPacket() (done bool, pkt *wshrpc.CommandStreamData, sender DataSender) {
sm.lock.Lock()
defer sm.lock.Unlock()
available := sm.buf.Size()
log.Printf("prepareNextPacket: connected=%t, available=%d, closed=%t, terminalEventAcked=%t, terminalEvent=%v",
sm.connected, available, sm.closed, sm.terminalEventAcked, sm.terminalEvent != nil)
if sm.closed || sm.terminalEventAcked {
return true, nil, nil
}
if !sm.connected {
log.Printf("prepareNextPacket: waiting for connection")
sm.drainCond.Wait()
return false, nil, nil
}
if available == 0 {
if sm.terminalEvent != nil && !sm.terminalEventSent {
log.Printf("prepareNextPacket: preparing terminal packet")
return false, sm.prepareTerminalPacket(), sm.dataSender
}
log.Printf("prepareNextPacket: no data available, waiting")
sm.drainCond.Wait()
return false, nil, nil
}
effectiveRwnd := sm.rwndSize
if sm.cwndSize < effectiveRwnd {
effectiveRwnd = sm.cwndSize
}
availableToSend := int64(effectiveRwnd) - sm.sentNotAcked
if availableToSend <= 0 {
sm.drainCond.Wait()
return false, nil, nil
}
peekSize := int(availableToSend)
if peekSize > MaxPacketSize {
peekSize = MaxPacketSize
}
if peekSize > available {
peekSize = available
}
data := make([]byte, peekSize)
n := sm.buf.PeekDataAt(int(sm.sentNotAcked), data)
if n == 0 {
log.Printf("prepareNextPacket: PeekDataAt returned 0 bytes, waiting for ACK")
sm.drainCond.Wait()
return false, nil, nil
}
data = data[:n]
seq := sm.buf.HeadPos() + sm.sentNotAcked
sm.sentNotAcked += int64(n)
log.Printf("prepareNextPacket: sending packet seq=%d, len=%d bytes", seq, n)
return false, &wshrpc.CommandStreamData{
Id: sm.streamId,
Seq: seq,
Data64: base64.StdEncoding.EncodeToString(data),
}, sm.dataSender
}
func (sm *StreamManager) prepareTerminalPacket() *wshrpc.CommandStreamData {
if sm.terminalEventSent || sm.terminalEvent == nil {
return nil
}
pkt := &wshrpc.CommandStreamData{
Id: sm.streamId,
Seq: sm.eofPos,
}
if sm.terminalEvent.isEof {
pkt.Eof = true
} else {
pkt.Error = sm.terminalEvent.err
}
sm.terminalEventSent = true
return pkt
}

View file

@ -0,0 +1,348 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package jobmanager
import (
"encoding/base64"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type testWriter struct {
mu sync.Mutex
packets []wshrpc.CommandStreamData
}
func (tw *testWriter) SendData(pkt wshrpc.CommandStreamData) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.packets = append(tw.packets, pkt)
}
func (tw *testWriter) GetPackets() []wshrpc.CommandStreamData {
tw.mu.Lock()
defer tw.mu.Unlock()
result := make([]wshrpc.CommandStreamData, len(tw.packets))
copy(result, tw.packets)
return result
}
func (tw *testWriter) Clear() {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.packets = nil
}
func decodeData(data64 string) string {
decoded, _ := base64.StdEncoding.DecodeString(data64)
return string(decoded)
}
func TestBasicDisconnectedMode(t *testing.T) {
tw := &testWriter{}
sm := MakeStreamManager()
reader := strings.NewReader("hello world")
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
packets := tw.GetPackets()
if len(packets) > 0 {
t.Errorf("Expected no packets in DISCONNECTED mode without client, got %d", len(packets))
}
sm.Close()
}
func TestConnectedModeBasicFlow(t *testing.T) {
tw := &testWriter{}
sm := MakeStreamManager()
reader := strings.NewReader("hello")
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
if err != nil {
t.Fatalf("ClientConnected failed: %v", err)
}
time.Sleep(100 * time.Millisecond)
packets := tw.GetPackets()
if len(packets) == 0 {
t.Fatal("Expected packets after ClientConnected")
}
// Verify we got the data
allData := ""
for _, pkt := range packets {
if pkt.Data64 != "" {
allData += decodeData(pkt.Data64)
}
}
if allData != "hello" {
t.Errorf("Expected 'hello', got '%s'", allData)
}
// Send ACK
sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: 5, RWnd: CwndSize})
time.Sleep(50 * time.Millisecond)
// Check for EOF packet
packets = tw.GetPackets()
hasEof := false
for _, pkt := range packets {
if pkt.Eof {
hasEof = true
}
}
if !hasEof {
t.Error("Expected EOF packet after ACKing all data")
}
sm.Close()
}
func TestDisconnectedToConnectedTransition(t *testing.T) {
tw := &testWriter{}
sm := MakeStreamManager()
reader := strings.NewReader("test data")
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
time.Sleep(100 * time.Millisecond)
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
if err != nil {
t.Fatalf("ClientConnected failed: %v", err)
}
time.Sleep(100 * time.Millisecond)
packets := tw.GetPackets()
if len(packets) == 0 {
t.Fatal("Expected cirbuf drain after connect")
}
allData := ""
for _, pkt := range packets {
if pkt.Data64 != "" {
allData += decodeData(pkt.Data64)
}
}
if allData != "test data" {
t.Errorf("Expected 'test data', got '%s'", allData)
}
sm.Close()
}
func TestConnectedToDisconnectedTransition(t *testing.T) {
tw := &testWriter{}
sm := MakeStreamManager()
reader := &slowReader{data: []byte("slow data"), delay: 50 * time.Millisecond}
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
if err != nil {
t.Fatalf("ClientConnected failed: %v", err)
}
time.Sleep(150 * time.Millisecond)
sm.ClientDisconnected()
time.Sleep(100 * time.Millisecond)
sm.Close()
}
func TestFlowControl(t *testing.T) {
cwndSize := 1024
tw := &testWriter{}
sm := MakeStreamManagerWithSizes(cwndSize, 8*1024)
largeData := strings.Repeat("x", cwndSize+500)
reader := strings.NewReader(largeData)
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
_, err = sm.ClientConnected("1", tw, cwndSize, 0)
if err != nil {
t.Fatalf("ClientConnected failed: %v", err)
}
time.Sleep(100 * time.Millisecond)
packets := tw.GetPackets()
totalData := 0
for _, pkt := range packets {
if pkt.Data64 != "" {
decoded, _ := base64.StdEncoding.DecodeString(pkt.Data64)
totalData += len(decoded)
}
}
if totalData > cwndSize {
t.Errorf("Sent %d bytes without ACK, exceeds cwnd size %d", totalData, cwndSize)
}
sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: int64(totalData), RWnd: int64(cwndSize)})
time.Sleep(100 * time.Millisecond)
sm.Close()
}
func TestSequenceNumbering(t *testing.T) {
tw := &testWriter{}
sm := MakeStreamManager()
reader := strings.NewReader("abcdefghij")
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
if err != nil {
t.Fatalf("ClientConnected failed: %v", err)
}
time.Sleep(100 * time.Millisecond)
packets := tw.GetPackets()
if len(packets) == 0 {
t.Fatal("Expected packets")
}
expectedSeq := int64(0)
for _, pkt := range packets {
if pkt.Data64 == "" {
continue
}
if pkt.Seq != expectedSeq {
t.Errorf("Expected seq %d, got %d", expectedSeq, pkt.Seq)
}
decoded, _ := base64.StdEncoding.DecodeString(pkt.Data64)
expectedSeq += int64(len(decoded))
}
sm.Close()
}
func TestTerminalEventOrdering(t *testing.T) {
tw := &testWriter{}
sm := MakeStreamManager()
reader := strings.NewReader("data")
err := sm.AttachReader(reader)
if err != nil {
t.Fatalf("AttachReader failed: %v", err)
}
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
if err != nil {
t.Fatalf("ClientConnected failed: %v", err)
}
time.Sleep(100 * time.Millisecond)
packets := tw.GetPackets()
if len(packets) == 0 {
t.Fatal("Expected data packets")
}
hasData := false
hasEof := false
eofSeq := int64(-1)
for _, pkt := range packets {
if pkt.Data64 != "" {
hasData = true
}
if pkt.Eof {
hasEof = true
eofSeq = pkt.Seq
}
}
if !hasData {
t.Error("Expected data packet")
}
if hasEof {
t.Error("Should not have EOF before ACK")
}
sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: 4, RWnd: CwndSize})
time.Sleep(50 * time.Millisecond)
packets = tw.GetPackets()
hasEof = false
for _, pkt := range packets {
if pkt.Eof {
hasEof = true
eofSeq = pkt.Seq
}
}
if !hasEof {
t.Error("Expected EOF after ACKing all data")
}
if eofSeq != 4 {
t.Errorf("Expected EOF at seq 4, got %d", eofSeq)
}
sm.Close()
}
type slowReader struct {
data []byte
pos int
delay time.Duration
}
func (sr *slowReader) Read(p []byte) (n int, err error) {
if sr.pos >= len(sr.data) {
return 0, io.EOF
}
time.Sleep(sr.delay)
n = copy(p, sr.data[sr.pos:])
sr.pos += n
return n, nil
}

View file

@ -85,7 +85,7 @@ type SSHConn struct {
var ConnServerCmdTemplate = strings.TrimSpace(
strings.Join([]string{
"%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm; exit 0);",
"exec %s connserver --conn %s %s",
"exec %s connserver --conn %s %s %s",
}, "\n"))
func IsLocalConnName(connName string) bool {
@ -285,8 +285,9 @@ func (conn *SSHConn) GetConfigShellPath() string {
// returns (needsInstall, clientVersion, osArchStr, error)
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
// if clientVersion is set, then no osArchStr will be returned
func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) {
conn.Infof(ctx, "running StartConnServer...\n")
// if useRouterMode is true, will start connserver with --router-domainsocket flag
func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool, useRouterMode bool) (bool, string, string, error) {
conn.Infof(ctx, "running StartConnServer (routerMode=%v)...\n", useRouterMode)
allowed := WithLockRtn(conn, func() bool {
return conn.Status == Status_Connecting
})
@ -296,10 +297,19 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
client := conn.GetClient()
wshPath := conn.getWshPath()
sockName := conn.GetDomainSocketName()
rpcCtx := wshrpc.RpcContext{
RouteId: wshutil.MakeConnectionRouteId(conn.GetName()),
SockName: sockName,
Conn: conn.GetName(),
var rpcCtx wshrpc.RpcContext
if useRouterMode {
rpcCtx = wshrpc.RpcContext{
IsRouter: true,
SockName: sockName,
Conn: conn.GetName(),
}
} else {
rpcCtx = wshrpc.RpcContext{
RouteId: wshutil.MakeConnectionRouteId(conn.GetName()),
SockName: sockName,
Conn: conn.GetName(),
}
}
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx)
if err != nil {
@ -321,7 +331,11 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
if wavebase.IsDevMode() {
devFlag = "--dev"
}
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag)
routerFlag := ""
if useRouterMode {
routerFlag = "--router-domainsocket"
}
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag, routerFlag)
log.Printf("starting conn controller: %q\n", cmdStr)
shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr))
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
@ -702,7 +716,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
err = fmt.Errorf("error opening domain socket listener: %w", err)
return WshCheckResult{NoWshReason: "error opening domain socket", NoWshCode: NoWshCode_DomainSocketError, WshError: err}
}
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false)
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false, false)
if err != nil {
conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
err = fmt.Errorf("error starting conn server: %w", err)
@ -716,7 +730,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
err = fmt.Errorf("error installing wsh: %w", err)
return WshCheckResult{NoWshReason: "error installing wsh/connserver", NoWshCode: NoWshCode_InstallError, WshError: err}
}
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true)
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true, false)
if err != nil {
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
err = fmt.Errorf("error starting conn server (after install): %w", err)
@ -842,23 +856,38 @@ func (conn *SSHConn) ClearWshError() {
})
}
func getConnInternal(opts *remote.SSHOpts) *SSHConn {
func getConnInternal(opts *remote.SSHOpts, createIfNotExists bool) *SSHConn {
globalLock.Lock()
defer globalLock.Unlock()
rtn := clientControllerMap[*opts]
if rtn == nil {
if rtn == nil && createIfNotExists {
rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, WshEnabled: &atomic.Bool{}, Opts: opts, HasWaiter: &atomic.Bool{}}
clientControllerMap[*opts] = rtn
}
return rtn
}
// does NOT connect, can return nil if connection does not exist
// does NOT connect, does not return nil
func GetConn(opts *remote.SSHOpts) *SSHConn {
conn := getConnInternal(opts)
conn := getConnInternal(opts, true)
return conn
}
func IsConnected(connName string) (bool, error) {
if IsLocalConnName(connName) {
return true, nil
}
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return false, fmt.Errorf("error parsing connection name: %w", err)
}
conn := getConnInternal(connOpts, false)
if conn == nil {
return false, nil
}
return conn.GetStatus() == Status_Connected, nil
}
// Convenience function for ensuring a connection is established
func EnsureConnection(ctx context.Context, connName string) error {
if IsLocalConnName(connName) {
@ -888,7 +917,7 @@ func EnsureConnection(ctx context.Context, connName string) error {
}
func DisconnectClient(opts *remote.SSHOpts) error {
conn := getConnInternal(opts)
conn := getConnInternal(opts, false)
if conn == nil {
return fmt.Errorf("client %q not found", opts.String())
}

View file

@ -2,6 +2,7 @@ package streamclient
import (
"bytes"
"encoding/base64"
"io"
"testing"
"time"
@ -32,8 +33,8 @@ func (ft *fakeTransport) SendAck(ackPk wshrpc.CommandStreamAckData) {
func TestBasicReadWrite(t *testing.T) {
transport := newFakeTransport()
reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)
reader := NewReader("1", 1024, transport)
writer := NewWriter("1", 1024, transport)
go func() {
for dataPk := range transport.dataChan {
@ -72,8 +73,8 @@ func TestBasicReadWrite(t *testing.T) {
func TestEOF(t *testing.T) {
transport := newFakeTransport()
reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)
reader := NewReader("1", 1024, transport)
writer := NewWriter("1", 1024, transport)
go func() {
for dataPk := range transport.dataChan {
@ -110,8 +111,8 @@ func TestFlowControl(t *testing.T) {
smallWindow := int64(10)
transport := newFakeTransport()
reader := NewReader(1, smallWindow, transport)
writer := NewWriter(1, smallWindow, transport)
reader := NewReader("1", smallWindow, transport)
writer := NewWriter("1", smallWindow, transport)
go func() {
for dataPk := range transport.dataChan {
@ -163,8 +164,8 @@ func TestFlowControl(t *testing.T) {
func TestError(t *testing.T) {
transport := newFakeTransport()
reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)
reader := NewReader("1", 1024, transport)
writer := NewWriter("1", 1024, transport)
go func() {
for dataPk := range transport.dataChan {
@ -194,8 +195,8 @@ func TestError(t *testing.T) {
func TestCancel(t *testing.T) {
transport := newFakeTransport()
reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)
reader := NewReader("1", 1024, transport)
writer := NewWriter("1", 1024, transport)
go func() {
for dataPk := range transport.dataChan {
@ -227,8 +228,8 @@ func TestCancel(t *testing.T) {
func TestMultipleWrites(t *testing.T) {
transport := newFakeTransport()
reader := NewReader(1, 1024, transport)
writer := NewWriter(1, 1024, transport)
reader := NewReader("1", 1024, transport)
writer := NewWriter("1", 1024, transport)
go func() {
for dataPk := range transport.dataChan {
@ -265,3 +266,258 @@ func TestMultipleWrites(t *testing.T) {
t.Fatalf("Expected %q, got %q", expected, string(buf))
}
}
func TestOutOfOrderPackets(t *testing.T) {
transport := newFakeTransport()
reader := NewReader("test-ooo", 1024, transport)
packet0 := wshrpc.CommandStreamData{
Id: "test-ooo",
Seq: 0,
Data64: base64.StdEncoding.EncodeToString([]byte("AAAAA")),
}
packet5 := wshrpc.CommandStreamData{
Id: "test-ooo",
Seq: 5,
Data64: base64.StdEncoding.EncodeToString([]byte("BBBBB")),
}
packet10 := wshrpc.CommandStreamData{
Id: "test-ooo",
Seq: 10,
Data64: base64.StdEncoding.EncodeToString([]byte("CCCCC")),
}
packet15 := wshrpc.CommandStreamData{
Id: "test-ooo",
Seq: 15,
Data64: base64.StdEncoding.EncodeToString([]byte("DDDDD")),
}
// Send packets out of order: 0, 10, 15, 5
reader.RecvData(packet0)
reader.RecvData(packet10) // OOO - should be buffered
reader.RecvData(packet15) // OOO - should be buffered
reader.RecvData(packet5) // fills the gap - should trigger processing
// Read all data
buf := make([]byte, 1024)
totalRead := 0
expectedLen := 20 // 4 packets * 5 bytes each
readDone := make(chan struct{})
go func() {
for totalRead < expectedLen {
n, err := reader.Read(buf[totalRead:])
if err != nil {
t.Errorf("Read failed: %v", err)
return
}
totalRead += n
}
close(readDone)
}()
select {
case <-readDone:
// Success
case <-time.After(2 * time.Second):
t.Fatalf("Read didn't complete in time. Read %d bytes, expected %d", totalRead, expectedLen)
}
if totalRead != expectedLen {
t.Fatalf("Expected to read %d bytes, got %d", expectedLen, totalRead)
}
}
func TestOutOfOrderWithDuplicates(t *testing.T) {
transport := newFakeTransport()
reader := NewReader("test-dup", 1024, transport)
packet0 := wshrpc.CommandStreamData{
Id: "test-dup",
Seq: 0,
Data64: base64.StdEncoding.EncodeToString([]byte("aaaaa")),
}
packet10 := wshrpc.CommandStreamData{
Id: "test-dup",
Seq: 10,
Data64: base64.StdEncoding.EncodeToString([]byte("ccccc")),
}
packet5First := wshrpc.CommandStreamData{
Id: "test-dup",
Seq: 5,
Data64: base64.StdEncoding.EncodeToString([]byte("xxxxx")),
}
packet5Second := wshrpc.CommandStreamData{
Id: "test-dup",
Seq: 5,
Data64: base64.StdEncoding.EncodeToString([]byte("bbbbb")),
}
reader.RecvData(packet0)
reader.RecvData(packet10) // OOO - buffered
reader.RecvData(packet5First) // OOO - buffered
reader.RecvData(packet5First) // Duplicate - should be ignored
reader.RecvData(packet5Second) // Duplicate with different data - should be ignored
// Read all data - should get all 3 packets in order
buf := make([]byte, 20)
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
// Should get all 15 bytes (3 packets * 5 bytes)
if n != 15 {
t.Fatalf("Expected to read 15 bytes, got %d", n)
}
// Should be "aaaaaxxxxxccccc" (first packet received for each seq wins)
expected := "aaaaaxxxxxccccc"
if string(buf[:n]) != expected {
t.Fatalf("Expected %q, got %q", expected, string(buf[:n]))
}
}
func TestOutOfOrderWithGaps(t *testing.T) {
transport := newFakeTransport()
reader := NewReader("test-gaps", 1024, transport)
packet0 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 0,
Data64: base64.StdEncoding.EncodeToString([]byte("aaaaa")),
}
packet20 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 20,
Data64: base64.StdEncoding.EncodeToString([]byte("eeeee")),
}
packet40 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 40,
Data64: base64.StdEncoding.EncodeToString([]byte("iiiii")),
}
packet5 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 5,
Data64: base64.StdEncoding.EncodeToString([]byte("bbbbb")),
}
reader.RecvData(packet0)
reader.RecvData(packet40) // Way ahead - should be buffered
reader.RecvData(packet20) // Still ahead - should be buffered
// Read first packet
buf := make([]byte, 10)
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 || string(buf[:n]) != "aaaaa" {
t.Fatalf("Expected 'aaaaa', got %q", string(buf[:n]))
}
// Send packet to partially fill gap
reader.RecvData(packet5)
// Should be able to read it now
n, err = reader.Read(buf)
if err != nil {
t.Fatalf("Second read failed: %v", err)
}
if n != 5 || string(buf[:n]) != "bbbbb" {
t.Fatalf("Expected 'bbbbb', got %q", string(buf[:n]))
}
packet10 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 10,
Data64: base64.StdEncoding.EncodeToString([]byte("ccccc")),
}
packet15 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 15,
Data64: base64.StdEncoding.EncodeToString([]byte("ddddd")),
}
packet25 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 25,
Data64: base64.StdEncoding.EncodeToString([]byte("fffff")),
}
packet30 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 30,
Data64: base64.StdEncoding.EncodeToString([]byte("ggggg")),
}
packet35 := wshrpc.CommandStreamData{
Id: "test-gaps",
Seq: 35,
Data64: base64.StdEncoding.EncodeToString([]byte("hhhhh")),
}
reader.RecvData(packet10)
reader.RecvData(packet15)
reader.RecvData(packet25)
reader.RecvData(packet30)
reader.RecvData(packet35)
// Read all remaining data at once
allData := make([]byte, 100)
totalRead := 0
for totalRead < 35 {
n, err = reader.Read(allData[totalRead:])
if err != nil {
t.Fatalf("Read failed: %v", err)
}
totalRead += n
}
expected := "cccccdddddeeeeefffffggggghhhhhiiiii"
if string(allData[:totalRead]) != expected {
t.Fatalf("Expected %q, got %q", expected, string(allData[:totalRead]))
}
}
func TestOutOfOrderWithEOF(t *testing.T) {
transport := newFakeTransport()
reader := NewReader("test-eof", 1024, transport)
packet0 := wshrpc.CommandStreamData{
Id: "test-eof",
Seq: 0,
Data64: base64.StdEncoding.EncodeToString([]byte("first")),
}
packet11 := wshrpc.CommandStreamData{
Id: "test-eof",
Seq: 11,
Data64: base64.StdEncoding.EncodeToString([]byte("third")),
Eof: true,
}
packet5 := wshrpc.CommandStreamData{
Id: "test-eof",
Seq: 5,
Data64: base64.StdEncoding.EncodeToString([]byte("second")),
}
reader.RecvData(packet0)
reader.RecvData(packet11) // OOO with EOF
reader.RecvData(packet5) // Fill the gap
// Read all data
buf := make([]byte, 20)
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
expected := "firstsecondthird"
if string(buf[:n]) != expected {
t.Fatalf("Expected %q, got %q", expected, string(buf[:n]))
}
// Should get EOF now
_, err = reader.Read(buf)
if err != io.EOF {
t.Fatalf("Expected EOF, got %v", err)
}
}

View file

@ -5,10 +5,9 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/utilds"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
type workItem struct {
@ -17,36 +16,23 @@ type workItem struct {
dataPk wshrpc.CommandStreamData
}
type StreamWriter interface {
RecvAck(ackPk wshrpc.CommandStreamAckData)
}
type StreamRpcInterface interface {
StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error
StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error
}
type wshRpcAdapter struct {
rpc *wshutil.WshRpc
}
func (a *wshRpcAdapter) StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error {
return wshclient.StreamDataAckCommand(a.rpc, data, opts)
}
func (a *wshRpcAdapter) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error {
return wshclient.StreamDataCommand(a.rpc, data, opts)
}
func AdaptWshRpc(rpc *wshutil.WshRpc) StreamRpcInterface {
return &wshRpcAdapter{rpc: rpc}
}
type Broker struct {
lock sync.Mutex
rpcClient StreamRpcInterface
streamIdCounter int64
readers map[int64]*Reader
writers map[int64]*Writer
readerRoutes map[int64]string
writerRoutes map[int64]string
readerErrorSentTime map[int64]time.Time
readers map[string]*Reader
writers map[string]StreamWriter
readerRoutes map[string]string
writerRoutes map[string]string
readerErrorSentTime map[string]time.Time
sendQueue *utilds.WorkQueue[workItem]
recvQueue *utilds.WorkQueue[workItem]
}
@ -54,12 +40,11 @@ type Broker struct {
func NewBroker(rpcClient StreamRpcInterface) *Broker {
b := &Broker{
rpcClient: rpcClient,
streamIdCounter: 0,
readers: make(map[int64]*Reader),
writers: make(map[int64]*Writer),
readerRoutes: make(map[int64]string),
writerRoutes: make(map[int64]string),
readerErrorSentTime: make(map[int64]time.Time),
readers: make(map[string]*Reader),
writers: make(map[string]StreamWriter),
readerRoutes: make(map[string]string),
writerRoutes: make(map[string]string),
readerErrorSentTime: make(map[string]time.Time),
}
b.sendQueue = utilds.NewWorkQueue(b.processSendWork)
b.recvQueue = utilds.NewWorkQueue(b.processRecvWork)
@ -67,13 +52,16 @@ func NewBroker(rpcClient StreamRpcInterface) *Broker {
}
func (b *Broker) CreateStreamReader(readerRoute string, writerRoute string, rwnd int64) (*Reader, *wshrpc.StreamMeta) {
return b.CreateStreamReaderWithSeq(readerRoute, writerRoute, rwnd, 0)
}
func (b *Broker) CreateStreamReaderWithSeq(readerRoute string, writerRoute string, rwnd int64, startSeq int64) (*Reader, *wshrpc.StreamMeta) {
b.lock.Lock()
defer b.lock.Unlock()
b.streamIdCounter++
streamId := b.streamIdCounter
streamId := uuid.New().String()
reader := NewReader(streamId, rwnd, b)
reader := NewReaderWithSeq(streamId, rwnd, startSeq, b)
b.readers[streamId] = reader
b.readerRoutes[streamId] = readerRoute
b.writerRoutes[streamId] = writerRoute
@ -88,19 +76,35 @@ func (b *Broker) CreateStreamReader(readerRoute string, writerRoute string, rwnd
return reader, meta
}
func (b *Broker) AttachStreamWriter(meta *wshrpc.StreamMeta) (*Writer, error) {
func (b *Broker) AttachStreamWriter(meta *wshrpc.StreamMeta, writer StreamWriter) error {
b.lock.Lock()
defer b.lock.Unlock()
if _, exists := b.writers[meta.Id]; exists {
return nil, fmt.Errorf("writer already registered for stream id %d", meta.Id)
return fmt.Errorf("writer already registered for stream id %s", meta.Id)
}
writer := NewWriter(meta.Id, meta.RWnd, b)
b.writers[meta.Id] = writer
b.readerRoutes[meta.Id] = meta.ReaderRouteId
b.writerRoutes[meta.Id] = meta.WriterRouteId
return nil
}
func (b *Broker) DetachStreamWriter(streamId string) {
b.lock.Lock()
defer b.lock.Unlock()
delete(b.writers, streamId)
delete(b.writerRoutes, streamId)
}
func (b *Broker) CreateStreamWriter(meta *wshrpc.StreamMeta) (*Writer, error) {
writer := NewWriter(meta.Id, meta.RWnd, b)
err := b.AttachStreamWriter(meta, writer)
if err != nil {
return nil, err
}
return writer, nil
}
@ -112,6 +116,9 @@ func (b *Broker) SendData(dataPk wshrpc.CommandStreamData) {
b.sendQueue.Enqueue(workItem{workType: "senddata", dataPk: dataPk})
}
// RecvData and RecvAck are designed to be non-blocking and must remain so to prevent deadlock.
// They only enqueue work items to be processed asynchronously by the work queue's goroutine.
// These methods are called from the main RPC runServer loop, so blocking here would stall all RPC processing.
func (b *Broker) RecvData(dataPk wshrpc.CommandStreamData) {
b.recvQueue.Enqueue(workItem{workType: "recvdata", dataPk: dataPk})
}
@ -220,7 +227,7 @@ func (b *Broker) Close() {
b.recvQueue.Wait()
}
func (b *Broker) cleanupReader(streamId int64) {
func (b *Broker) cleanupReader(streamId string) {
b.lock.Lock()
defer b.lock.Unlock()
@ -229,7 +236,7 @@ func (b *Broker) cleanupReader(streamId int64) {
delete(b.readerErrorSentTime, streamId)
}
func (b *Broker) cleanupWriter(streamId int64) {
func (b *Broker) cleanupWriter(streamId string) {
b.lock.Lock()
defer b.lock.Unlock()

View file

@ -68,9 +68,9 @@ func TestBrokerBasicReadWrite(t *testing.T) {
broker1, broker2 := setupBrokerPair()
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
testData := []byte("Hello, World!")
@ -105,9 +105,9 @@ func TestBrokerEOF(t *testing.T) {
broker1, broker2 := setupBrokerPair()
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
testData := []byte("Test data")
@ -134,9 +134,9 @@ func TestBrokerFlowControl(t *testing.T) {
smallWindow := int64(10)
reader, meta := broker1.CreateStreamReader("reader1", "writer1", smallWindow)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
largeData := make([]byte, 100)
@ -180,9 +180,9 @@ func TestBrokerError(t *testing.T) {
broker1, broker2 := setupBrokerPair()
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
testErr := io.ErrUnexpectedEOF
@ -202,9 +202,9 @@ func TestBrokerCancel(t *testing.T) {
broker1, broker2 := setupBrokerPair()
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
reader.Close()
@ -226,9 +226,9 @@ func TestBrokerMultipleWrites(t *testing.T) {
broker1, broker2 := setupBrokerPair()
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
messages := []string{"First", "Second", "Third"}
@ -261,9 +261,9 @@ func TestBrokerCleanup(t *testing.T) {
broker1, broker2 := setupBrokerPair()
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
writer, err := broker2.AttachStreamWriter(meta)
writer, err := broker2.CreateStreamWriter(meta)
if err != nil {
t.Fatalf("AttachStreamWriter failed: %v", err)
t.Fatalf("CreateStreamWriter failed: %v", err)
}
testData := []byte("cleanup test")

View file

@ -4,6 +4,7 @@ import (
"encoding/base64"
"fmt"
"io"
"sort"
"sync"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
@ -16,7 +17,7 @@ type AckSender interface {
type Reader struct {
lock sync.Mutex
cond *sync.Cond
id int64
id string
ackSender AckSender
readWindow int64
nextSeq int64
@ -25,14 +26,19 @@ type Reader struct {
err error
closed bool
lastRwndSent int64
oooPackets []wshrpc.CommandStreamData // out-of-order packets awaiting delivery
}
func NewReader(id int64, readWindow int64, ackSender AckSender) *Reader {
func NewReader(id string, readWindow int64, ackSender AckSender) *Reader {
return NewReaderWithSeq(id, readWindow, 0, ackSender)
}
func NewReaderWithSeq(id string, readWindow int64, startSeq int64, ackSender AckSender) *Reader {
r := &Reader{
id: id,
readWindow: readWindow,
ackSender: ackSender,
nextSeq: 0,
nextSeq: startSeq,
lastRwndSent: readWindow,
}
r.cond = sync.NewCond(&r.lock)
@ -43,7 +49,7 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) {
r.lock.Lock()
defer r.lock.Unlock()
if r.closed {
if r.closed || r.eof || r.err != nil {
return
}
@ -59,18 +65,25 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) {
return
}
if dataPk.Seq != r.nextSeq {
r.err = fmt.Errorf("stream sequence mismatch: expected %d, got %d", r.nextSeq, dataPk.Seq)
r.cond.Broadcast()
r.sendAckLocked(false, true, "sequence mismatch error")
if dataPk.Seq < r.nextSeq {
return
}
if dataPk.Seq > r.nextSeq {
r.addOOOPacketLocked(dataPk)
return
}
r.recvDataOrderedLocked(dataPk)
r.processOOOPacketsLocked()
r.cond.Broadcast()
r.sendAckLocked(r.eof, false, "")
}
func (r *Reader) recvDataOrderedLocked(dataPk wshrpc.CommandStreamData) {
if dataPk.Data64 != "" {
data, err := base64.StdEncoding.DecodeString(dataPk.Data64)
if err != nil {
r.err = err
r.cond.Broadcast()
r.sendAckLocked(false, true, "base64 decode error")
return
}
@ -80,13 +93,40 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) {
if dataPk.Eof {
r.eof = true
r.cond.Broadcast()
r.sendAckLocked(true, false, "")
}
}
func (r *Reader) addOOOPacketLocked(dataPk wshrpc.CommandStreamData) {
for _, pkt := range r.oooPackets {
if pkt.Seq == dataPk.Seq {
// this handles duplicates
return
}
}
r.oooPackets = append(r.oooPackets, dataPk)
}
func (r *Reader) processOOOPacketsLocked() {
if len(r.oooPackets) == 0 {
return
}
r.cond.Broadcast()
r.sendAckLocked(false, false, "")
sort.Slice(r.oooPackets, func(i, j int) bool {
return r.oooPackets[i].Seq < r.oooPackets[j].Seq
})
consumed := 0
for _, pkt := range r.oooPackets {
if r.eof || r.err != nil {
// we're done, so we can clear any pending ooo packets
r.oooPackets = nil
return
}
if pkt.Seq != r.nextSeq {
break
}
r.recvDataOrderedLocked(pkt)
consumed++
}
r.oooPackets = r.oooPackets[consumed:]
}
func (r *Reader) sendAckLocked(fin bool, cancel bool, errStr string) {
@ -146,6 +186,12 @@ func (r *Reader) Read(p []byte) (int, error) {
return n, nil
}
func (r *Reader) UpdateNextSeq(newSeq int64) {
r.lock.Lock()
defer r.lock.Unlock()
r.nextSeq = newSeq
}
func (r *Reader) Close() error {
r.lock.Lock()
defer r.lock.Unlock()

View file

@ -16,7 +16,7 @@ type DataSender interface {
type Writer struct {
lock sync.Mutex
cond *sync.Cond
id int64
id string
dataSender DataSender
readWindow int64
nextSeq int64
@ -31,7 +31,7 @@ type Writer struct {
closed bool
}
func NewWriter(id int64, readWindow int64, dataSender DataSender) *Writer {
func NewWriter(id string, readWindow int64, dataSender DataSender) *Writer {
w := &Writer{
id: id,
readWindow: readWindow,

View file

@ -26,11 +26,13 @@ var (
type WaveJwtClaims struct {
jwt.RegisteredClaims
Sock string `json:"sock,omitempty"`
RouteId string `json:"routeid,omitempty"`
BlockId string `json:"blockid,omitempty"`
Conn string `json:"conn,omitempty"`
Router bool `json:"router,omitempty"`
MainServer bool `json:"mainserver,omitempty"`
Sock string `json:"sock,omitempty"`
RouteId string `json:"routeid,omitempty"`
BlockId string `json:"blockid,omitempty"`
JobId string `json:"jobid,omitempty"`
Conn string `json:"conn,omitempty"`
Router bool `json:"router,omitempty"`
}
type KeyPair struct {

View file

@ -29,6 +29,7 @@ const (
OType_LayoutState = "layout"
OType_Block = "block"
OType_MainServer = "mainserver"
OType_Job = "job"
OType_Temp = "temp"
OType_Builder = "builder" // not persisted to DB
)
@ -41,6 +42,7 @@ var ValidOTypes = map[string]bool{
OType_LayoutState: true,
OType_Block: true,
OType_MainServer: true,
OType_Job: true,
OType_Temp: true,
OType_Builder: true,
}
@ -134,6 +136,7 @@ type Client struct {
TosAgreed int64 `json:"tosagreed,omitempty"` // unix milli
HasOldHistory bool `json:"hasoldhistory,omitempty"`
TempOID string `json:"tempoid,omitempty"`
InstallId string `json:"installid,omitempty"`
}
func (*Client) GetOType() string {
@ -288,6 +291,7 @@ type Block struct {
Stickers []*StickerType `json:"stickers,omitempty"`
Meta MetaMapType `json:"meta"`
SubBlockIds []string `json:"subblockids,omitempty"`
JobId string `json:"jobid,omitempty"` // if set, the block will render this jobid's pty output
}
func (*Block) GetOType() string {
@ -306,6 +310,49 @@ func (*MainServer) GetOType() string {
return OType_MainServer
}
type Job struct {
OID string `json:"oid"`
Version int `json:"version"`
// job metadata
Connection string `json:"connection"`
JobKind string `json:"jobkind"` // shell, task
Cmd string `json:"cmd"`
CmdArgs []string `json:"cmdargs,omitempty"`
CmdEnv map[string]string `json:"cmdenv,omitempty"`
JobAuthToken string `json:"jobauthtoken"` // job manger -> wave
AttachedBlockId string `json:"attachedblockid,omitempty"`
// reconnect option (e.g. orphaned, so we need to kill on connect)
TerminateOnReconnect bool `json:"terminateonreconnect,omitempty"`
// job manager state
JobManagerStatus string `json:"jobmanagerstatus"` // init, running, done
JobManagerDoneReason string `json:"jobmanagerdonereason,omitempty"` // startuperror, gone, terminated
JobManagerStartupError string `json:"jobmanagerstartuperror,omitempty"`
JobManagerPid int `json:"jobmanagerpid,omitempty"`
JobManagerStartTs int64 `json:"jobmanagerstartts,omitempty"` // exact process start time (milliseconds)
// cmd/process runtime info
CmdPid int `json:"cmdpid,omitempty"` // command process id
CmdStartTs int64 `json:"cmdstartts,omitempty"` // exact command process start time (milliseconds from epoch)
CmdTermSize TermSize `json:"cmdtermsize"`
CmdExitTs int64 `json:"cmdexitts,omitempty"` // timestamp (milliseconds) -- use CmdExitTs > 0 to check if command has exited
CmdExitCode *int `json:"cmdexitcode,omitempty"` // nil when CmdExitSignal is set. success exit is when CmdExitCode is 0
CmdExitSignal string `json:"cmdexitsignal,omitempty"` // empty string if CmdExitCode is set
CmdExitError string `json:"cmdexiterror,omitempty"`
// output info
StreamDone bool `json:"streamdone,omitempty"`
StreamError string `json:"streamerror,omitempty"`
Meta MetaMapType `json:"meta"`
}
func (*Job) GetOType() string {
return OType_Job
}
func AllWaveObjTypes() []reflect.Type {
return []reflect.Type{
reflect.TypeOf(&Client{}),
@ -315,6 +362,7 @@ func AllWaveObjTypes() []reflect.Type {
reflect.TypeOf(&Block{}),
reflect.TypeOf(&LayoutState{}),
reflect.TypeOf(&MainServer{}),
reflect.TypeOf(&Job{}),
}
}

View file

@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/blockcontroller"
"github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/jobcontroller"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata"
@ -167,6 +168,19 @@ func DeleteBlock(ctx context.Context, blockId string, recursive bool) error {
}
}
}
if block.JobId != "" {
go func() {
defer func() {
panichandler.PanicHandler("DetachJobFromBlock", recover())
}()
detachCtx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
err := jobcontroller.DetachJobFromBlock(detachCtx, block.JobId, false)
if err != nil {
log.Printf("error detaching job from block %s: %v", blockId, err)
}
}()
}
parentBlockCount, err := deleteBlockObj(ctx, blockId)
if err != nil {
return fmt.Errorf("error deleting block: %w", err)

View file

@ -50,6 +50,14 @@ func EnsureInitialData() (bool, error) {
return firstLaunch, fmt.Errorf("error updating client: %w", err)
}
}
if client.InstallId == "" {
log.Println("client.InstallId is empty")
client.InstallId = uuid.NewString()
err = wstore.DBUpdate(ctx, client)
if err != nil {
return firstLaunch, fmt.Errorf("error updating client: %w", err)
}
}
log.Printf("clientid: %s\n", client.OID)
if len(client.WindowIds) == 1 {
log.Println("client has one window")

View file

@ -9,14 +9,11 @@ import (
"github.com/wavetermdev/waveterm/pkg/tsgen/tsgenmeta"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
const (
WSCommand_SetBlockTermSize = "setblocktermsize"
WSCommand_BlockInput = "blockinput"
WSCommand_Rpc = "rpc"
WSCommand_Rpc = "rpc"
)
type WSCommandType interface {
@ -28,8 +25,6 @@ func WSCommandTypeUnionMeta() tsgenmeta.TypeUnionMeta {
BaseType: reflect.TypeOf((*WSCommandType)(nil)).Elem(),
TypeFieldName: "wscommand",
Types: []reflect.Type{
reflect.TypeOf(SetBlockTermSizeWSCommand{}),
reflect.TypeOf(BlockInputWSCommand{}),
reflect.TypeOf(WSRpcCommand{}),
},
}
@ -44,46 +39,12 @@ func (cmd *WSRpcCommand) GetWSCommand() string {
return cmd.WSCommand
}
type SetBlockTermSizeWSCommand struct {
WSCommand string `json:"wscommand" tstype:"\"setblocktermsize\""`
BlockId string `json:"blockid"`
TermSize waveobj.TermSize `json:"termsize"`
}
func (cmd *SetBlockTermSizeWSCommand) GetWSCommand() string {
return cmd.WSCommand
}
type BlockInputWSCommand struct {
WSCommand string `json:"wscommand" tstype:"\"blockinput\""`
BlockId string `json:"blockid"`
InputData64 string `json:"inputdata64"`
}
func (cmd *BlockInputWSCommand) GetWSCommand() string {
return cmd.WSCommand
}
func ParseWSCommandMap(cmdMap map[string]any) (WSCommandType, error) {
cmdType, ok := cmdMap["wscommand"].(string)
if !ok {
return nil, fmt.Errorf("no wscommand field in command map")
}
switch cmdType {
case WSCommand_SetBlockTermSize:
var cmd SetBlockTermSizeWSCommand
err := utilfn.DoMapStructure(&cmd, cmdMap)
if err != nil {
return nil, fmt.Errorf("error decoding SetBlockTermSizeWSCommand: %w", err)
}
return &cmd, nil
case WSCommand_BlockInput:
var cmd BlockInputWSCommand
err := utilfn.DoMapStructure(&cmd, cmdMap)
if err != nil {
return nil, fmt.Errorf("error decoding BlockInputWSCommand: %w", err)
}
return &cmd, nil
case WSCommand_Rpc:
var cmd WSRpcCommand
err := utilfn.DoMapStructure(&cmd, cmdMap)
@ -94,5 +55,4 @@ func ParseWSCommandMap(cmdMap map[string]any) (WSCommandType, error) {
default:
return nil, fmt.Errorf("unknown wscommand type %q", cmdType)
}
}

View file

@ -20,7 +20,6 @@ import (
"github.com/wavetermdev/waveterm/pkg/eventbus"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/web/webcmd"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
@ -110,40 +109,6 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan ba
}
cmdType = wsCommand.GetWSCommand()
switch cmd := wsCommand.(type) {
case *webcmd.SetBlockTermSizeWSCommand:
data := wshrpc.CommandBlockInputData{
BlockId: cmd.BlockId,
TermSize: &cmd.TermSize,
}
rpcMsg := wshutil.RpcMessage{
Command: wshrpc.Command_ControllerInput,
Data: data,
}
msgBytes, err := json.Marshal(rpcMsg)
if err != nil {
// this really should never fail since we just unmarshalled this value
log.Printf("[websocket] error marshalling rpc message: %v\n", err)
return
}
rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes}
case *webcmd.BlockInputWSCommand:
data := wshrpc.CommandBlockInputData{
BlockId: cmd.BlockId,
InputData64: cmd.InputData64,
}
rpcMsg := wshutil.RpcMessage{
Command: wshrpc.Command_ControllerInput,
Data: data,
}
msgBytes, err := json.Marshal(rpcMsg)
if err != nil {
// this really should never fail since we just unmarshalled this value
log.Printf("[websocket] error marshalling rpc message: %v\n", err)
return
}
rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes}
case *webcmd.WSRpcCommand:
rpcMsg := cmd.Message
if rpcMsg == nil {

View file

@ -16,7 +16,8 @@ const (
Event_BlockFile = "blockfile"
Event_Config = "config"
Event_UserInput = "userinput"
Event_RouteGone = "route:gone"
Event_RouteDown = "route:down"
Event_RouteUp = "route:up"
Event_WorkspaceUpdate = "workspace:update"
Event_WaveAIRateLimit = "waveai:ratelimit"
Event_WaveAppAppGoUpdated = "waveapp:appgoupdated"

View file

@ -4,8 +4,10 @@
package wshclient
import (
"fmt"
"sync"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
@ -17,21 +19,22 @@ func (*WshServer) WshServerImpl() {}
var WshServerImpl = WshServer{}
const (
DefaultOutputChSize = 32
DefaultInputChSize = 32
)
var waveSrvClient_Singleton *wshutil.WshRpc
var waveSrvClient_Once = &sync.Once{}
const BareClientRoute = "bare"
var waveSrvClient_RouteId string
func GetBareRpcClient() *wshutil.WshRpc {
waveSrvClient_Once.Do(func() {
waveSrvClient_Singleton = wshutil.MakeWshRpc(wshrpc.RpcContext{}, &WshServerImpl, "bare-client")
wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, BareClientRoute)
waveSrvClient_RouteId = fmt.Sprintf("bare:%s", uuid.New().String())
// we can safely ignore the error from RegisterTrustedLeaf since the route is valid
wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, waveSrvClient_RouteId)
wps.Broker.SetClient(wshutil.DefaultRouter)
})
return waveSrvClient_Singleton
}
func GetBareRpcClientRouteId() string {
GetBareRpcClient()
return waveSrvClient_RouteId
}

View file

@ -35,6 +35,24 @@ func AuthenticateCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) (
return resp, err
}
// command "authenticatejobmanager", wshserver.AuthenticateJobManagerCommand
func AuthenticateJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateJobManagerData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "authenticatejobmanager", data, opts)
return err
}
// command "authenticatejobmanagerverify", wshserver.AuthenticateJobManagerVerifyCommand
func AuthenticateJobManagerVerifyCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateJobManagerData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "authenticatejobmanagerverify", data, opts)
return err
}
// command "authenticatetojobmanager", wshserver.AuthenticateToJobManagerCommand
func AuthenticateToJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateToJobData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "authenticatetojobmanager", data, opts)
return err
}
// command "authenticatetoken", wshserver.AuthenticateTokenCommand
func AuthenticateTokenCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateTokenData, opts *wshrpc.RpcOpts) (wshrpc.CommandAuthenticateRtnData, error) {
resp, err := sendRpcRequestCallHelper[wshrpc.CommandAuthenticateRtnData](w, "authenticatetoken", data, opts)
@ -458,6 +476,90 @@ func GetWaveAIRateLimitCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (*uctype
return resp, err
}
// command "jobcmdexited", wshserver.JobCmdExitedCommand
func JobCmdExitedCommand(w *wshutil.WshRpc, data wshrpc.CommandJobCmdExitedData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcmdexited", data, opts)
return err
}
// command "jobcontrollerattachjob", wshserver.JobControllerAttachJobCommand
func JobControllerAttachJobCommand(w *wshutil.WshRpc, data wshrpc.CommandJobControllerAttachJobData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerattachjob", data, opts)
return err
}
// command "jobcontrollerconnectedjobs", wshserver.JobControllerConnectedJobsCommand
func JobControllerConnectedJobsCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
resp, err := sendRpcRequestCallHelper[[]string](w, "jobcontrollerconnectedjobs", nil, opts)
return resp, err
}
// command "jobcontrollerdeletejob", wshserver.JobControllerDeleteJobCommand
func JobControllerDeleteJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdeletejob", data, opts)
return err
}
// command "jobcontrollerdetachjob", wshserver.JobControllerDetachJobCommand
func JobControllerDetachJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdetachjob", data, opts)
return err
}
// command "jobcontrollerdisconnectjob", wshserver.JobControllerDisconnectJobCommand
func JobControllerDisconnectJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdisconnectjob", data, opts)
return err
}
// command "jobcontrollerexitjob", wshserver.JobControllerExitJobCommand
func JobControllerExitJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerexitjob", data, opts)
return err
}
// command "jobcontrollerlist", wshserver.JobControllerListCommand
func JobControllerListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]*waveobj.Job, error) {
resp, err := sendRpcRequestCallHelper[[]*waveobj.Job](w, "jobcontrollerlist", nil, opts)
return resp, err
}
// command "jobcontrollerreconnectjob", wshserver.JobControllerReconnectJobCommand
func JobControllerReconnectJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerreconnectjob", data, opts)
return err
}
// command "jobcontrollerreconnectjobsforconn", wshserver.JobControllerReconnectJobsForConnCommand
func JobControllerReconnectJobsForConnCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerreconnectjobsforconn", data, opts)
return err
}
// command "jobcontrollerstartjob", wshserver.JobControllerStartJobCommand
func JobControllerStartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandJobControllerStartJobData, opts *wshrpc.RpcOpts) (string, error) {
resp, err := sendRpcRequestCallHelper[string](w, "jobcontrollerstartjob", data, opts)
return resp, err
}
// command "jobinput", wshserver.JobInputCommand
func JobInputCommand(w *wshutil.WshRpc, data wshrpc.CommandJobInputData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobinput", data, opts)
return err
}
// command "jobprepareconnect", wshserver.JobPrepareConnectCommand
func JobPrepareConnectCommand(w *wshutil.WshRpc, data wshrpc.CommandJobPrepareConnectData, opts *wshrpc.RpcOpts) (*wshrpc.CommandJobConnectRtnData, error) {
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandJobConnectRtnData](w, "jobprepareconnect", data, opts)
return resp, err
}
// command "jobstartstream", wshserver.JobStartStreamCommand
func JobStartStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandJobStartStreamData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "jobstartstream", data, opts)
return err
}
// command "listallappfiles", wshserver.ListAllAppFilesCommand
func ListAllAppFilesCommand(w *wshutil.WshRpc, data wshrpc.CommandListAllAppFilesData, opts *wshrpc.RpcOpts) (*wshrpc.CommandListAllAppFilesRtnData, error) {
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandListAllAppFilesRtnData](w, "listallappfiles", data, opts)
@ -524,6 +626,12 @@ func RecordTEventCommand(w *wshutil.WshRpc, data telemetrydata.TEvent, opts *wsh
return err
}
// command "remotedisconnectfromjobmanager", wshserver.RemoteDisconnectFromJobManagerCommand
func RemoteDisconnectFromJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteDisconnectFromJobManagerData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "remotedisconnectfromjobmanager", data, opts)
return err
}
// command "remotefilecopy", wshserver.RemoteFileCopyCommand
func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) (bool, error) {
resp, err := sendRpcRequestCallHelper[bool](w, "remotefilecopy", data, opts)
@ -583,6 +691,18 @@ func RemoteMkdirCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) er
return err
}
// command "remotereconnecttojobmanager", wshserver.RemoteReconnectToJobManagerCommand
func RemoteReconnectToJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteReconnectToJobManagerData, opts *wshrpc.RpcOpts) (*wshrpc.CommandRemoteReconnectToJobManagerRtnData, error) {
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandRemoteReconnectToJobManagerRtnData](w, "remotereconnecttojobmanager", data, opts)
return resp, err
}
// command "remotestartjob", wshserver.RemoteStartJobCommand
func RemoteStartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStartJobData, opts *wshrpc.RpcOpts) (*wshrpc.CommandStartJobRtnData, error) {
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandStartJobRtnData](w, "remotestartjob", data, opts)
return resp, err
}
// command "remotestreamcpudata", wshserver.RemoteStreamCpuDataCommand
func RemoteStreamCpuDataCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.TimeSeriesData] {
return sendRpcRequestResponseStreamHelper[wshrpc.TimeSeriesData](w, "remotestreamcpudata", nil, opts)
@ -598,6 +718,12 @@ func RemoteTarStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTa
return sendRpcRequestResponseStreamHelper[iochantypes.Packet](w, "remotetarstream", data, opts)
}
// command "remoteterminatejobmanager", wshserver.RemoteTerminateJobManagerCommand
func RemoteTerminateJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteTerminateJobManagerData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "remoteterminatejobmanager", data, opts)
return err
}
// command "remotewritefile", wshserver.RemoteWriteFileCommand
func RemoteWriteFileCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "remotewritefile", data, opts)
@ -688,6 +814,12 @@ func StartBuilderCommand(w *wshutil.WshRpc, data wshrpc.CommandStartBuilderData,
return err
}
// command "startjob", wshserver.StartJobCommand
func StartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandStartJobData, opts *wshrpc.RpcOpts) (*wshrpc.CommandStartJobRtnData, error) {
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandStartJobRtnData](w, "startjob", data, opts)
return resp, err
}
// command "stopbuilder", wshserver.StopBuilderCommand
func StopBuilderCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "stopbuilder", data, opts)
@ -727,6 +859,12 @@ func TermGetScrollbackLinesCommand(w *wshutil.WshRpc, data wshrpc.CommandTermGet
return resp, err
}
// command "termupdateattachedjob", wshserver.TermUpdateAttachedJobCommand
func TermUpdateAttachedJobCommand(w *wshutil.WshRpc, data wshrpc.CommandTermUpdateAttachedJobData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "termupdateattachedjob", data, opts)
return err
}
// command "test", wshserver.TestCommand
func TestCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "test", data, opts)

View file

@ -4,35 +4,44 @@
package wshremote
import (
"archive/tar"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"net"
"path/filepath"
"strings"
"time"
"sync"
"github.com/wavetermdev/waveterm/pkg/remote/connparse"
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype"
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
"github.com/wavetermdev/waveterm/pkg/suggestion"
"github.com/wavetermdev/waveterm/pkg/util/fileutil"
"github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes"
"github.com/wavetermdev/waveterm/pkg/util/tarcopy"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
type JobManagerConnection struct {
JobId string
Conn net.Conn
WshRpc *wshutil.WshRpc
CleanupFn func()
}
type ServerImpl struct {
LogWriter io.Writer
LogWriter io.Writer
Router *wshutil.WshRouter
RpcClient *wshutil.WshRpc
IsLocal bool
JobManagerMap map[string]*JobManagerConnection
Lock sync.Mutex
}
func MakeRemoteRpcServerImpl(logWriter io.Writer, router *wshutil.WshRouter, rpcClient *wshutil.WshRpc, isLocal bool) *ServerImpl {
return &ServerImpl{
LogWriter: logWriter,
Router: router,
RpcClient: rpcClient,
IsLocal: isLocal,
JobManagerMap: make(map[string]*JobManagerConnection),
}
}
func (*ServerImpl) WshServerImpl() {}
@ -66,785 +75,6 @@ func (impl *ServerImpl) StreamTestCommand(ctx context.Context) chan wshrpc.RespO
return ch
}
type ByteRangeType struct {
All bool
Start int64
End int64
}
func parseByteRange(rangeStr string) (ByteRangeType, error) {
if rangeStr == "" {
return ByteRangeType{All: true}, nil
}
var start, end int64
_, err := fmt.Sscanf(rangeStr, "%d-%d", &start, &end)
if err != nil {
return ByteRangeType{}, errors.New("invalid byte range")
}
if start < 0 || end < 0 || start > end {
return ByteRangeType{}, errors.New("invalid byte range")
}
return ByteRangeType{Start: start, End: end}, nil
}
func (impl *ServerImpl) remoteStreamFileDir(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
innerFilesEntries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("cannot open dir %q: %w", path, err)
}
if byteRange.All {
if len(innerFilesEntries) > wshrpc.MaxDirSize {
innerFilesEntries = innerFilesEntries[:wshrpc.MaxDirSize]
}
} else {
if byteRange.Start < int64(len(innerFilesEntries)) {
realEnd := byteRange.End
if realEnd > int64(len(innerFilesEntries)) {
realEnd = int64(len(innerFilesEntries))
}
innerFilesEntries = innerFilesEntries[byteRange.Start:realEnd]
} else {
innerFilesEntries = []os.DirEntry{}
}
}
var fileInfoArr []*wshrpc.FileInfo
for _, innerFileEntry := range innerFilesEntries {
if ctx.Err() != nil {
return ctx.Err()
}
innerFileInfoInt, err := innerFileEntry.Info()
if err != nil {
continue
}
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
fileInfoArr = append(fileInfoArr, innerFileInfo)
if len(fileInfoArr) >= wshrpc.DirChunkSize {
dataCallback(fileInfoArr, nil, byteRange)
fileInfoArr = nil
}
}
if len(fileInfoArr) > 0 {
dataCallback(fileInfoArr, nil, byteRange)
}
return nil
}
func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
fd, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", path, err)
}
defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path)
var filePos int64
if !byteRange.All && byteRange.Start > 0 {
_, err := fd.Seek(byteRange.Start, io.SeekStart)
if err != nil {
return fmt.Errorf("seeking file %q: %w", path, err)
}
filePos = byteRange.Start
}
buf := make([]byte, wshrpc.FileChunkSize)
for {
if ctx.Err() != nil {
return ctx.Err()
}
n, err := fd.Read(buf)
if n > 0 {
if !byteRange.All && filePos+int64(n) > byteRange.End {
n = int(byteRange.End - filePos)
}
filePos += int64(n)
dataCallback(nil, buf[:n], byteRange)
}
if !byteRange.All && filePos >= byteRange.End {
break
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("reading file %q: %w", path, err)
}
}
return nil
}
func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrpc.CommandRemoteStreamFileData, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
byteRange, err := parseByteRange(data.ByteRange)
if err != nil {
return err
}
path, err := wavebase.ExpandHomeDir(data.Path)
if err != nil {
return err
}
finfo, err := impl.fileInfoInternal(path, true)
if err != nil {
return fmt.Errorf("cannot stat file %q: %w", path, err)
}
dataCallback([]*wshrpc.FileInfo{finfo}, nil, byteRange)
if finfo.NotFound {
return nil
}
if finfo.IsDir {
return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback)
} else {
return impl.remoteStreamFileRegular(ctx, path, byteRange, dataCallback)
}
}
func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc.CommandRemoteStreamFileData) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] {
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16)
go func() {
defer close(ch)
firstPk := true
err := impl.remoteStreamFileInternal(ctx, data, func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType) {
resp := wshrpc.FileData{}
fileInfoLen := len(fileInfo)
if fileInfoLen > 1 || !firstPk {
resp.Entries = fileInfo
} else if fileInfoLen == 1 {
resp.Info = fileInfo[0]
}
if firstPk {
firstPk = false
}
if len(data) > 0 {
resp.Data64 = base64.StdEncoding.EncodeToString(data)
resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)}
}
ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp}
})
if err != nil {
ch <- wshutil.RespErr[wshrpc.FileData](err)
}
}()
return ch
}
func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] {
path := data.Path
opts := data.Opts
if opts == nil {
opts = &wshrpc.FileCopyOpts{}
}
log.Printf("RemoteTarStreamCommand: path=%s\n", path)
srcHasSlash := strings.HasSuffix(path, "/")
path, err := wavebase.ExpandHomeDir(path)
if err != nil {
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err))
}
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
finfo, err := os.Stat(cleanedPath)
if err != nil {
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err))
}
var pathPrefix string
singleFile := !finfo.IsDir()
if !singleFile && srcHasSlash {
pathPrefix = cleanedPath
} else {
pathPrefix = filepath.Dir(cleanedPath)
}
timeout := fstype.DefaultTimeout
if opts.Timeout > 0 {
timeout = time.Duration(opts.Timeout) * time.Millisecond
}
readerCtx, cancel := context.WithTimeout(ctx, timeout)
rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix)
go func() {
defer func() {
tarClose()
cancel()
}()
walkFunc := func(path string, info fs.FileInfo, err error) error {
if readerCtx.Err() != nil {
return readerCtx.Err()
}
if err != nil {
return err
}
if err = writeHeader(info, path, singleFile); err != nil {
return err
}
// if not a dir, write file content
if !info.IsDir() {
data, err := os.Open(path)
if err != nil {
return err
}
defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path)
if _, err := io.Copy(fileWriter, data); err != nil {
return err
}
}
return nil
}
log.Printf("RemoteTarStreamCommand: starting\n")
err = nil
if singleFile {
err = walkFunc(cleanedPath, finfo, nil)
} else {
err = filepath.Walk(cleanedPath, walkFunc)
}
if err != nil {
rtn <- wshutil.RespErr[iochantypes.Packet](err)
}
log.Printf("RemoteTarStreamCommand: done\n")
}()
log.Printf("RemoteTarStreamCommand: returning channel\n")
return rtn
}
func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) (bool, error) {
log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri)
opts := data.Opts
if opts == nil {
opts = &wshrpc.FileCopyOpts{}
}
destUri := data.DestUri
srcUri := data.SrcUri
merge := opts.Merge
overwrite := opts.Overwrite
if overwrite && merge {
return false, fmt.Errorf("cannot specify both overwrite and merge")
}
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
if err != nil {
return false, fmt.Errorf("cannot parse destination URI %q: %w", destUri, err)
}
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
destinfo, err := os.Stat(destPathCleaned)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return false, fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err)
}
}
destExists := destinfo != nil
destIsDir := destExists && destinfo.IsDir()
destHasSlash := strings.HasSuffix(destUri, "/")
if destExists && !destIsDir {
if !overwrite {
return false, fmt.Errorf(fstype.OverwriteRequiredError, destPathCleaned)
} else {
err := os.Remove(destPathCleaned)
if err != nil {
return false, fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err)
}
}
}
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
if err != nil {
return false, fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
}
copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) {
nextinfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
}
if nextinfo != nil {
if nextinfo.IsDir() {
if !finfo.IsDir() {
// try to create file in directory
path = filepath.Join(path, filepath.Base(finfo.Name()))
newdestinfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
}
if newdestinfo != nil && !overwrite {
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
}
} else if overwrite {
err := os.RemoveAll(path)
if err != nil {
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
}
} else if !merge {
return 0, fmt.Errorf(fstype.MergeRequiredError, path)
}
} else {
if !overwrite {
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
} else if finfo.IsDir() {
err := os.RemoveAll(path)
if err != nil {
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
}
}
}
}
if finfo.IsDir() {
err := os.MkdirAll(path, finfo.Mode())
if err != nil {
return 0, fmt.Errorf("cannot create directory %q: %w", path, err)
}
return 0, nil
} else {
err := os.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err)
}
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
if err != nil {
return 0, fmt.Errorf("cannot create new file %q: %w", path, err)
}
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path)
_, err = io.Copy(file, srcFile)
if err != nil {
return 0, fmt.Errorf("cannot write file %q: %w", path, err)
}
return finfo.Size(), nil
}
srcIsDir := false
if srcConn.Host == destConn.Host {
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
srcFileStat, err := os.Stat(srcPathCleaned)
if err != nil {
return false, fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
}
if srcFileStat.IsDir() {
srcIsDir = true
var srcPathPrefix string
if destIsDir {
srcPathPrefix = filepath.Dir(srcPathCleaned)
} else {
srcPathPrefix = srcPathCleaned
}
err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
srcFilePath := path
destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix))
var file *os.File
if !info.IsDir() {
file, err = os.Open(srcFilePath)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", srcFilePath, err)
}
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath)
}
_, err = copyFileFunc(destFilePath, info, file)
return err
})
if err != nil {
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
} else {
file, err := os.Open(srcPathCleaned)
if err != nil {
return false, fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err)
}
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned)
var destFilePath string
if destHasSlash {
destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned))
} else {
destFilePath = destPathCleaned
}
_, err = copyFileFunc(destFilePath, srcFileStat, file)
if err != nil {
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
}
} else {
timeout := fstype.DefaultTimeout
if opts.Timeout > 0 {
timeout = time.Duration(opts.Timeout) * time.Millisecond
}
readCtx, cancel := context.WithCancelCause(ctx)
readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri))
defer timeoutCancel()
copyStart := time.Now()
ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout})
numFiles := 0
numSkipped := 0
totalBytes := int64(0)
err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error {
numFiles++
nextpath := filepath.Join(destPathCleaned, next.Name)
srcIsDir = !singleFile
if singleFile && !destHasSlash {
// custom flag to indicate that the source is a single file, not a directory the contents of a directory
nextpath = destPathCleaned
}
finfo := next.FileInfo()
n, err := copyFileFunc(nextpath, finfo, reader)
if err != nil {
return fmt.Errorf("cannot copy file %q: %w", next.Name, err)
}
totalBytes += n
return nil
})
if err != nil {
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
totalTime := time.Since(copyStart).Seconds()
totalMegaBytes := float64(totalBytes) / 1024 / 1024
rate := float64(0)
if totalTime > 0 {
rate = totalMegaBytes / totalTime
}
log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped)
}
return srcIsDir, nil
}
func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrpc.CommandRemoteListEntriesData) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] {
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16)
go func() {
defer close(ch)
path, err := wavebase.ExpandHomeDir(data.Path)
if err != nil {
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err)
return
}
innerFilesEntries := []os.DirEntry{}
seen := 0
if data.Opts.Limit == 0 {
data.Opts.Limit = wshrpc.MaxDirSize
}
if data.Opts.All {
fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error {
defer func() {
seen++
}()
if seen < data.Opts.Offset {
return nil
}
if seen >= data.Opts.Offset+data.Opts.Limit {
return io.EOF
}
if err != nil {
return err
}
if d.IsDir() {
return nil
}
innerFilesEntries = append(innerFilesEntries, d)
return nil
})
} else {
innerFilesEntries, err = os.ReadDir(path)
if err != nil {
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](fmt.Errorf("cannot open dir %q: %w", path, err))
return
}
}
var fileInfoArr []*wshrpc.FileInfo
for _, innerFileEntry := range innerFilesEntries {
if ctx.Err() != nil {
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](ctx.Err())
return
}
innerFileInfoInt, err := innerFileEntry.Info()
if err != nil {
log.Printf("cannot stat file %q: %v\n", innerFileEntry.Name(), err)
continue
}
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
fileInfoArr = append(fileInfoArr, innerFileInfo)
if len(fileInfoArr) >= wshrpc.DirChunkSize {
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
fileInfoArr = nil
}
}
if len(fileInfoArr) > 0 {
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
}
}()
return ch
}
func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo {
mimeType := fileutil.DetectMimeType(fullPath, finfo, extended)
rtn := &wshrpc.FileInfo{
Path: wavebase.ReplaceHomeDir(fullPath),
Dir: computeDirPart(fullPath),
Name: finfo.Name(),
Size: finfo.Size(),
Mode: finfo.Mode(),
ModeStr: finfo.Mode().String(),
ModTime: finfo.ModTime().UnixMilli(),
IsDir: finfo.IsDir(),
MimeType: mimeType,
SupportsMkdir: true,
}
if finfo.IsDir() {
rtn.Size = -1
}
return rtn
}
// fileInfo might be null
func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool {
if !exists || fileInfo.Mode().IsDir() {
dirName := filepath.Dir(path)
randHexStr, err := utilfn.RandomHexString(12)
if err != nil {
// we're not sure, just return false
return false
}
tmpFileName := filepath.Join(dirName, "wsh-tmp-"+randHexStr)
fd, err := os.Create(tmpFileName)
if err != nil {
return true
}
utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName)
os.Remove(tmpFileName)
return false
}
// try to open for writing, if this fails then it is read-only
file, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return true
}
utilfn.GracefulClose(file, "checkIsReadOnly", path)
return false
}
func computeDirPart(path string) string {
path = filepath.Clean(wavebase.ExpandHomeDirSafe(path))
path = filepath.ToSlash(path)
if path == "/" {
return "/"
}
return filepath.Dir(path)
}
func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInfo, error) {
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
finfo, err := os.Stat(cleanedPath)
if os.IsNotExist(err) {
return &wshrpc.FileInfo{
Path: wavebase.ReplaceHomeDir(path),
Dir: computeDirPart(path),
NotFound: true,
ReadOnly: checkIsReadOnly(cleanedPath, finfo, false),
SupportsMkdir: true,
}, nil
}
if err != nil {
return nil, fmt.Errorf("cannot stat file %q: %w", path, err)
}
rtn := statToFileInfo(cleanedPath, finfo, extended)
if extended {
rtn.ReadOnly = checkIsReadOnly(cleanedPath, finfo, true)
}
return rtn, nil
}
func resolvePaths(paths []string) string {
if len(paths) == 0 {
return wavebase.ExpandHomeDirSafe("~")
}
rtnPath := wavebase.ExpandHomeDirSafe(paths[0])
for _, path := range paths[1:] {
path = wavebase.ExpandHomeDirSafe(path)
if filepath.IsAbs(path) {
rtnPath = path
continue
}
rtnPath = filepath.Join(rtnPath, path)
}
return rtnPath
}
func (impl *ServerImpl) RemoteFileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) {
rtnPath := resolvePaths(paths)
return impl.fileInfoInternal(rtnPath, true)
}
func (impl *ServerImpl) RemoteFileInfoCommand(ctx context.Context, path string) (*wshrpc.FileInfo, error) {
return impl.fileInfoInternal(path, true)
}
func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) error {
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
if _, err := os.Stat(cleanedPath); err == nil {
return fmt.Errorf("file %q already exists", path)
}
if err := os.MkdirAll(filepath.Dir(cleanedPath), 0755); err != nil {
return fmt.Errorf("cannot create directory %q: %w", filepath.Dir(cleanedPath), err)
}
if err := os.WriteFile(cleanedPath, []byte{}, 0644); err != nil {
return fmt.Errorf("cannot create file %q: %w", cleanedPath, err)
}
return nil
}
func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error {
opts := data.Opts
destUri := data.DestUri
srcUri := data.SrcUri
overwrite := opts != nil && opts.Overwrite
recursive := opts != nil && opts.Recursive
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
if err != nil {
return fmt.Errorf("cannot parse destination URI %q: %w", srcUri, err)
}
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
destinfo, err := os.Stat(destPathCleaned)
if err == nil {
if !destinfo.IsDir() {
if !overwrite {
return fmt.Errorf("destination %q already exists, use overwrite option", destUri)
} else {
err := os.Remove(destPathCleaned)
if err != nil {
return fmt.Errorf("cannot remove file %q: %w", destUri, err)
}
}
}
} else if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot stat destination %q: %w", destUri, err)
}
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
if err != nil {
return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
}
if srcConn.Host == destConn.Host {
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
finfo, err := os.Stat(srcPathCleaned)
if err != nil {
return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
}
if finfo.IsDir() && !recursive {
return fmt.Errorf(fstype.RecursiveRequiredError)
}
err = os.Rename(srcPathCleaned, destPathCleaned)
if err != nil {
return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err)
}
} else {
return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri)
}
return nil
}
func (impl *ServerImpl) RemoteMkdirCommand(ctx context.Context, path string) error {
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
if stat, err := os.Stat(cleanedPath); err == nil {
if stat.IsDir() {
return fmt.Errorf("directory %q already exists", path)
} else {
return fmt.Errorf("cannot create directory %q, file exists at path", path)
}
}
if err := os.MkdirAll(cleanedPath, 0755); err != nil {
return fmt.Errorf("cannot create directory %q: %w", cleanedPath, err)
}
return nil
}
func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileData) error {
var truncate, append bool
var atOffset int64
if data.Info != nil && data.Info.Opts != nil {
truncate = data.Info.Opts.Truncate
append = data.Info.Opts.Append
}
if data.At != nil {
atOffset = data.At.Offset
}
if truncate && atOffset > 0 {
return fmt.Errorf("cannot specify non-zero offset with truncate option")
}
if append && atOffset > 0 {
return fmt.Errorf("cannot specify non-zero offset with append option")
}
path, err := wavebase.ExpandHomeDir(data.Info.Path)
if err != nil {
return err
}
createMode := os.FileMode(0644)
if data.Info != nil && data.Info.Mode > 0 {
createMode = data.Info.Mode
}
dataSize := base64.StdEncoding.DecodedLen(len(data.Data64))
dataBytes := make([]byte, dataSize)
n, err := base64.StdEncoding.Decode(dataBytes, []byte(data.Data64))
if err != nil {
return fmt.Errorf("cannot decode base64 data: %w", err)
}
finfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot stat file %q: %w", path, err)
}
fileSize := int64(0)
if finfo != nil {
fileSize = finfo.Size()
}
if atOffset > fileSize {
return fmt.Errorf("cannot write at offset %d, file size is %d", atOffset, fileSize)
}
openFlags := os.O_CREATE | os.O_WRONLY
if truncate {
openFlags |= os.O_TRUNC
}
if append {
openFlags |= os.O_APPEND
}
file, err := os.OpenFile(path, openFlags, createMode)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", path, err)
}
defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path)
if atOffset > 0 && !append {
n, err = file.WriteAt(dataBytes[:n], atOffset)
} else {
n, err = file.Write(dataBytes[:n])
}
if err != nil {
return fmt.Errorf("cannot write to file %q: %w", path, err)
}
return nil
}
func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error {
expandedPath, err := wavebase.ExpandHomeDir(data.Path)
if err != nil {
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
}
cleanedPath := filepath.Clean(expandedPath)
err = os.Remove(cleanedPath)
if err != nil {
finfo, _ := os.Stat(cleanedPath)
if finfo != nil && finfo.IsDir() {
if !data.Recursive {
return fmt.Errorf(fstype.RecursiveRequiredError)
}
err = os.RemoveAll(cleanedPath)
if err != nil {
return fmt.Errorf("cannot delete directory %q: %w", data.Path, err)
}
} else {
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
}
}
return nil
}
func (*ServerImpl) RemoteGetInfoCommand(ctx context.Context) (wshrpc.RemoteInfo, error) {
return wshutil.GetInfo(), nil
}
@ -861,3 +91,14 @@ func (*ServerImpl) DisposeSuggestionsCommand(ctx context.Context, widgetId strin
suggestion.DisposeSuggestions(ctx, widgetId)
return nil
}
func (impl *ServerImpl) getWshPath() (string, error) {
if impl.IsLocal {
return filepath.Join(wavebase.GetWaveDataDir(), "bin", "wsh"), nil
}
wshPath, err := wavebase.ExpandHomeDir("~/.waveterm/bin/wsh")
if err != nil {
return "", fmt.Errorf("cannot expand wsh path: %w", err)
}
return wshPath, nil
}

View file

@ -0,0 +1,810 @@
// Copyright 2026, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshremote
import (
"archive/tar"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/wavetermdev/waveterm/pkg/remote/connparse"
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype"
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
"github.com/wavetermdev/waveterm/pkg/util/fileutil"
"github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes"
"github.com/wavetermdev/waveterm/pkg/util/tarcopy"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
type ByteRangeType struct {
All bool
Start int64
End int64
}
func parseByteRange(rangeStr string) (ByteRangeType, error) {
if rangeStr == "" {
return ByteRangeType{All: true}, nil
}
var start, end int64
_, err := fmt.Sscanf(rangeStr, "%d-%d", &start, &end)
if err != nil {
return ByteRangeType{}, errors.New("invalid byte range")
}
if start < 0 || end < 0 || start > end {
return ByteRangeType{}, errors.New("invalid byte range")
}
return ByteRangeType{Start: start, End: end}, nil
}
func (impl *ServerImpl) remoteStreamFileDir(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
innerFilesEntries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("cannot open dir %q: %w", path, err)
}
if byteRange.All {
if len(innerFilesEntries) > wshrpc.MaxDirSize {
innerFilesEntries = innerFilesEntries[:wshrpc.MaxDirSize]
}
} else {
if byteRange.Start < int64(len(innerFilesEntries)) {
realEnd := byteRange.End
if realEnd > int64(len(innerFilesEntries)) {
realEnd = int64(len(innerFilesEntries))
}
innerFilesEntries = innerFilesEntries[byteRange.Start:realEnd]
} else {
innerFilesEntries = []os.DirEntry{}
}
}
var fileInfoArr []*wshrpc.FileInfo
for _, innerFileEntry := range innerFilesEntries {
if ctx.Err() != nil {
return ctx.Err()
}
innerFileInfoInt, err := innerFileEntry.Info()
if err != nil {
continue
}
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
fileInfoArr = append(fileInfoArr, innerFileInfo)
if len(fileInfoArr) >= wshrpc.DirChunkSize {
dataCallback(fileInfoArr, nil, byteRange)
fileInfoArr = nil
}
}
if len(fileInfoArr) > 0 {
dataCallback(fileInfoArr, nil, byteRange)
}
return nil
}
func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
fd, err := os.Open(path)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", path, err)
}
defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path)
var filePos int64
if !byteRange.All && byteRange.Start > 0 {
_, err := fd.Seek(byteRange.Start, io.SeekStart)
if err != nil {
return fmt.Errorf("seeking file %q: %w", path, err)
}
filePos = byteRange.Start
}
buf := make([]byte, wshrpc.FileChunkSize)
for {
if ctx.Err() != nil {
return ctx.Err()
}
n, err := fd.Read(buf)
if n > 0 {
if !byteRange.All && filePos+int64(n) > byteRange.End {
n = int(byteRange.End - filePos)
}
filePos += int64(n)
dataCallback(nil, buf[:n], byteRange)
}
if !byteRange.All && filePos >= byteRange.End {
break
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("reading file %q: %w", path, err)
}
}
return nil
}
func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrpc.CommandRemoteStreamFileData, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
byteRange, err := parseByteRange(data.ByteRange)
if err != nil {
return err
}
path, err := wavebase.ExpandHomeDir(data.Path)
if err != nil {
return err
}
finfo, err := impl.fileInfoInternal(path, true)
if err != nil {
return fmt.Errorf("cannot stat file %q: %w", path, err)
}
dataCallback([]*wshrpc.FileInfo{finfo}, nil, byteRange)
if finfo.NotFound {
return nil
}
if finfo.IsDir {
return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback)
} else {
return impl.remoteStreamFileRegular(ctx, path, byteRange, dataCallback)
}
}
func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc.CommandRemoteStreamFileData) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] {
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16)
go func() {
defer close(ch)
firstPk := true
err := impl.remoteStreamFileInternal(ctx, data, func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType) {
resp := wshrpc.FileData{}
fileInfoLen := len(fileInfo)
if fileInfoLen > 1 || !firstPk {
resp.Entries = fileInfo
} else if fileInfoLen == 1 {
resp.Info = fileInfo[0]
}
if firstPk {
firstPk = false
}
if len(data) > 0 {
resp.Data64 = base64.StdEncoding.EncodeToString(data)
resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)}
}
ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp}
})
if err != nil {
ch <- wshutil.RespErr[wshrpc.FileData](err)
}
}()
return ch
}
func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] {
path := data.Path
opts := data.Opts
if opts == nil {
opts = &wshrpc.FileCopyOpts{}
}
log.Printf("RemoteTarStreamCommand: path=%s\n", path)
srcHasSlash := strings.HasSuffix(path, "/")
path, err := wavebase.ExpandHomeDir(path)
if err != nil {
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err))
}
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
finfo, err := os.Stat(cleanedPath)
if err != nil {
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err))
}
var pathPrefix string
singleFile := !finfo.IsDir()
if !singleFile && srcHasSlash {
pathPrefix = cleanedPath
} else {
pathPrefix = filepath.Dir(cleanedPath)
}
timeout := fstype.DefaultTimeout
if opts.Timeout > 0 {
timeout = time.Duration(opts.Timeout) * time.Millisecond
}
readerCtx, cancel := context.WithTimeout(ctx, timeout)
rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix)
go func() {
defer func() {
tarClose()
cancel()
}()
walkFunc := func(path string, info fs.FileInfo, err error) error {
if readerCtx.Err() != nil {
return readerCtx.Err()
}
if err != nil {
return err
}
if err = writeHeader(info, path, singleFile); err != nil {
return err
}
// if not a dir, write file content
if !info.IsDir() {
data, err := os.Open(path)
if err != nil {
return err
}
defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path)
if _, err := io.Copy(fileWriter, data); err != nil {
return err
}
}
return nil
}
log.Printf("RemoteTarStreamCommand: starting\n")
err = nil
if singleFile {
err = walkFunc(cleanedPath, finfo, nil)
} else {
err = filepath.Walk(cleanedPath, walkFunc)
}
if err != nil {
rtn <- wshutil.RespErr[iochantypes.Packet](err)
}
log.Printf("RemoteTarStreamCommand: done\n")
}()
log.Printf("RemoteTarStreamCommand: returning channel\n")
return rtn
}
func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) (bool, error) {
log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri)
opts := data.Opts
if opts == nil {
opts = &wshrpc.FileCopyOpts{}
}
destUri := data.DestUri
srcUri := data.SrcUri
merge := opts.Merge
overwrite := opts.Overwrite
if overwrite && merge {
return false, fmt.Errorf("cannot specify both overwrite and merge")
}
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
if err != nil {
return false, fmt.Errorf("cannot parse destination URI %q: %w", destUri, err)
}
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
destinfo, err := os.Stat(destPathCleaned)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return false, fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err)
}
}
destExists := destinfo != nil
destIsDir := destExists && destinfo.IsDir()
destHasSlash := strings.HasSuffix(destUri, "/")
if destExists && !destIsDir {
if !overwrite {
return false, fmt.Errorf(fstype.OverwriteRequiredError, destPathCleaned)
} else {
err := os.Remove(destPathCleaned)
if err != nil {
return false, fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err)
}
}
}
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
if err != nil {
return false, fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
}
copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) {
nextinfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
}
if nextinfo != nil {
if nextinfo.IsDir() {
if !finfo.IsDir() {
// try to create file in directory
path = filepath.Join(path, filepath.Base(finfo.Name()))
newdestinfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
}
if newdestinfo != nil && !overwrite {
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
}
} else if overwrite {
err := os.RemoveAll(path)
if err != nil {
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
}
} else if !merge {
return 0, fmt.Errorf(fstype.MergeRequiredError, path)
}
} else {
if !overwrite {
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
} else if finfo.IsDir() {
err := os.RemoveAll(path)
if err != nil {
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
}
}
}
}
if finfo.IsDir() {
err := os.MkdirAll(path, finfo.Mode())
if err != nil {
return 0, fmt.Errorf("cannot create directory %q: %w", path, err)
}
return 0, nil
} else {
err := os.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err)
}
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
if err != nil {
return 0, fmt.Errorf("cannot create new file %q: %w", path, err)
}
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path)
_, err = io.Copy(file, srcFile)
if err != nil {
return 0, fmt.Errorf("cannot write file %q: %w", path, err)
}
return finfo.Size(), nil
}
srcIsDir := false
if srcConn.Host == destConn.Host {
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
srcFileStat, err := os.Stat(srcPathCleaned)
if err != nil {
return false, fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
}
if srcFileStat.IsDir() {
srcIsDir = true
var srcPathPrefix string
if destIsDir {
srcPathPrefix = filepath.Dir(srcPathCleaned)
} else {
srcPathPrefix = srcPathCleaned
}
err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
srcFilePath := path
destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix))
var file *os.File
if !info.IsDir() {
file, err = os.Open(srcFilePath)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", srcFilePath, err)
}
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath)
}
_, err = copyFileFunc(destFilePath, info, file)
return err
})
if err != nil {
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
} else {
file, err := os.Open(srcPathCleaned)
if err != nil {
return false, fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err)
}
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned)
var destFilePath string
if destHasSlash {
destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned))
} else {
destFilePath = destPathCleaned
}
_, err = copyFileFunc(destFilePath, srcFileStat, file)
if err != nil {
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
}
} else {
timeout := fstype.DefaultTimeout
if opts.Timeout > 0 {
timeout = time.Duration(opts.Timeout) * time.Millisecond
}
readCtx, cancel := context.WithCancelCause(ctx)
readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri))
defer timeoutCancel()
copyStart := time.Now()
ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout})
numFiles := 0
numSkipped := 0
totalBytes := int64(0)
err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error {
numFiles++
nextpath := filepath.Join(destPathCleaned, next.Name)
srcIsDir = !singleFile
if singleFile && !destHasSlash {
// custom flag to indicate that the source is a single file, not a directory the contents of a directory
nextpath = destPathCleaned
}
finfo := next.FileInfo()
n, err := copyFileFunc(nextpath, finfo, reader)
if err != nil {
return fmt.Errorf("cannot copy file %q: %w", next.Name, err)
}
totalBytes += n
return nil
})
if err != nil {
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
}
totalTime := time.Since(copyStart).Seconds()
totalMegaBytes := float64(totalBytes) / 1024 / 1024
rate := float64(0)
if totalTime > 0 {
rate = totalMegaBytes / totalTime
}
log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped)
}
return srcIsDir, nil
}
func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrpc.CommandRemoteListEntriesData) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] {
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16)
go func() {
defer close(ch)
path, err := wavebase.ExpandHomeDir(data.Path)
if err != nil {
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err)
return
}
innerFilesEntries := []os.DirEntry{}
seen := 0
if data.Opts.Limit == 0 {
data.Opts.Limit = wshrpc.MaxDirSize
}
if data.Opts.All {
fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error {
defer func() {
seen++
}()
if seen < data.Opts.Offset {
return nil
}
if seen >= data.Opts.Offset+data.Opts.Limit {
return io.EOF
}
if err != nil {
return err
}
if d.IsDir() {
return nil
}
innerFilesEntries = append(innerFilesEntries, d)
return nil
})
} else {
innerFilesEntries, err = os.ReadDir(path)
if err != nil {
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](fmt.Errorf("cannot open dir %q: %w", path, err))
return
}
}
var fileInfoArr []*wshrpc.FileInfo
for _, innerFileEntry := range innerFilesEntries {
if ctx.Err() != nil {
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](ctx.Err())
return
}
innerFileInfoInt, err := innerFileEntry.Info()
if err != nil {
log.Printf("cannot stat file %q: %v\n", innerFileEntry.Name(), err)
continue
}
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
fileInfoArr = append(fileInfoArr, innerFileInfo)
if len(fileInfoArr) >= wshrpc.DirChunkSize {
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
fileInfoArr = nil
}
}
if len(fileInfoArr) > 0 {
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
}
}()
return ch
}
func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo {
mimeType := fileutil.DetectMimeType(fullPath, finfo, extended)
rtn := &wshrpc.FileInfo{
Path: wavebase.ReplaceHomeDir(fullPath),
Dir: computeDirPart(fullPath),
Name: finfo.Name(),
Size: finfo.Size(),
Mode: finfo.Mode(),
ModeStr: finfo.Mode().String(),
ModTime: finfo.ModTime().UnixMilli(),
IsDir: finfo.IsDir(),
MimeType: mimeType,
SupportsMkdir: true,
}
if finfo.IsDir() {
rtn.Size = -1
}
return rtn
}
// fileInfo might be null
func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool {
if !exists || fileInfo.Mode().IsDir() {
dirName := filepath.Dir(path)
randHexStr, err := utilfn.RandomHexString(12)
if err != nil {
// we're not sure, just return false
return false
}
tmpFileName := filepath.Join(dirName, "wsh-tmp-"+randHexStr)
fd, err := os.Create(tmpFileName)
if err != nil {
return true
}
utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName)
os.Remove(tmpFileName)
return false
}
// try to open for writing, if this fails then it is read-only
file, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return true
}
utilfn.GracefulClose(file, "checkIsReadOnly", path)
return false
}
func computeDirPart(path string) string {
path = filepath.Clean(wavebase.ExpandHomeDirSafe(path))
path = filepath.ToSlash(path)
if path == "/" {
return "/"
}
return filepath.Dir(path)
}
func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInfo, error) {
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
finfo, err := os.Stat(cleanedPath)
if os.IsNotExist(err) {
return &wshrpc.FileInfo{
Path: wavebase.ReplaceHomeDir(path),
Dir: computeDirPart(path),
NotFound: true,
ReadOnly: checkIsReadOnly(cleanedPath, finfo, false),
SupportsMkdir: true,
}, nil
}
if err != nil {
return nil, fmt.Errorf("cannot stat file %q: %w", path, err)
}
rtn := statToFileInfo(cleanedPath, finfo, extended)
if extended {
rtn.ReadOnly = checkIsReadOnly(cleanedPath, finfo, true)
}
return rtn, nil
}
func resolvePaths(paths []string) string {
if len(paths) == 0 {
return wavebase.ExpandHomeDirSafe("~")
}
rtnPath := wavebase.ExpandHomeDirSafe(paths[0])
for _, path := range paths[1:] {
path = wavebase.ExpandHomeDirSafe(path)
if filepath.IsAbs(path) {
rtnPath = path
continue
}
rtnPath = filepath.Join(rtnPath, path)
}
return rtnPath
}
func (impl *ServerImpl) RemoteFileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) {
rtnPath := resolvePaths(paths)
return impl.fileInfoInternal(rtnPath, true)
}
func (impl *ServerImpl) RemoteFileInfoCommand(ctx context.Context, path string) (*wshrpc.FileInfo, error) {
return impl.fileInfoInternal(path, true)
}
func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) error {
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
if _, err := os.Stat(cleanedPath); err == nil {
return fmt.Errorf("file %q already exists", path)
}
if err := os.MkdirAll(filepath.Dir(cleanedPath), 0755); err != nil {
return fmt.Errorf("cannot create directory %q: %w", filepath.Dir(cleanedPath), err)
}
if err := os.WriteFile(cleanedPath, []byte{}, 0644); err != nil {
return fmt.Errorf("cannot create file %q: %w", cleanedPath, err)
}
return nil
}
func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error {
opts := data.Opts
destUri := data.DestUri
srcUri := data.SrcUri
overwrite := opts != nil && opts.Overwrite
recursive := opts != nil && opts.Recursive
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
if err != nil {
return fmt.Errorf("cannot parse destination URI %q: %w", srcUri, err)
}
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
destinfo, err := os.Stat(destPathCleaned)
if err == nil {
if !destinfo.IsDir() {
if !overwrite {
return fmt.Errorf("destination %q already exists, use overwrite option", destUri)
} else {
err := os.Remove(destPathCleaned)
if err != nil {
return fmt.Errorf("cannot remove file %q: %w", destUri, err)
}
}
}
} else if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot stat destination %q: %w", destUri, err)
}
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
if err != nil {
return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
}
if srcConn.Host == destConn.Host {
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
finfo, err := os.Stat(srcPathCleaned)
if err != nil {
return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
}
if finfo.IsDir() && !recursive {
return fmt.Errorf(fstype.RecursiveRequiredError)
}
err = os.Rename(srcPathCleaned, destPathCleaned)
if err != nil {
return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err)
}
} else {
return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri)
}
return nil
}
func (impl *ServerImpl) RemoteMkdirCommand(ctx context.Context, path string) error {
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
if stat, err := os.Stat(cleanedPath); err == nil {
if stat.IsDir() {
return fmt.Errorf("directory %q already exists", path)
} else {
return fmt.Errorf("cannot create directory %q, file exists at path", path)
}
}
if err := os.MkdirAll(cleanedPath, 0755); err != nil {
return fmt.Errorf("cannot create directory %q: %w", cleanedPath, err)
}
return nil
}
func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileData) error {
var truncate, append bool
var atOffset int64
if data.Info != nil && data.Info.Opts != nil {
truncate = data.Info.Opts.Truncate
append = data.Info.Opts.Append
}
if data.At != nil {
atOffset = data.At.Offset
}
if truncate && atOffset > 0 {
return fmt.Errorf("cannot specify non-zero offset with truncate option")
}
if append && atOffset > 0 {
return fmt.Errorf("cannot specify non-zero offset with append option")
}
path, err := wavebase.ExpandHomeDir(data.Info.Path)
if err != nil {
return err
}
createMode := os.FileMode(0644)
if data.Info != nil && data.Info.Mode > 0 {
createMode = data.Info.Mode
}
dataSize := base64.StdEncoding.DecodedLen(len(data.Data64))
dataBytes := make([]byte, dataSize)
n, err := base64.StdEncoding.Decode(dataBytes, []byte(data.Data64))
if err != nil {
return fmt.Errorf("cannot decode base64 data: %w", err)
}
finfo, err := os.Stat(path)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("cannot stat file %q: %w", path, err)
}
fileSize := int64(0)
if finfo != nil {
fileSize = finfo.Size()
}
if atOffset > fileSize {
return fmt.Errorf("cannot write at offset %d, file size is %d", atOffset, fileSize)
}
openFlags := os.O_CREATE | os.O_WRONLY
if truncate {
openFlags |= os.O_TRUNC
}
if append {
openFlags |= os.O_APPEND
}
file, err := os.OpenFile(path, openFlags, createMode)
if err != nil {
return fmt.Errorf("cannot open file %q: %w", path, err)
}
defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path)
if atOffset > 0 && !append {
n, err = file.WriteAt(dataBytes[:n], atOffset)
} else {
n, err = file.Write(dataBytes[:n])
}
if err != nil {
return fmt.Errorf("cannot write to file %q: %w", path, err)
}
return nil
}
func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error {
expandedPath, err := wavebase.ExpandHomeDir(data.Path)
if err != nil {
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
}
cleanedPath := filepath.Clean(expandedPath)
err = os.Remove(cleanedPath)
if err != nil {
finfo, _ := os.Stat(cleanedPath)
if finfo != nil && finfo.IsDir() {
if !data.Recursive {
return fmt.Errorf(fstype.RecursiveRequiredError)
}
err = os.RemoveAll(cleanedPath)
if err != nil {
return fmt.Errorf("cannot delete directory %q: %w", data.Path, err)
}
} else {
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
}
}
return nil
}

View file

@ -0,0 +1,352 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshremote
import (
"bufio"
"context"
"fmt"
"log"
"net"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
"github.com/shirou/gopsutil/v4/process"
"github.com/wavetermdev/waveterm/pkg/jobmanager"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
func isProcessRunning(pid int, pidStartTs int64) (*process.Process, error) {
if pid <= 0 {
return nil, nil
}
proc, err := process.NewProcess(int32(pid))
if err != nil {
return nil, nil
}
createTime, err := proc.CreateTime()
if err != nil {
return nil, err
}
if createTime != pidStartTs {
return nil, nil
}
return proc, nil
}
// returns jobRouteId, cleanupFunc, error
func (impl *ServerImpl) connectToJobManager(ctx context.Context, jobId string, mainServerJwtToken string) (string, func(), error) {
socketPath := jobmanager.GetJobSocketPath(jobId)
log.Printf("connectToJobManager: connecting to socket: %s\n", socketPath)
conn, err := net.Dial("unix", socketPath)
if err != nil {
log.Printf("connectToJobManager: error connecting to socket: %v\n", err)
return "", nil, fmt.Errorf("cannot connect to job manager socket: %w", err)
}
log.Printf("connectToJobManager: connected to socket\n")
proxy := wshutil.MakeRpcProxy("jobmanager")
linkId := impl.Router.RegisterUntrustedLink(proxy)
var cleanupOnce sync.Once
cleanup := func() {
cleanupOnce.Do(func() {
conn.Close()
impl.Router.UnregisterLink(linkId)
impl.removeJobManagerConnection(jobId)
})
}
go func() {
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("connectToJobManager: error writing to job manager socket: %v\n", writeErr)
}
}()
go func() {
defer func() {
close(proxy.FromRemoteCh)
cleanup()
}()
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
routeId := wshutil.MakeLinkRouteId(linkId)
authData := wshrpc.CommandAuthenticateToJobData{
JobAccessToken: mainServerJwtToken,
}
err = wshclient.AuthenticateToJobManagerCommand(impl.RpcClient, authData, &wshrpc.RpcOpts{Route: routeId})
if err != nil {
cleanup()
return "", nil, fmt.Errorf("authentication to job manager failed: %w", err)
}
jobRouteId := wshutil.MakeJobRouteId(jobId)
waitCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
err = impl.Router.WaitForRegister(waitCtx, jobRouteId)
if err != nil {
cleanup()
return "", nil, fmt.Errorf("timeout waiting for job route to register: %w", err)
}
jobConn := &JobManagerConnection{
JobId: jobId,
Conn: conn,
CleanupFn: cleanup,
}
impl.addJobManagerConnection(jobConn)
log.Printf("connectToJobManager: successfully connected and authenticated\n")
return jobRouteId, cleanup, nil
}
func (impl *ServerImpl) addJobManagerConnection(conn *JobManagerConnection) {
impl.Lock.Lock()
defer impl.Lock.Unlock()
impl.JobManagerMap[conn.JobId] = conn
log.Printf("addJobManagerConnection: added job manager connection for jobid=%s\n", conn.JobId)
}
func (impl *ServerImpl) removeJobManagerConnection(jobId string) {
impl.Lock.Lock()
defer impl.Lock.Unlock()
if _, exists := impl.JobManagerMap[jobId]; exists {
delete(impl.JobManagerMap, jobId)
log.Printf("removeJobManagerConnection: removed job manager connection for jobid=%s\n", jobId)
}
}
func (impl *ServerImpl) getJobManagerConnection(jobId string) *JobManagerConnection {
impl.Lock.Lock()
defer impl.Lock.Unlock()
return impl.JobManagerMap[jobId]
}
func (impl *ServerImpl) RemoteStartJobCommand(ctx context.Context, data wshrpc.CommandRemoteStartJobData) (*wshrpc.CommandStartJobRtnData, error) {
log.Printf("RemoteStartJobCommand: starting, jobid=%s, clientid=%s\n", data.JobId, data.ClientId)
if impl.Router == nil {
return nil, fmt.Errorf("cannot start remote job: no router available")
}
wshPath, err := impl.getWshPath()
if err != nil {
return nil, err
}
log.Printf("RemoteStartJobCommand: wshPath=%s\n", wshPath)
readyPipeRead, readyPipeWrite, err := os.Pipe()
if err != nil {
return nil, fmt.Errorf("cannot create ready pipe: %w", err)
}
defer readyPipeRead.Close()
defer readyPipeWrite.Close()
cmd := exec.Command(wshPath, "jobmanager", "--jobid", data.JobId, "--clientid", data.ClientId)
if data.PublicKeyBase64 != "" {
cmd.Env = append(os.Environ(), "WAVETERM_PUBLICKEY="+data.PublicKeyBase64)
}
cmd.ExtraFiles = []*os.File{readyPipeWrite}
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("cannot create stdin pipe: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("cannot create stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("cannot create stderr pipe: %w", err)
}
log.Printf("RemoteStartJobCommand: created pipes\n")
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("cannot start job manager: %w", err)
}
readyPipeWrite.Close()
log.Printf("RemoteStartJobCommand: job manager process started\n")
jobAuthTokenLine := fmt.Sprintf("Wave-JobAccessToken:%s\n", data.JobAuthToken)
if _, err := stdin.Write([]byte(jobAuthTokenLine)); err != nil {
cmd.Process.Kill()
return nil, fmt.Errorf("cannot write job auth token: %w", err)
}
stdin.Close()
log.Printf("RemoteStartJobCommand: wrote auth token to stdin\n")
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
log.Printf("RemoteStartJobCommand: stderr: %s\n", line)
}
if err := scanner.Err(); err != nil {
log.Printf("RemoteStartJobCommand: error reading stderr: %v\n", err)
} else {
log.Printf("RemoteStartJobCommand: stderr EOF\n")
}
}()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
line := scanner.Text()
log.Printf("RemoteStartJobCommand: stdout: %s\n", line)
}
if err := scanner.Err(); err != nil {
log.Printf("RemoteStartJobCommand: error reading stdout: %v\n", err)
} else {
log.Printf("RemoteStartJobCommand: stdout EOF\n")
}
}()
startCh := make(chan error, 1)
go func() {
scanner := bufio.NewScanner(readyPipeRead)
for scanner.Scan() {
line := scanner.Text()
log.Printf("RemoteStartJobCommand: ready pipe line: %s\n", line)
if strings.Contains(line, "Wave-JobManagerStart") {
startCh <- nil
return
}
}
if err := scanner.Err(); err != nil {
startCh <- fmt.Errorf("error reading ready pipe: %w", err)
} else {
log.Printf("RemoteStartJobCommand: ready pipe EOF\n")
startCh <- fmt.Errorf("job manager exited without start signal")
}
}()
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
log.Printf("RemoteStartJobCommand: waiting for start signal\n")
select {
case err := <-startCh:
if err != nil {
cmd.Process.Kill()
log.Printf("RemoteStartJobCommand: error from start signal: %v\n", err)
return nil, err
}
log.Printf("RemoteStartJobCommand: received start signal\n")
case <-timeoutCtx.Done():
cmd.Process.Kill()
log.Printf("RemoteStartJobCommand: timeout waiting for start signal\n")
return nil, fmt.Errorf("timeout waiting for job manager to start")
}
go func() {
cmd.Wait()
}()
jobRouteId, cleanup, err := impl.connectToJobManager(ctx, data.JobId, data.MainServerJwtToken)
if err != nil {
return nil, err
}
startJobData := wshrpc.CommandStartJobData{
Cmd: data.Cmd,
Args: data.Args,
Env: data.Env,
TermSize: data.TermSize,
StreamMeta: data.StreamMeta,
}
rtnData, err := wshclient.StartJobCommand(impl.RpcClient, startJobData, &wshrpc.RpcOpts{Route: jobRouteId})
if err != nil {
cleanup()
return nil, fmt.Errorf("failed to start job: %w", err)
}
return rtnData, nil
}
func (impl *ServerImpl) RemoteReconnectToJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteReconnectToJobManagerData) (*wshrpc.CommandRemoteReconnectToJobManagerRtnData, error) {
log.Printf("RemoteReconnectToJobManagerCommand: reconnecting, jobid=%s\n", data.JobId)
if impl.Router == nil {
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
Success: false,
Error: "cannot reconnect to job manager: no router available",
}, nil
}
proc, err := isProcessRunning(data.JobManagerPid, data.JobManagerStartTs)
if err != nil {
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
Success: false,
Error: fmt.Sprintf("error checking job manager process: %v", err),
}, nil
}
if proc == nil {
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
Success: false,
JobManagerGone: true,
Error: fmt.Sprintf("job manager process (pid=%d) is not running", data.JobManagerPid),
}, nil
}
existingConn := impl.getJobManagerConnection(data.JobId)
if existingConn != nil {
log.Printf("RemoteReconnectToJobManagerCommand: closing existing connection for jobid=%s\n", data.JobId)
if existingConn.CleanupFn != nil {
existingConn.CleanupFn()
}
}
_, _, err = impl.connectToJobManager(ctx, data.JobId, data.MainServerJwtToken)
if err != nil {
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
Success: false,
Error: err.Error(),
}, nil
}
log.Printf("RemoteReconnectToJobManagerCommand: successfully reconnected to job manager\n")
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
Success: true,
}, nil
}
func (impl *ServerImpl) RemoteDisconnectFromJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteDisconnectFromJobManagerData) error {
log.Printf("RemoteDisconnectFromJobManagerCommand: disconnecting, jobid=%s\n", data.JobId)
conn := impl.getJobManagerConnection(data.JobId)
if conn == nil {
log.Printf("RemoteDisconnectFromJobManagerCommand: no connection found for jobid=%s\n", data.JobId)
return nil
}
if conn.CleanupFn != nil {
conn.CleanupFn()
log.Printf("RemoteDisconnectFromJobManagerCommand: cleanup completed for jobid=%s\n", data.JobId)
}
return nil
}
func (impl *ServerImpl) RemoteTerminateJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteTerminateJobManagerData) error {
log.Printf("RemoteTerminateJobManagerCommand: terminating job manager, jobid=%s, pid=%d\n", data.JobId, data.JobManagerPid)
proc, err := isProcessRunning(data.JobManagerPid, data.JobManagerStartTs)
if err != nil {
return fmt.Errorf("error checking job manager process: %w", err)
}
if proc == nil {
log.Printf("RemoteTerminateJobManagerCommand: job manager process not running, jobid=%s\n", data.JobId)
return nil
}
err = proc.SendSignal(syscall.SIGTERM)
if err != nil {
log.Printf("failed to send SIGTERM to job manager: %v", err)
} else {
log.Printf("RemoteTerminateJobManagerCommand: sent SIGTERM to job manager process, jobid=%s, pid=%d\n", data.JobId, data.JobManagerPid)
}
return nil
}

View file

@ -22,10 +22,18 @@ type RespOrErrorUnion[T any] struct {
Error error
}
// Instructions for adding a new RPC call
// * methods must end with Command
// * methods must take context as their first parameter
// * methods may take up to one parameter, and may return either just an error, or one return value plus an error
// * after modifying WshRpcInterface, run `task generate` to regnerate bindings
type WshRpcInterface interface {
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
AuthenticateTokenCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error)
AuthenticateTokenVerifyCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error) // (special) validates token without binding, root router only
AuthenticateJobManagerCommand(ctx context.Context, data CommandAuthenticateJobManagerData) error
AuthenticateJobManagerVerifyCommand(ctx context.Context, data CommandAuthenticateJobManagerData) error // (special) validates job auth token without binding, root router only
DisposeCommand(ctx context.Context, data CommandDisposeData) error
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
@ -100,6 +108,10 @@ type WshRpcInterface interface {
RemoteStreamCpuDataCommand(ctx context.Context) chan RespOrErrorUnion[TimeSeriesData]
RemoteGetInfoCommand(ctx context.Context) (RemoteInfo, error)
RemoteInstallRcFilesCommand(ctx context.Context) error
RemoteStartJobCommand(ctx context.Context, data CommandRemoteStartJobData) (*CommandStartJobRtnData, error)
RemoteReconnectToJobManagerCommand(ctx context.Context, data CommandRemoteReconnectToJobManagerData) (*CommandRemoteReconnectToJobManagerRtnData, error)
RemoteDisconnectFromJobManagerCommand(ctx context.Context, data CommandRemoteDisconnectFromJobManagerData) error
RemoteTerminateJobManagerCommand(ctx context.Context, data CommandRemoteTerminateJobManagerData) error
// emain
WebSelectorCommand(ctx context.Context, data CommandWebSelectorData) ([]string, error)
@ -140,6 +152,7 @@ type WshRpcInterface interface {
// terminal
TermGetScrollbackLinesCommand(ctx context.Context, data CommandTermGetScrollbackLinesData) (*CommandTermGetScrollbackLinesRtnData, error)
TermUpdateAttachedJobCommand(ctx context.Context, data CommandTermUpdateAttachedJobData) error
// file
WshRpcFileInterface
@ -154,6 +167,26 @@ type WshRpcInterface interface {
// streams
StreamDataCommand(ctx context.Context, data CommandStreamData) error
StreamDataAckCommand(ctx context.Context, data CommandStreamAckData) error
// jobs
AuthenticateToJobManagerCommand(ctx context.Context, data CommandAuthenticateToJobData) error
StartJobCommand(ctx context.Context, data CommandStartJobData) (*CommandStartJobRtnData, error)
JobPrepareConnectCommand(ctx context.Context, data CommandJobPrepareConnectData) (*CommandJobConnectRtnData, error)
JobStartStreamCommand(ctx context.Context, data CommandJobStartStreamData) error
JobInputCommand(ctx context.Context, data CommandJobInputData) error
JobCmdExitedCommand(ctx context.Context, data CommandJobCmdExitedData) error // this is sent FROM the job manager => main server
// job controller
JobControllerDeleteJobCommand(ctx context.Context, jobId string) error
JobControllerListCommand(ctx context.Context) ([]*waveobj.Job, error)
JobControllerStartJobCommand(ctx context.Context, data CommandJobControllerStartJobData) (string, error)
JobControllerExitJobCommand(ctx context.Context, jobId string) error
JobControllerDisconnectJobCommand(ctx context.Context, jobId string) error
JobControllerReconnectJobCommand(ctx context.Context, jobId string) error
JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error
JobControllerConnectedJobsCommand(ctx context.Context) ([]string, error)
JobControllerAttachJobCommand(ctx context.Context, data CommandJobControllerAttachJobData) error
JobControllerDetachJobCommand(ctx context.Context, jobId string) error
}
// for frontend
@ -250,6 +283,13 @@ type CommandBlockInputData struct {
TermSize *waveobj.TermSize `json:"termsize,omitempty"`
}
type CommandJobInputData struct {
JobId string `json:"jobid"`
InputData64 string `json:"inputdata64,omitempty"`
SigName string `json:"signame,omitempty"`
TermSize *waveobj.TermSize `json:"termsize,omitempty"`
}
type CommandWaitForRouteData struct {
RouteId string `json:"routeid"`
WaitMs int `json:"waitms"`
@ -614,6 +654,11 @@ type CommandTermGetScrollbackLinesRtnData struct {
LastUpdated int64 `json:"lastupdated"`
}
type CommandTermUpdateAttachedJobData struct {
BlockId string `json:"blockid"`
JobId string `json:"jobid,omitempty"`
}
type CommandElectronEncryptData struct {
PlainText string `json:"plaintext"`
}
@ -633,7 +678,7 @@ type CommandElectronDecryptRtnData struct {
}
type CommandStreamData struct {
Id int64 `json:"id"` // streamid
Id string `json:"id"` // streamid
Seq int64 `json:"seq"` // start offset (bytes)
Data64 string `json:"data64,omitempty"`
Eof bool `json:"eof,omitempty"` // can be set with data or without
@ -641,7 +686,7 @@ type CommandStreamData struct {
}
type CommandStreamAckData struct {
Id int64 `json:"id"` // streamid
Id string `json:"id"` // streamid
Seq int64 `json:"seq"` // next expected byte
RWnd int64 `json:"rwnd"` // receive window size
Fin bool `json:"fin,omitempty"` // observed end-of-stream (eof or error)
@ -651,8 +696,108 @@ type CommandStreamAckData struct {
}
type StreamMeta struct {
Id int64 `json:"id"` // streamid
Id string `json:"id"` // streamid
RWnd int64 `json:"rwnd"` // initial receive window size
ReaderRouteId string `json:"readerrouteid"`
WriterRouteId string `json:"writerrouteid"`
}
type CommandAuthenticateToJobData struct {
JobAccessToken string `json:"jobaccesstoken"`
}
type CommandAuthenticateJobManagerData struct {
JobId string `json:"jobid"`
JobAuthToken string `json:"jobauthtoken"`
}
type CommandStartJobData struct {
Cmd string `json:"cmd"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
TermSize waveobj.TermSize `json:"termsize"`
StreamMeta *StreamMeta `json:"streammeta,omitempty"`
}
type CommandRemoteStartJobData struct {
Cmd string `json:"cmd"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
TermSize waveobj.TermSize `json:"termsize"`
StreamMeta *StreamMeta `json:"streammeta,omitempty"`
JobAuthToken string `json:"jobauthtoken"`
JobId string `json:"jobid"`
MainServerJwtToken string `json:"mainserverjwttoken"`
ClientId string `json:"clientid"`
PublicKeyBase64 string `json:"publickeybase64"`
}
type CommandRemoteReconnectToJobManagerData struct {
JobId string `json:"jobid"`
JobAuthToken string `json:"jobauthtoken"`
MainServerJwtToken string `json:"mainserverjwttoken"`
JobManagerPid int `json:"jobmanagerpid"`
JobManagerStartTs int64 `json:"jobmanagerstartts"`
}
type CommandRemoteReconnectToJobManagerRtnData struct {
Success bool `json:"success"`
JobManagerGone bool `json:"jobmanagergone"`
Error string `json:"error,omitempty"`
}
type CommandRemoteDisconnectFromJobManagerData struct {
JobId string `json:"jobid"`
}
type CommandRemoteTerminateJobManagerData struct {
JobId string `json:"jobid"`
JobManagerPid int `json:"jobmanagerpid"`
JobManagerStartTs int64 `json:"jobmanagerstartts"`
}
type CommandStartJobRtnData struct {
CmdPid int `json:"cmdpid"`
CmdStartTs int64 `json:"cmdstartts"`
JobManagerPid int `json:"jobmanagerpid"`
JobManagerStartTs int64 `json:"jobmanagerstartts"`
}
type CommandJobPrepareConnectData struct {
StreamMeta StreamMeta `json:"streammeta"`
Seq int64 `json:"seq"`
}
type CommandJobStartStreamData struct {
}
type CommandJobConnectRtnData struct {
Seq int64 `json:"seq"`
StreamDone bool `json:"streamdone,omitempty"`
StreamError string `json:"streamerror,omitempty"`
HasExited bool `json:"hasexited,omitempty"`
ExitCode *int `json:"exitcode,omitempty"`
ExitSignal string `json:"exitsignal,omitempty"`
ExitErr string `json:"exiterr,omitempty"`
}
type CommandJobCmdExitedData struct {
JobId string `json:"jobid"`
ExitCode *int `json:"exitcode,omitempty"`
ExitSignal string `json:"exitsignal,omitempty"`
ExitErr string `json:"exiterr,omitempty"`
ExitTs int64 `json:"exitts,omitempty"`
}
type CommandJobControllerStartJobData struct {
ConnName string `json:"connname"`
Cmd string `json:"cmd"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
TermSize *waveobj.TermSize `json:"termsize,omitempty"`
}
type CommandJobControllerAttachJobData struct {
JobId string `json:"jobid"`
BlockId string `json:"blockid"`
}

View file

@ -35,13 +35,16 @@ const (
// we only need consts for special commands handled in the router or
// in the RPC code / WPS code directly. other commands go through the clients
const (
Command_Authenticate = "authenticate" // $control
Command_AuthenticateToken = "authenticatetoken" // $control
Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only)
Command_RouteAnnounce = "routeannounce" // $control (for routing)
Command_RouteUnannounce = "routeunannounce" // $control (for routing)
Command_Ping = "ping" // $control
Command_ControllerInput = "controllerinput"
Command_EventRecv = "eventrecv"
Command_Message = "message"
Command_Authenticate = "authenticate" // $control
Command_AuthenticateToken = "authenticatetoken" // $control
Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only)
Command_AuthenticateJobManagerVerify = "authenticatejobmanagerverify" // $control:root (internal, for job auth token validation only)
Command_RouteAnnounce = "routeannounce" // $control (for routing)
Command_RouteUnannounce = "routeunannounce" // $control (for routing)
Command_Ping = "ping" // $control
Command_ControllerInput = "controllerinput"
Command_EventRecv = "eventrecv"
Command_Message = "message"
Command_StreamData = "streamdata"
Command_StreamDataAck = "streamdataack"
)

View file

@ -29,6 +29,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/filebackup"
"github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/genconn"
"github.com/wavetermdev/waveterm/pkg/jobcontroller"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/remote"
"github.com/wavetermdev/waveterm/pkg/remote/awsconn"
@ -294,6 +295,21 @@ func (ws *WshServer) ControllerResyncCommand(ctx context.Context, data wshrpc.Co
}
func (ws *WshServer) ControllerInputCommand(ctx context.Context, data wshrpc.CommandBlockInputData) error {
block, err := wstore.DBMustGet[*waveobj.Block](ctx, data.BlockId)
if err != nil {
return fmt.Errorf("error getting block: %w", err)
}
if block.JobId != "" {
jobInputData := wshrpc.CommandJobInputData{
JobId: block.JobId,
InputData64: data.InputData64,
SigName: data.SigName,
TermSize: data.TermSize,
}
return jobcontroller.SendInput(ctx, jobInputData)
}
inputUnion := &blockcontroller.BlockInputUnion{
SigName: data.SigName,
TermSize: data.TermSize,
@ -1430,3 +1446,54 @@ func (ws *WshServer) GetSecretsLinuxStorageBackendCommand(ctx context.Context) (
}
return backend, nil
}
func (ws *WshServer) JobCmdExitedCommand(ctx context.Context, data wshrpc.CommandJobCmdExitedData) error {
return jobcontroller.HandleCmdJobExited(ctx, data.JobId, data)
}
func (ws *WshServer) JobControllerListCommand(ctx context.Context) ([]*waveobj.Job, error) {
return wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job)
}
func (ws *WshServer) JobControllerDeleteJobCommand(ctx context.Context, jobId string) error {
return jobcontroller.DeleteJob(ctx, jobId)
}
func (ws *WshServer) JobControllerStartJobCommand(ctx context.Context, data wshrpc.CommandJobControllerStartJobData) (string, error) {
params := jobcontroller.StartJobParams{
ConnName: data.ConnName,
Cmd: data.Cmd,
Args: data.Args,
Env: data.Env,
TermSize: data.TermSize,
}
return jobcontroller.StartJob(ctx, params)
}
func (ws *WshServer) JobControllerExitJobCommand(ctx context.Context, jobId string) error {
return jobcontroller.TerminateJobManager(ctx, jobId)
}
func (ws *WshServer) JobControllerDisconnectJobCommand(ctx context.Context, jobId string) error {
return jobcontroller.DisconnectJob(ctx, jobId)
}
func (ws *WshServer) JobControllerReconnectJobCommand(ctx context.Context, jobId string) error {
return jobcontroller.ReconnectJob(ctx, jobId)
}
func (ws *WshServer) JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error {
return jobcontroller.ReconnectJobsForConn(ctx, connName)
}
func (ws *WshServer) JobControllerConnectedJobsCommand(ctx context.Context) ([]string, error) {
return jobcontroller.GetConnectedJobIds(), nil
}
func (ws *WshServer) JobControllerAttachJobCommand(ctx context.Context, data wshrpc.CommandJobControllerAttachJobData) error {
return jobcontroller.AttachJobToBlock(ctx, data.JobId, data.BlockId)
}
func (ws *WshServer) JobControllerDetachJobCommand(ctx context.Context, jobId string) error {
return jobcontroller.DetachJobFromBlock(ctx, jobId, true)
}

View file

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"log"
"strconv"
"strings"
"sync"
"time"
@ -34,6 +35,9 @@ const (
RoutePrefix_Tab = "tab:"
RoutePrefix_FeBlock = "feblock:"
RoutePrefix_Builder = "builder:"
RoutePrefix_Link = "link:"
RoutePrefix_Job = "job:"
RoutePrefix_Bare = "bare:"
)
// this works like a network switch
@ -118,6 +122,14 @@ func MakeBuilderRouteId(builderId string) string {
return "builder:" + builderId
}
func MakeJobRouteId(jobId string) string {
return "job:" + jobId
}
func MakeLinkRouteId(linkId baseds.LinkId) string {
return fmt.Sprintf("%s%d", RoutePrefix_Link, linkId)
}
var DefaultRouter *WshRouter
func NewWshRouter() *WshRouter {
@ -245,6 +257,13 @@ func (router *WshRouter) getRouteInfo(rpcId string) *rpcRoutingInfo {
// returns true if message was sent, false if failed
func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string, commandName string, ingressLinkId baseds.LinkId) bool {
if strings.HasPrefix(routeId, RoutePrefix_Link) {
linkIdStr := strings.TrimPrefix(routeId, RoutePrefix_Link)
linkIdInt, err := strconv.ParseInt(linkIdStr, 10, 32)
if err == nil {
return router.sendMessageToLink(msgBytes, baseds.LinkId(linkIdInt), ingressLinkId)
}
}
lm := router.getLinkForRoute(routeId)
if lm != nil {
lm.client.SendRpcMessage(msgBytes, ingressLinkId, "route")
@ -448,8 +467,10 @@ func (router *WshRouter) runLinkClientRecvLoop(linkId baseds.LinkId, client Abst
} else {
// non-request messages (responses)
if !lm.trusted {
// drop responses from untrusted links
continue
// allow responses to RPCs we initiated
if rpcMsg.ResId == "" || router.getRouteInfo(rpcMsg.ResId) == nil {
continue
}
}
}
router.inputCh <- baseds.RpcInputChType{MsgBytes: msgBytes, IngressLinkId: linkId}
@ -596,7 +617,7 @@ func (router *WshRouter) UnregisterLink(linkId baseds.LinkId) {
}
func isBindableRouteId(routeId string) bool {
if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) {
if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) || strings.HasPrefix(routeId, RoutePrefix_Link) {
return false
}
return true
@ -676,6 +697,9 @@ func (router *WshRouter) bindRoute(linkId baseds.LinkId, routeId string, isSourc
if !strings.HasPrefix(routeId, ControlPrefix) {
router.announceUpstream(routeId)
}
if router.IsRootRouter() {
router.publishRouteToBroker(routeId)
}
return nil
}
@ -692,12 +716,19 @@ func (router *WshRouter) getUpstreamClient() AbstractRpcClient {
return lm.client
}
func (router *WshRouter) publishRouteToBroker(routeId string) {
defer func() {
panichandler.PanicHandler("WshRouter:publishRouteToBroker", recover())
}()
wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteUp, Scopes: []string{routeId}})
}
func (router *WshRouter) unsubscribeFromBroker(routeId string) {
defer func() {
panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover())
panichandler.PanicHandler("WshRouter:unregisterRoute:routedown", recover())
}()
wps.Broker.UnsubscribeAll(routeId)
wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteGone, Scopes: []string{routeId}})
wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteDown, Scopes: []string{routeId}})
}
func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMeta) {

View file

@ -11,7 +11,9 @@ import (
"github.com/wavetermdev/waveterm/pkg/baseds"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
type WshRouterControlImpl struct {
@ -102,6 +104,46 @@ func (impl *WshRouterControlImpl) AuthenticateCommand(ctx context.Context, data
return rtnData, nil
}
func extractTokenData(token string) (wshrpc.CommandAuthenticateRtnData, error) {
entry := shellutil.GetAndRemoveTokenSwapEntry(token)
if entry == nil {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found")
}
_, err := validateRpcContextFromAuth(entry.RpcContext)
if err != nil {
return wshrpc.CommandAuthenticateRtnData{}, err
}
if entry.RpcContext.IsRouter {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token")
}
if entry.RpcContext.RouteId == "" {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid")
}
return wshrpc.CommandAuthenticateRtnData{
Env: entry.Env,
InitScriptText: entry.ScriptText,
RpcContext: entry.RpcContext,
}, nil
}
func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) {
if !impl.Router.IsRootRouter() {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router")
}
if data.Token == "" {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message")
}
rtnData, err := extractTokenData(data.Token)
if err != nil {
log.Printf("wshrouter authenticate-token-verify error: %v", err)
return wshrpc.CommandAuthenticateRtnData{}, err
}
log.Printf("wshrouter authenticate-token-verify success routeid=%q", rtnData.RpcContext.RouteId)
return rtnData, nil
}
func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) {
handler := GetRpcResponseHandlerFromContext(ctx)
if handler == nil {
@ -117,29 +159,14 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context,
}
var rtnData wshrpc.CommandAuthenticateRtnData
var rpcContext *wshrpc.RpcContext
var err error
if impl.Router.IsRootRouter() {
entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token)
if entry == nil {
log.Printf("wshrouter authenticate-token error linkid=%d: no token entry found", linkId)
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found")
}
_, err := validateRpcContextFromAuth(entry.RpcContext)
rtnData, err = extractTokenData(data.Token)
if err != nil {
log.Printf("wshrouter authenticate-token error linkid=%d: %v", linkId, err)
return wshrpc.CommandAuthenticateRtnData{}, err
}
if entry.RpcContext.IsRouter {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token")
}
if entry.RpcContext.RouteId == "" {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid")
}
rpcContext = entry.RpcContext
rtnData = wshrpc.CommandAuthenticateRtnData{
Env: entry.Env,
InitScriptText: entry.ScriptText,
RpcContext: rpcContext,
}
} else {
wshRpc := GetWshRpcFromContext(ctx)
if wshRpc == nil {
@ -154,51 +181,91 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context,
if err != nil {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("failed to unmarshal response: %w", err)
}
rpcContext = rtnData.RpcContext
}
if rpcContext == nil {
if rtnData.RpcContext == nil {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no rpccontext in token response")
}
log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rpcContext.RouteId)
log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rtnData.RpcContext.RouteId)
impl.Router.trustLink(linkId, LinkKind_Leaf)
impl.Router.bindRoute(linkId, rpcContext.RouteId, true)
impl.Router.bindRoute(linkId, rtnData.RpcContext.RouteId, true)
return rtnData, nil
}
func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) {
func (impl *WshRouterControlImpl) AuthenticateJobManagerVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateJobManagerData) error {
if !impl.Router.IsRootRouter() {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router")
return fmt.Errorf("authenticatejobmanagerverify can only be called on root router")
}
if data.Token == "" {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message")
if data.JobId == "" {
return fmt.Errorf("no jobid in authenticatejobmanager message")
}
entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token)
if entry == nil {
log.Printf("wshrouter authenticate-token-verify error: no token entry found")
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found")
if data.JobAuthToken == "" {
return fmt.Errorf("no jobauthtoken in authenticatejobmanager message")
}
_, err := validateRpcContextFromAuth(entry.RpcContext)
job, err := wstore.DBMustGet[*waveobj.Job](ctx, data.JobId)
if err != nil {
return wshrpc.CommandAuthenticateRtnData{}, err
}
if entry.RpcContext.IsRouter {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token")
}
if entry.RpcContext.RouteId == "" {
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid")
log.Printf("wshrouter authenticate-jobmanager-verify error jobid=%q: failed to get job: %v", data.JobId, err)
return fmt.Errorf("failed to get job: %w", err)
}
rtnData := wshrpc.CommandAuthenticateRtnData{
Env: entry.Env,
InitScriptText: entry.ScriptText,
RpcContext: entry.RpcContext,
if job.JobAuthToken != data.JobAuthToken {
log.Printf("wshrouter authenticate-jobmanager-verify error jobid=%q: invalid jobauthtoken", data.JobId)
return fmt.Errorf("invalid jobauthtoken")
}
log.Printf("wshrouter authenticate-token-verify success routeid=%q", entry.RpcContext.RouteId)
return rtnData, nil
log.Printf("wshrouter authenticate-jobmanager-verify success jobid=%q", data.JobId)
return nil
}
func (impl *WshRouterControlImpl) AuthenticateJobManagerCommand(ctx context.Context, data wshrpc.CommandAuthenticateJobManagerData) error {
handler := GetRpcResponseHandlerFromContext(ctx)
if handler == nil {
return fmt.Errorf("no response handler in context")
}
linkId := handler.GetIngressLinkId()
if linkId == baseds.NoLinkId {
return fmt.Errorf("no ingress link found")
}
if data.JobId == "" {
return fmt.Errorf("no jobid in authenticatejobmanager message")
}
if data.JobAuthToken == "" {
return fmt.Errorf("no jobauthtoken in authenticatejobmanager message")
}
if impl.Router.IsRootRouter() {
job, err := wstore.DBMustGet[*waveobj.Job](ctx, data.JobId)
if err != nil {
log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: failed to get job: %v", linkId, data.JobId, err)
return fmt.Errorf("failed to get job: %w", err)
}
if job.JobAuthToken != data.JobAuthToken {
log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: invalid jobauthtoken", linkId, data.JobId)
return fmt.Errorf("invalid jobauthtoken")
}
} else {
wshRpc := GetWshRpcFromContext(ctx)
if wshRpc == nil {
return fmt.Errorf("no wshrpc in context")
}
_, err := wshRpc.SendRpcRequest(wshrpc.Command_AuthenticateJobManagerVerify, data, &wshrpc.RpcOpts{Route: ControlRootRoute})
if err != nil {
log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: failed to verify job auth token: %v", linkId, data.JobId, err)
return fmt.Errorf("failed to verify job auth token: %w", err)
}
}
routeId := MakeJobRouteId(data.JobId)
log.Printf("wshrouter authenticate-jobmanager success linkid=%d jobid=%q routeid=%q", linkId, data.JobId, routeId)
impl.Router.trustLink(linkId, LinkKind_Leaf)
impl.Router.bindRoute(linkId, routeId, true)
return nil
}
func validateRpcContextFromAuth(newCtx *wshrpc.RpcContext) (string, error) {

View file

@ -18,6 +18,7 @@ import (
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/baseds"
"github.com/wavetermdev/waveterm/pkg/panichandler"
"github.com/wavetermdev/waveterm/pkg/streamclient"
"github.com/wavetermdev/waveterm/pkg/util/ds"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps"
@ -56,6 +57,7 @@ type WshRpc struct {
ServerImpl ServerImpl
EventListener *EventListener
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
StreamBroker *streamclient.Broker
Debug bool
DebugName string
ServerDone bool
@ -226,6 +228,7 @@ func MakeWshRpcWithChannels(inputCh chan baseds.RpcInputChType, outputCh chan []
ResponseHandlerMap: make(map[string]*RpcResponseHandler),
}
rtn.RpcContext.Store(&rpcCtx)
rtn.StreamBroker = streamclient.NewBroker(AdaptWshRpc(rtn))
go rtn.runServer()
return rtn
}
@ -286,6 +289,36 @@ func (w *WshRpc) handleEventRecv(req *RpcMessage) {
w.EventListener.RecvEvent(&waveEvent)
}
func (w *WshRpc) handleStreamData(req *RpcMessage) {
if w.StreamBroker == nil {
return
}
if req.Data == nil {
return
}
var dataPk wshrpc.CommandStreamData
err := utilfn.ReUnmarshal(&dataPk, req.Data)
if err != nil {
return
}
w.StreamBroker.RecvData(dataPk)
}
func (w *WshRpc) handleStreamAck(req *RpcMessage) {
if w.StreamBroker == nil {
return
}
if req.Data == nil {
return
}
var ackPk wshrpc.CommandStreamAckData
err := utilfn.ReUnmarshal(&ackPk, req.Data)
if err != nil {
return
}
w.StreamBroker.RecvAck(ackPk)
}
func (w *WshRpc) handleRequestInternal(req *RpcMessage, ingressLinkId baseds.LinkId, pprofCtx context.Context) {
if req.Command == wshrpc.Command_EventRecv {
w.handleEventRecv(req)
@ -381,6 +414,17 @@ outer:
continue
}
if msg.IsRpcRequest() {
// Handle stream commands synchronously since the broker is designed to be non-blocking.
// RecvData/RecvAck just enqueue to work queues, so there's no risk of blocking the main loop.
if msg.Command == wshrpc.Command_StreamData {
w.handleStreamData(&msg)
continue
}
if msg.Command == wshrpc.Command_StreamDataAck {
w.handleStreamAck(&msg)
continue
}
ingressLinkId := inputVal.IngressLinkId
go func() {
defer func() {

View file

@ -0,0 +1,24 @@
// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshutil
import (
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type WshRpcStreamClientAdapter struct {
rpc *WshRpc
}
func (a *WshRpcStreamClientAdapter) StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error {
return a.rpc.SendCommand("streamdataack", data, opts)
}
func (a *WshRpcStreamClientAdapter) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error {
return a.rpc.SendCommand("streamdata", data, opts)
}
func AdaptWshRpc(rpc *WshRpc) *WshRpcStreamClientAdapter {
return &WshRpcStreamClientAdapter{rpc: rpc}
}

View file

@ -317,6 +317,31 @@ func DBUpdate(ctx context.Context, val waveobj.WaveObj) error {
})
}
func DBUpdateFn[T waveobj.WaveObj](ctx context.Context, id string, updateFn func(T)) error {
return WithTx(ctx, func(tx *TxWrap) error {
val, err := DBMustGet[T](tx.Context(), id)
if err != nil {
return err
}
updateFn(val)
return DBUpdate(tx.Context(), val)
})
}
func DBUpdateFnErr[T waveobj.WaveObj](ctx context.Context, id string, updateFn func(T) error) error {
return WithTx(ctx, func(tx *TxWrap) error {
val, err := DBMustGet[T](tx.Context(), id)
if err != nil {
return err
}
err = updateFn(val)
if err != nil {
return err
}
return DBUpdate(tx.Context(), val)
})
}
func DBInsert(ctx context.Context, val waveobj.WaveObj) error {
oid := waveobj.GetOID(val)
if oid == "" {