mirror of
https://github.com/wavetermdev/waveterm
synced 2026-04-21 14:37:16 +00:00
new job manager / framework for creating persistent remove sessions (#2779)
lots of stuff here. introduces a streaming framework for the RPC system with flow control. new authentication primitives for the RPC system. this is used to create a persistent "job manager" process (via wsh) that can survive disconnects. and then a jobcontroller in the main server that can create, reconnect, and manage these new persistent jobs. code is currently not actively hooked up to anything minus some new debugging wsh commands, and a switch in the term block that lets me test viewing the output. after PRing this change the next steps are more testing and then integrating this functionality into the product.
This commit is contained in:
parent
011ca146df
commit
ae3e9f05b7
48 changed files with 5983 additions and 1076 deletions
5
.vscode/settings.json
vendored
5
.vscode/settings.json
vendored
|
|
@ -61,5 +61,8 @@
|
|||
},
|
||||
"directoryFilters": ["-tsunami/frontend/scaffold", "-dist", "-make"]
|
||||
},
|
||||
"tailwindCSS.lint.suggestCanonicalClasses": "ignore"
|
||||
"tailwindCSS.lint.suggestCanonicalClasses": "ignore",
|
||||
"go.coverageDecorator": {
|
||||
"type": "gutter"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/wavetermdev/waveterm/pkg/blocklogger"
|
||||
"github.com/wavetermdev/waveterm/pkg/filebackup"
|
||||
"github.com/wavetermdev/waveterm/pkg/filestore"
|
||||
"github.com/wavetermdev/waveterm/pkg/jobcontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
|
||||
|
|
@ -391,7 +392,7 @@ func createMainWshClient() {
|
|||
wshfs.RpcClient = rpc
|
||||
wshutil.DefaultRouter.RegisterTrustedLeaf(rpc, wshutil.DefaultRoute)
|
||||
wps.Broker.SetClient(wshutil.DefaultRouter)
|
||||
localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}, "conn:local")
|
||||
localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, wshremote.MakeRemoteRpcServerImpl(nil, wshutil.DefaultRouter, wshclient.GetBareRpcClient(), true), "conn:local")
|
||||
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
|
||||
wshutil.DefaultRouter.RegisterTrustedLeaf(localConnWsh, wshutil.MakeConnectionRouteId(wshrpc.LocalConnName))
|
||||
}
|
||||
|
|
@ -572,6 +573,7 @@ func main() {
|
|||
go backupCleanupLoop()
|
||||
go startupActivityUpdate(firstLaunch) // must be after startConfigWatcher()
|
||||
blocklogger.InitBlockLogger()
|
||||
jobcontroller.InitJobController()
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("GetSystemSummary", recover())
|
||||
|
|
|
|||
|
|
@ -38,11 +38,14 @@ var serverCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
var connServerRouter bool
|
||||
var connServerRouterDomainSocket bool
|
||||
var connServerConnName string
|
||||
var connServerDev bool
|
||||
var ConnServerWshRouter *wshutil.WshRouter
|
||||
|
||||
func init() {
|
||||
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
|
||||
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode (stdio upstream)")
|
||||
serverCmd.Flags().BoolVar(&connServerRouterDomainSocket, "router-domainsocket", false, "run in local router mode (domain socket upstream)")
|
||||
serverCmd.Flags().StringVar(&connServerConnName, "conn", "", "connection name")
|
||||
serverCmd.Flags().BoolVar(&connServerDev, "dev", false, "enable dev mode with file logging and PID in logs")
|
||||
rootCmd.AddCommand(serverCmd)
|
||||
|
|
@ -123,7 +126,12 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh
|
|||
RouteId: routeId,
|
||||
Conn: connServerConnName,
|
||||
}
|
||||
connServerClient := wshutil.MakeWshRpc(rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}, routeId)
|
||||
|
||||
bareRouteId := wshutil.MakeRandomProcRouteId()
|
||||
bareClient := wshutil.MakeWshRpc(wshrpc.RpcContext{}, &wshclient.WshServer{}, bareRouteId)
|
||||
router.RegisterTrustedLeaf(bareClient, bareRouteId)
|
||||
|
||||
connServerClient := wshutil.MakeWshRpc(rpcCtx, wshremote.MakeRemoteRpcServerImpl(os.Stdout, router, bareClient, false), routeId)
|
||||
router.RegisterTrustedLeaf(connServerClient, routeId)
|
||||
return connServerClient, nil
|
||||
}
|
||||
|
|
@ -131,6 +139,7 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh
|
|||
func serverRunRouter() error {
|
||||
log.Printf("starting connserver router")
|
||||
router := wshutil.NewWshRouter()
|
||||
ConnServerWshRouter = router
|
||||
termProxy := wshutil.MakeRpcProxy("connserver-term")
|
||||
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||
go func() {
|
||||
|
|
@ -209,8 +218,112 @@ func serverRunRouter() error {
|
|||
select {}
|
||||
}
|
||||
|
||||
func serverRunRouterDomainSocket(jwtToken string) error {
|
||||
log.Printf("starting connserver router (domain socket upstream)")
|
||||
|
||||
// extract socket name from JWT token (unverified - we're on the client side)
|
||||
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting socket name from JWT: %v", err)
|
||||
}
|
||||
|
||||
// connect to the forwarded domain socket
|
||||
sockName = wavebase.ExpandHomeDirSafe(sockName)
|
||||
conn, err := net.Dial("unix", sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to domain socket %s: %v", sockName, err)
|
||||
}
|
||||
|
||||
// create router
|
||||
router := wshutil.NewWshRouter()
|
||||
ConnServerWshRouter = router
|
||||
|
||||
// create proxy for the domain socket connection
|
||||
upstreamProxy := wshutil.MakeRpcProxy("connserver-upstream")
|
||||
|
||||
// goroutine to write to the domain socket
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("serverRunRouterDomainSocket:WriteLoop", recover())
|
||||
}()
|
||||
writeErr := wshutil.AdaptOutputChToStream(upstreamProxy.ToRemoteCh, conn)
|
||||
if writeErr != nil {
|
||||
log.Printf("error writing to upstream domain socket: %v\n", writeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// goroutine to read from the domain socket
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("serverRunRouterDomainSocket:ReadLoop", recover())
|
||||
}()
|
||||
defer func() {
|
||||
log.Printf("upstream domain socket closed, shutting down")
|
||||
wshutil.DoShutdown("", 0, true)
|
||||
}()
|
||||
wshutil.AdaptStreamToMsgCh(conn, upstreamProxy.FromRemoteCh)
|
||||
}()
|
||||
|
||||
// register the domain socket connection as upstream
|
||||
router.RegisterUpstream(upstreamProxy)
|
||||
|
||||
// setup the connserver rpc client (leaf)
|
||||
client, err := setupConnServerRpcClientWithRouter(router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting up connserver rpc client: %v", err)
|
||||
}
|
||||
wshfs.RpcClient = client
|
||||
|
||||
// authenticate with the upstream router using the JWT
|
||||
_, err = wshclient.AuthenticateCommand(client, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error authenticating with upstream: %v", err)
|
||||
}
|
||||
log.Printf("authenticated with upstream router")
|
||||
|
||||
// fetch and set JWT public key
|
||||
log.Printf("trying to get JWT public key")
|
||||
jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(client, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting jwt public key: %v", err)
|
||||
}
|
||||
jwtPublicKeyBytes, err := base64.StdEncoding.DecodeString(jwtPublicKeyB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding jwt public key: %v", err)
|
||||
}
|
||||
err = wavejwt.SetPublicKey(jwtPublicKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting jwt public key: %v", err)
|
||||
}
|
||||
log.Printf("got JWT public key")
|
||||
|
||||
// set up the local domain socket listener for local wsh commands
|
||||
unixListener, err := MakeRemoteUnixListener()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create unix listener: %v", err)
|
||||
}
|
||||
log.Printf("unix listener started")
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("serverRunRouterDomainSocket:runListener", recover())
|
||||
}()
|
||||
runListener(unixListener, router)
|
||||
}()
|
||||
|
||||
// run the sysinfo loop
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("serverRunRouterDomainSocket:RunSysInfoLoop", recover())
|
||||
}()
|
||||
wshremote.RunSysInfoLoop(client, connServerConnName)
|
||||
}()
|
||||
|
||||
log.Printf("running server (router-domainsocket mode), successfully started")
|
||||
select {}
|
||||
}
|
||||
|
||||
func serverRunNormal(jwtToken string) error {
|
||||
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout}, jwtToken)
|
||||
err := setupRpcClient(wshremote.MakeRemoteRpcServerImpl(os.Stdout, nil, nil, false), jwtToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -283,6 +396,20 @@ func serverRun(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
if connServerRouterDomainSocket {
|
||||
jwtToken, err := askForJwtToken()
|
||||
if err != nil {
|
||||
if logFile != nil {
|
||||
fmt.Fprintf(logFile, "askForJwtToken error: %v\n", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
err = serverRunRouterDomainSocket(jwtToken)
|
||||
if err != nil && logFile != nil {
|
||||
fmt.Fprintf(logFile, "serverRunRouterDomainSocket error: %v\n", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
jwtToken, err := askForJwtToken()
|
||||
if err != nil {
|
||||
if logFile != nil {
|
||||
|
|
|
|||
382
cmd/wsh/cmd/wshcmd-jobdebug.go
Normal file
382
cmd/wsh/cmd/wshcmd-jobdebug.go
Normal file
|
|
@ -0,0 +1,382 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
)
|
||||
|
||||
var jobDebugCmd = &cobra.Command{
|
||||
Use: "jobdebug",
|
||||
Short: "debugging commands for the job system",
|
||||
Hidden: true,
|
||||
PersistentPreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
var jobDebugListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "list all jobs with debug information",
|
||||
RunE: jobDebugListRun,
|
||||
}
|
||||
|
||||
var jobDebugDeleteCmd = &cobra.Command{
|
||||
Use: "delete",
|
||||
Short: "delete a job entry by jobid",
|
||||
RunE: jobDebugDeleteRun,
|
||||
}
|
||||
|
||||
var jobDebugDeleteAllCmd = &cobra.Command{
|
||||
Use: "deleteall",
|
||||
Short: "delete all jobs",
|
||||
RunE: jobDebugDeleteAllRun,
|
||||
}
|
||||
|
||||
var jobDebugPruneCmd = &cobra.Command{
|
||||
Use: "prune",
|
||||
Short: "remove jobs where the job manager is no longer running",
|
||||
RunE: jobDebugPruneRun,
|
||||
}
|
||||
|
||||
var jobDebugExitCmd = &cobra.Command{
|
||||
Use: "exit",
|
||||
Short: "exit a job manager",
|
||||
RunE: jobDebugExitRun,
|
||||
}
|
||||
|
||||
var jobDebugDisconnectCmd = &cobra.Command{
|
||||
Use: "disconnect",
|
||||
Short: "disconnect from a job manager",
|
||||
RunE: jobDebugDisconnectRun,
|
||||
}
|
||||
|
||||
var jobDebugReconnectCmd = &cobra.Command{
|
||||
Use: "reconnect",
|
||||
Short: "reconnect to a job manager",
|
||||
RunE: jobDebugReconnectRun,
|
||||
}
|
||||
|
||||
var jobDebugReconnectConnCmd = &cobra.Command{
|
||||
Use: "reconnectconn",
|
||||
Short: "reconnect all jobs for a connection",
|
||||
RunE: jobDebugReconnectConnRun,
|
||||
}
|
||||
|
||||
var jobDebugGetOutputCmd = &cobra.Command{
|
||||
Use: "getoutput",
|
||||
Short: "get the terminal output for a job",
|
||||
RunE: jobDebugGetOutputRun,
|
||||
}
|
||||
|
||||
var jobDebugStartCmd = &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "start a new job",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: jobDebugStartRun,
|
||||
}
|
||||
|
||||
var jobDebugAttachJobCmd = &cobra.Command{
|
||||
Use: "attachjob",
|
||||
Short: "attach a job to a block",
|
||||
RunE: jobDebugAttachJobRun,
|
||||
}
|
||||
|
||||
var jobDebugDetachJobCmd = &cobra.Command{
|
||||
Use: "detachjob",
|
||||
Short: "detach a job from its block",
|
||||
RunE: jobDebugDetachJobRun,
|
||||
}
|
||||
|
||||
var jobIdFlag string
|
||||
var jobDebugJsonFlag bool
|
||||
var jobConnFlag string
|
||||
var exitJobIdFlag string
|
||||
var disconnectJobIdFlag string
|
||||
var reconnectJobIdFlag string
|
||||
var reconnectConnNameFlag string
|
||||
var attachJobIdFlag string
|
||||
var attachBlockIdFlag string
|
||||
var detachJobIdFlag string
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(jobDebugCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugListCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugDeleteCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugDeleteAllCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugPruneCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugExitCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugDisconnectCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugReconnectCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugReconnectConnCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugGetOutputCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugStartCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugAttachJobCmd)
|
||||
jobDebugCmd.AddCommand(jobDebugDetachJobCmd)
|
||||
|
||||
jobDebugListCmd.Flags().BoolVar(&jobDebugJsonFlag, "json", false, "output as JSON")
|
||||
|
||||
jobDebugDeleteCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to delete (required)")
|
||||
jobDebugDeleteCmd.MarkFlagRequired("jobid")
|
||||
|
||||
jobDebugExitCmd.Flags().StringVar(&exitJobIdFlag, "jobid", "", "job id to exit (required)")
|
||||
jobDebugExitCmd.MarkFlagRequired("jobid")
|
||||
|
||||
jobDebugDisconnectCmd.Flags().StringVar(&disconnectJobIdFlag, "jobid", "", "job id to disconnect (required)")
|
||||
jobDebugDisconnectCmd.MarkFlagRequired("jobid")
|
||||
|
||||
jobDebugReconnectCmd.Flags().StringVar(&reconnectJobIdFlag, "jobid", "", "job id to reconnect (required)")
|
||||
jobDebugReconnectCmd.MarkFlagRequired("jobid")
|
||||
|
||||
jobDebugReconnectConnCmd.Flags().StringVar(&reconnectConnNameFlag, "conn", "", "connection name (required)")
|
||||
jobDebugReconnectConnCmd.MarkFlagRequired("conn")
|
||||
|
||||
jobDebugGetOutputCmd.Flags().StringVar(&jobIdFlag, "jobid", "", "job id to get output for (required)")
|
||||
jobDebugGetOutputCmd.MarkFlagRequired("jobid")
|
||||
|
||||
jobDebugStartCmd.Flags().StringVar(&jobConnFlag, "conn", "", "connection name (required)")
|
||||
jobDebugStartCmd.MarkFlagRequired("conn")
|
||||
|
||||
jobDebugAttachJobCmd.Flags().StringVar(&attachJobIdFlag, "jobid", "", "job id to attach (required)")
|
||||
jobDebugAttachJobCmd.MarkFlagRequired("jobid")
|
||||
jobDebugAttachJobCmd.Flags().StringVar(&attachBlockIdFlag, "blockid", "", "block id to attach to (required)")
|
||||
jobDebugAttachJobCmd.MarkFlagRequired("blockid")
|
||||
|
||||
jobDebugDetachJobCmd.Flags().StringVar(&detachJobIdFlag, "jobid", "", "job id to detach (required)")
|
||||
jobDebugDetachJobCmd.MarkFlagRequired("jobid")
|
||||
}
|
||||
|
||||
func jobDebugListRun(cmd *cobra.Command, args []string) error {
|
||||
rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting job debug list: %w", err)
|
||||
}
|
||||
|
||||
connectedJobIds, err := wshclient.JobControllerConnectedJobsCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting connected job ids: %w", err)
|
||||
}
|
||||
|
||||
connectedMap := make(map[string]bool)
|
||||
for _, jobId := range connectedJobIds {
|
||||
connectedMap[jobId] = true
|
||||
}
|
||||
|
||||
if jobDebugJsonFlag {
|
||||
jsonData, err := json.MarshalIndent(rtnData, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling json: %w", err)
|
||||
}
|
||||
fmt.Printf("%s\n", string(jsonData))
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n", "OID", "Connection", "Connected", "Manager", "Cmd", "ExitCode", "Stream")
|
||||
for _, job := range rtnData {
|
||||
connectedStatus := "no"
|
||||
if connectedMap[job.OID] {
|
||||
connectedStatus = "yes"
|
||||
}
|
||||
|
||||
streamStatus := "-"
|
||||
if job.StreamDone {
|
||||
if job.StreamError == "" {
|
||||
streamStatus = "EOF"
|
||||
} else {
|
||||
streamStatus = fmt.Sprintf("%q", job.StreamError)
|
||||
}
|
||||
}
|
||||
|
||||
exitCode := "-"
|
||||
if job.CmdExitTs > 0 {
|
||||
if job.CmdExitCode != nil {
|
||||
exitCode = fmt.Sprintf("%d", *job.CmdExitCode)
|
||||
} else if job.CmdExitSignal != "" {
|
||||
exitCode = job.CmdExitSignal
|
||||
} else {
|
||||
exitCode = "?"
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("%-36s %-20s %-9s %-10s %-30s %-8s %-10s\n",
|
||||
job.OID, job.Connection, connectedStatus, job.JobManagerStatus, job.Cmd, exitCode, streamStatus)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugDeleteRun(cmd *cobra.Command, args []string) error {
|
||||
err := wshclient.JobControllerDeleteJobCommand(RpcClient, jobIdFlag, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting job: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Job %s deleted successfully\n", jobIdFlag)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugDeleteAllRun(cmd *cobra.Command, args []string) error {
|
||||
rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting job debug list: %w", err)
|
||||
}
|
||||
|
||||
if len(rtnData) == 0 {
|
||||
fmt.Printf("No jobs to delete\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
deletedCount := 0
|
||||
for _, job := range rtnData {
|
||||
err := wshclient.JobControllerDeleteJobCommand(RpcClient, job.OID, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
fmt.Printf("Error deleting job %s: %v\n", job.OID, err)
|
||||
} else {
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Deleted %d of %d job(s)\n", deletedCount, len(rtnData))
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugPruneRun(cmd *cobra.Command, args []string) error {
|
||||
rtnData, err := wshclient.JobControllerListCommand(RpcClient, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting job debug list: %w", err)
|
||||
}
|
||||
|
||||
if len(rtnData) == 0 {
|
||||
fmt.Printf("No jobs to prune\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
deletedCount := 0
|
||||
for _, job := range rtnData {
|
||||
if job.JobManagerStatus != "running" {
|
||||
err := wshclient.JobControllerDeleteJobCommand(RpcClient, job.OID, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
fmt.Printf("Error deleting job %s: %v\n", job.OID, err)
|
||||
} else {
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if deletedCount == 0 {
|
||||
fmt.Printf("No jobs with stopped job managers to prune\n")
|
||||
} else {
|
||||
fmt.Printf("Pruned %d job(s) with stopped job managers\n", deletedCount)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugExitRun(cmd *cobra.Command, args []string) error {
|
||||
err := wshclient.JobControllerExitJobCommand(RpcClient, exitJobIdFlag, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exiting job manager: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Job manager for %s exited successfully\n", exitJobIdFlag)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugDisconnectRun(cmd *cobra.Command, args []string) error {
|
||||
err := wshclient.JobControllerDisconnectJobCommand(RpcClient, disconnectJobIdFlag, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disconnecting from job manager: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Disconnected from job manager for %s successfully\n", disconnectJobIdFlag)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugReconnectRun(cmd *cobra.Command, args []string) error {
|
||||
err := wshclient.JobControllerReconnectJobCommand(RpcClient, reconnectJobIdFlag, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconnecting to job manager: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Reconnected to job manager for %s successfully\n", reconnectJobIdFlag)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugReconnectConnRun(cmd *cobra.Command, args []string) error {
|
||||
err := wshclient.JobControllerReconnectJobsForConnCommand(RpcClient, reconnectConnNameFlag, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconnecting jobs for connection: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Reconnected all jobs for connection %s successfully\n", reconnectConnNameFlag)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugGetOutputRun(cmd *cobra.Command, args []string) error {
|
||||
fileData, err := wshclient.FileReadCommand(RpcClient, wshrpc.FileData{
|
||||
Info: &wshrpc.FileInfo{
|
||||
Path: fmt.Sprintf("wavefile://%s/term", jobIdFlag),
|
||||
},
|
||||
}, &wshrpc.RpcOpts{Timeout: 10000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading job output: %w", err)
|
||||
}
|
||||
|
||||
if fileData.Data64 != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(fileData.Data64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoding output data: %w", err)
|
||||
}
|
||||
fmt.Printf("%s", string(decoded))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugStartRun(cmd *cobra.Command, args []string) error {
|
||||
cmdToRun := args[0]
|
||||
cmdArgs := args[1:]
|
||||
|
||||
data := wshrpc.CommandJobControllerStartJobData{
|
||||
ConnName: jobConnFlag,
|
||||
Cmd: cmdToRun,
|
||||
Args: cmdArgs,
|
||||
Env: make(map[string]string),
|
||||
TermSize: nil,
|
||||
}
|
||||
|
||||
jobId, err := wshclient.JobControllerStartJobCommand(RpcClient, data, &wshrpc.RpcOpts{Timeout: 10000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting job: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Job started successfully with ID: %s\n", jobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugAttachJobRun(cmd *cobra.Command, args []string) error {
|
||||
data := wshrpc.CommandJobControllerAttachJobData{
|
||||
JobId: attachJobIdFlag,
|
||||
BlockId: attachBlockIdFlag,
|
||||
}
|
||||
|
||||
err := wshclient.JobControllerAttachJobCommand(RpcClient, data, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("attaching job: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Job %s attached to block %s successfully\n", attachJobIdFlag, attachBlockIdFlag)
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDebugDetachJobRun(cmd *cobra.Command, args []string) error {
|
||||
err := wshclient.JobControllerDetachJobCommand(RpcClient, detachJobIdFlag, &wshrpc.RpcOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
return fmt.Errorf("detaching job: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Job %s detached successfully\n", detachJobIdFlag)
|
||||
return nil
|
||||
}
|
||||
119
cmd/wsh/cmd/wshcmd-jobmanager.go
Normal file
119
cmd/wsh/cmd/wshcmd-jobmanager.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/waveterm/pkg/jobmanager"
|
||||
)
|
||||
|
||||
var jobManagerCmd = &cobra.Command{
|
||||
Use: "jobmanager",
|
||||
Hidden: true,
|
||||
Short: "job manager for wave terminal",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: jobManagerRun,
|
||||
}
|
||||
|
||||
var jobManagerJobId string
|
||||
var jobManagerClientId string
|
||||
|
||||
func init() {
|
||||
jobManagerCmd.Flags().StringVar(&jobManagerJobId, "jobid", "", "job ID (UUID, required)")
|
||||
jobManagerCmd.Flags().StringVar(&jobManagerClientId, "clientid", "", "client ID (UUID, required)")
|
||||
jobManagerCmd.MarkFlagRequired("jobid")
|
||||
jobManagerCmd.MarkFlagRequired("clientid")
|
||||
rootCmd.AddCommand(jobManagerCmd)
|
||||
}
|
||||
|
||||
func jobManagerRun(cmd *cobra.Command, args []string) error {
|
||||
_, err := uuid.Parse(jobManagerJobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid jobid: must be a valid UUID")
|
||||
}
|
||||
|
||||
_, err = uuid.Parse(jobManagerClientId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid clientid: must be a valid UUID")
|
||||
}
|
||||
|
||||
publicKeyB64 := os.Getenv("WAVETERM_PUBLICKEY")
|
||||
if publicKeyB64 == "" {
|
||||
return fmt.Errorf("WAVETERM_PUBLICKEY environment variable is not set")
|
||||
}
|
||||
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode WAVETERM_PUBLICKEY: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
jobAuthToken, err := readJobAuthToken(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read job auth token: %v", err)
|
||||
}
|
||||
|
||||
readyFile := os.NewFile(3, "ready-pipe")
|
||||
_, err = readyFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ready pipe (fd 3) not available: %v", err)
|
||||
}
|
||||
|
||||
err = jobmanager.SetupJobManager(jobManagerClientId, jobManagerJobId, publicKeyBytes, jobAuthToken, readyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting up job manager: %v", err)
|
||||
}
|
||||
|
||||
select {}
|
||||
}
|
||||
|
||||
func readJobAuthToken(ctx context.Context) (string, error) {
|
||||
resultCh := make(chan string, 1)
|
||||
errorCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
errorCh <- fmt.Errorf("error reading from stdin: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
prefix := jobmanager.JobAccessTokenLabel + ":"
|
||||
if !strings.HasPrefix(line, prefix) {
|
||||
errorCh <- fmt.Errorf("invalid token format: expected '%s'", prefix)
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(line, prefix)
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
errorCh <- fmt.Errorf("empty job auth token")
|
||||
return
|
||||
}
|
||||
|
||||
resultCh <- token
|
||||
}()
|
||||
|
||||
select {
|
||||
case token := <-resultCh:
|
||||
return token, nil
|
||||
case err := <-errorCh:
|
||||
return "", err
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
}
|
||||
1
db/migrations-wstore/000011_job.down.sql
Normal file
1
db/migrations-wstore/000011_job.down.sql
Normal file
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS db_job;
|
||||
5
db/migrations-wstore/000011_job.up.sql
Normal file
5
db/migrations-wstore/000011_job.up.sql
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
CREATE TABLE IF NOT EXISTS db_job (
|
||||
oid varchar(36) PRIMARY KEY,
|
||||
version int NOT NULL,
|
||||
data json NOT NULL
|
||||
);
|
||||
|
|
@ -22,6 +22,21 @@ class RpcApiType {
|
|||
return client.wshRpcCall("authenticate", data, opts);
|
||||
}
|
||||
|
||||
// command "authenticatejobmanager" [call]
|
||||
AuthenticateJobManagerCommand(client: WshClient, data: CommandAuthenticateJobManagerData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("authenticatejobmanager", data, opts);
|
||||
}
|
||||
|
||||
// command "authenticatejobmanagerverify" [call]
|
||||
AuthenticateJobManagerVerifyCommand(client: WshClient, data: CommandAuthenticateJobManagerData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("authenticatejobmanagerverify", data, opts);
|
||||
}
|
||||
|
||||
// command "authenticatetojobmanager" [call]
|
||||
AuthenticateToJobManagerCommand(client: WshClient, data: CommandAuthenticateToJobData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("authenticatetojobmanager", data, opts);
|
||||
}
|
||||
|
||||
// command "authenticatetoken" [call]
|
||||
AuthenticateTokenCommand(client: WshClient, data: CommandAuthenticateTokenData, opts?: RpcOpts): Promise<CommandAuthenticateRtnData> {
|
||||
return client.wshRpcCall("authenticatetoken", data, opts);
|
||||
|
|
@ -377,6 +392,76 @@ class RpcApiType {
|
|||
return client.wshRpcCall("getwaveairatelimit", null, opts);
|
||||
}
|
||||
|
||||
// command "jobcmdexited" [call]
|
||||
JobCmdExitedCommand(client: WshClient, data: CommandJobCmdExitedData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcmdexited", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerattachjob" [call]
|
||||
JobControllerAttachJobCommand(client: WshClient, data: CommandJobControllerAttachJobData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerattachjob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerconnectedjobs" [call]
|
||||
JobControllerConnectedJobsCommand(client: WshClient, opts?: RpcOpts): Promise<string[]> {
|
||||
return client.wshRpcCall("jobcontrollerconnectedjobs", null, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerdeletejob" [call]
|
||||
JobControllerDeleteJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerdeletejob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerdetachjob" [call]
|
||||
JobControllerDetachJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerdetachjob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerdisconnectjob" [call]
|
||||
JobControllerDisconnectJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerdisconnectjob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerexitjob" [call]
|
||||
JobControllerExitJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerexitjob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerlist" [call]
|
||||
JobControllerListCommand(client: WshClient, opts?: RpcOpts): Promise<Job[]> {
|
||||
return client.wshRpcCall("jobcontrollerlist", null, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerreconnectjob" [call]
|
||||
JobControllerReconnectJobCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerreconnectjob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerreconnectjobsforconn" [call]
|
||||
JobControllerReconnectJobsForConnCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobcontrollerreconnectjobsforconn", data, opts);
|
||||
}
|
||||
|
||||
// command "jobcontrollerstartjob" [call]
|
||||
JobControllerStartJobCommand(client: WshClient, data: CommandJobControllerStartJobData, opts?: RpcOpts): Promise<string> {
|
||||
return client.wshRpcCall("jobcontrollerstartjob", data, opts);
|
||||
}
|
||||
|
||||
// command "jobinput" [call]
|
||||
JobInputCommand(client: WshClient, data: CommandJobInputData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobinput", data, opts);
|
||||
}
|
||||
|
||||
// command "jobprepareconnect" [call]
|
||||
JobPrepareConnectCommand(client: WshClient, data: CommandJobPrepareConnectData, opts?: RpcOpts): Promise<CommandJobConnectRtnData> {
|
||||
return client.wshRpcCall("jobprepareconnect", data, opts);
|
||||
}
|
||||
|
||||
// command "jobstartstream" [call]
|
||||
JobStartStreamCommand(client: WshClient, data: CommandJobStartStreamData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("jobstartstream", data, opts);
|
||||
}
|
||||
|
||||
// command "listallappfiles" [call]
|
||||
ListAllAppFilesCommand(client: WshClient, data: CommandListAllAppFilesData, opts?: RpcOpts): Promise<CommandListAllAppFilesRtnData> {
|
||||
return client.wshRpcCall("listallappfiles", data, opts);
|
||||
|
|
@ -432,6 +517,11 @@ class RpcApiType {
|
|||
return client.wshRpcCall("recordtevent", data, opts);
|
||||
}
|
||||
|
||||
// command "remotedisconnectfromjobmanager" [call]
|
||||
RemoteDisconnectFromJobManagerCommand(client: WshClient, data: CommandRemoteDisconnectFromJobManagerData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("remotedisconnectfromjobmanager", data, opts);
|
||||
}
|
||||
|
||||
// command "remotefilecopy" [call]
|
||||
RemoteFileCopyCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise<boolean> {
|
||||
return client.wshRpcCall("remotefilecopy", data, opts);
|
||||
|
|
@ -482,6 +572,16 @@ class RpcApiType {
|
|||
return client.wshRpcCall("remotemkdir", data, opts);
|
||||
}
|
||||
|
||||
// command "remotereconnecttojobmanager" [call]
|
||||
RemoteReconnectToJobManagerCommand(client: WshClient, data: CommandRemoteReconnectToJobManagerData, opts?: RpcOpts): Promise<CommandRemoteReconnectToJobManagerRtnData> {
|
||||
return client.wshRpcCall("remotereconnecttojobmanager", data, opts);
|
||||
}
|
||||
|
||||
// command "remotestartjob" [call]
|
||||
RemoteStartJobCommand(client: WshClient, data: CommandRemoteStartJobData, opts?: RpcOpts): Promise<CommandStartJobRtnData> {
|
||||
return client.wshRpcCall("remotestartjob", data, opts);
|
||||
}
|
||||
|
||||
// command "remotestreamcpudata" [responsestream]
|
||||
RemoteStreamCpuDataCommand(client: WshClient, opts?: RpcOpts): AsyncGenerator<TimeSeriesData, void, boolean> {
|
||||
return client.wshRpcStream("remotestreamcpudata", null, opts);
|
||||
|
|
@ -497,6 +597,11 @@ class RpcApiType {
|
|||
return client.wshRpcStream("remotetarstream", data, opts);
|
||||
}
|
||||
|
||||
// command "remoteterminatejobmanager" [call]
|
||||
RemoteTerminateJobManagerCommand(client: WshClient, data: CommandRemoteTerminateJobManagerData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("remoteterminatejobmanager", data, opts);
|
||||
}
|
||||
|
||||
// command "remotewritefile" [call]
|
||||
RemoteWriteFileCommand(client: WshClient, data: FileData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("remotewritefile", data, opts);
|
||||
|
|
@ -572,6 +677,11 @@ class RpcApiType {
|
|||
return client.wshRpcCall("startbuilder", data, opts);
|
||||
}
|
||||
|
||||
// command "startjob" [call]
|
||||
StartJobCommand(client: WshClient, data: CommandStartJobData, opts?: RpcOpts): Promise<CommandStartJobRtnData> {
|
||||
return client.wshRpcCall("startjob", data, opts);
|
||||
}
|
||||
|
||||
// command "stopbuilder" [call]
|
||||
StopBuilderCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("stopbuilder", data, opts);
|
||||
|
|
@ -607,6 +717,11 @@ class RpcApiType {
|
|||
return client.wshRpcCall("termgetscrollbacklines", data, opts);
|
||||
}
|
||||
|
||||
// command "termupdateattachedjob" [call]
|
||||
TermUpdateAttachedJobCommand(client: WshClient, data: CommandTermUpdateAttachedJobData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("termupdateattachedjob", data, opts);
|
||||
}
|
||||
|
||||
// command "test" [call]
|
||||
TestCommand(client: WshClient, data: string, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("test", data, opts);
|
||||
|
|
|
|||
|
|
@ -104,6 +104,11 @@ export class TermWshClient extends WshClient {
|
|||
}
|
||||
}
|
||||
|
||||
async handle_termupdateattachedjob(rh: RpcResponseHelper, data: CommandTermUpdateAttachedJobData): Promise<void> {
|
||||
console.log("term-update-attached-job", this.blockId, data);
|
||||
// TODO: implement frontend logic to handle job attachment updates
|
||||
}
|
||||
|
||||
async handle_termgetscrollbacklines(
|
||||
rh: RpcResponseHelper,
|
||||
data: CommandTermGetScrollbackLinesData
|
||||
|
|
|
|||
|
|
@ -298,6 +298,7 @@ const TerminalView = ({ blockId, model }: ViewComponentProps<TermViewModel>) =>
|
|||
useWebGl: !termSettings?.["term:disablewebgl"],
|
||||
sendDataHandler: model.sendDataToController.bind(model),
|
||||
nodeModel: model.nodeModel,
|
||||
jobId: blockData?.jobid,
|
||||
}
|
||||
);
|
||||
(window as any).term = termWrap;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import type { BlockNodeModel } from "@/app/block/blocktypes";
|
||||
import { getFileSubject } from "@/app/store/wps";
|
||||
import { sendWSCommand } from "@/app/store/ws";
|
||||
import { RpcApi } from "@/app/store/wshclientapi";
|
||||
import { TabRpcClient } from "@/app/store/wshrpcutil";
|
||||
import { WOS, fetchWaveFile, getApi, getSettingsKeyAtom, globalStore, openLink, recordTEvent } from "@/store/global";
|
||||
|
|
@ -50,6 +49,7 @@ type TermWrapOptions = {
|
|||
useWebGl?: boolean;
|
||||
sendDataHandler?: (data: string) => void;
|
||||
nodeModel?: BlockNodeModel;
|
||||
jobId?: string;
|
||||
};
|
||||
|
||||
// for xterm OSC handlers, we return true always because we "own" the OSC number.
|
||||
|
|
@ -375,6 +375,7 @@ function handleOsc16162Command(data: string, blockId: string, loaded: boolean, t
|
|||
export class TermWrap {
|
||||
tabId: string;
|
||||
blockId: string;
|
||||
jobId: string;
|
||||
ptyOffset: number;
|
||||
dataBytesProcessed: number;
|
||||
terminal: Terminal;
|
||||
|
|
@ -422,6 +423,7 @@ export class TermWrap {
|
|||
this.loaded = false;
|
||||
this.tabId = tabId;
|
||||
this.blockId = blockId;
|
||||
this.jobId = waveOptions.jobId;
|
||||
this.sendDataHandler = waveOptions.sendDataHandler;
|
||||
this.nodeModel = waveOptions.nodeModel;
|
||||
this.ptyOffset = 0;
|
||||
|
|
@ -495,6 +497,10 @@ export class TermWrap {
|
|||
});
|
||||
}
|
||||
|
||||
getZoneId(): string {
|
||||
return this.jobId ?? this.blockId;
|
||||
}
|
||||
|
||||
resetCompositionState() {
|
||||
this.isComposing = false;
|
||||
this.composingData = "";
|
||||
|
|
@ -566,7 +572,7 @@ export class TermWrap {
|
|||
});
|
||||
}
|
||||
|
||||
this.mainFileSubject = getFileSubject(this.blockId, TermFileName);
|
||||
this.mainFileSubject = getFileSubject(this.getZoneId(), TermFileName);
|
||||
this.mainFileSubject.subscribe(this.handleNewFileSubjectData.bind(this));
|
||||
|
||||
try {
|
||||
|
|
@ -699,8 +705,9 @@ export class TermWrap {
|
|||
}
|
||||
|
||||
async loadInitialTerminalData(): Promise<void> {
|
||||
let startTs = Date.now();
|
||||
const { data: cacheData, fileInfo: cacheFile } = await fetchWaveFile(this.blockId, TermCacheFileName);
|
||||
const startTs = Date.now();
|
||||
const zoneId = this.getZoneId();
|
||||
const { data: cacheData, fileInfo: cacheFile } = await fetchWaveFile(zoneId, TermCacheFileName);
|
||||
let ptyOffset = 0;
|
||||
if (cacheFile != null) {
|
||||
ptyOffset = cacheFile.meta["ptyoffset"] ?? 0;
|
||||
|
|
@ -722,7 +729,7 @@ export class TermWrap {
|
|||
}
|
||||
}
|
||||
}
|
||||
const { data: mainData, fileInfo: mainFile } = await fetchWaveFile(this.blockId, TermFileName, ptyOffset);
|
||||
const { data: mainData, fileInfo: mainFile } = await fetchWaveFile(zoneId, TermFileName, ptyOffset);
|
||||
console.log(
|
||||
`terminal loaded cachefile:${cacheData?.byteLength ?? 0} main:${mainData?.byteLength ?? 0} bytes, ${Date.now() - startTs}ms`
|
||||
);
|
||||
|
|
@ -751,12 +758,7 @@ export class TermWrap {
|
|||
this.fitAddon.fit();
|
||||
if (oldRows !== this.terminal.rows || oldCols !== this.terminal.cols) {
|
||||
const termSize: TermSize = { rows: this.terminal.rows, cols: this.terminal.cols };
|
||||
const wsCommand: SetBlockTermSizeWSCommand = {
|
||||
wscommand: "setblocktermsize",
|
||||
blockid: this.blockId,
|
||||
termsize: termSize,
|
||||
};
|
||||
sendWSCommand(wsCommand);
|
||||
RpcApi.ControllerInputCommand(TabRpcClient, { blockid: this.blockId, termsize: termSize });
|
||||
}
|
||||
dlog("resize", `${this.terminal.rows}x${this.terminal.cols}`, `${oldRows}x${oldCols}`, this.hasResized);
|
||||
if (!this.hasResized) {
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ export class VDomModel {
|
|||
this.queueUpdate(true);
|
||||
}
|
||||
this.routeGoneUnsub = waveEventSubscribe({
|
||||
eventType: "route:gone",
|
||||
eventType: "route:down",
|
||||
scope: curBackendRoute,
|
||||
handler: (event: WaveEvent) => {
|
||||
this.disposed = true;
|
||||
|
|
|
|||
185
frontend/types/gotypes.d.ts
vendored
185
frontend/types/gotypes.d.ts
vendored
|
|
@ -112,6 +112,7 @@ declare global {
|
|||
runtimeopts?: RuntimeOpts;
|
||||
stickers?: StickerType[];
|
||||
subblockids?: string[];
|
||||
jobid?: string;
|
||||
};
|
||||
|
||||
// blockcontroller.BlockControllerRuntimeStatus
|
||||
|
|
@ -139,13 +140,6 @@ declare global {
|
|||
files: FileInfo[];
|
||||
};
|
||||
|
||||
// webcmd.BlockInputWSCommand
|
||||
type BlockInputWSCommand = {
|
||||
wscommand: "blockinput";
|
||||
blockid: string;
|
||||
inputdata64: string;
|
||||
};
|
||||
|
||||
// wshrpc.BlocksListEntry
|
||||
type BlocksListEntry = {
|
||||
windowid: string;
|
||||
|
|
@ -179,6 +173,7 @@ declare global {
|
|||
tosagreed?: number;
|
||||
hasoldhistory?: boolean;
|
||||
tempoid?: string;
|
||||
installid?: string;
|
||||
};
|
||||
|
||||
// workspaceservice.CloseTabRtnType
|
||||
|
|
@ -194,6 +189,12 @@ declare global {
|
|||
data: {[key: string]: any};
|
||||
};
|
||||
|
||||
// wshrpc.CommandAuthenticateJobManagerData
|
||||
type CommandAuthenticateJobManagerData = {
|
||||
jobid: string;
|
||||
jobauthtoken: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandAuthenticateRtnData
|
||||
type CommandAuthenticateRtnData = {
|
||||
env?: {[key: string]: string};
|
||||
|
|
@ -201,6 +202,11 @@ declare global {
|
|||
rpccontext?: RpcContext;
|
||||
};
|
||||
|
||||
// wshrpc.CommandAuthenticateToJobData
|
||||
type CommandAuthenticateToJobData = {
|
||||
jobaccesstoken: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandAuthenticateTokenData
|
||||
type CommandAuthenticateTokenData = {
|
||||
token: string;
|
||||
|
|
@ -343,6 +349,59 @@ declare global {
|
|||
chatid: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobCmdExitedData
|
||||
type CommandJobCmdExitedData = {
|
||||
jobid: string;
|
||||
exitcode?: number;
|
||||
exitsignal?: string;
|
||||
exiterr?: string;
|
||||
exitts?: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobConnectRtnData
|
||||
type CommandJobConnectRtnData = {
|
||||
seq: number;
|
||||
streamdone?: boolean;
|
||||
streamerror?: string;
|
||||
hasexited?: boolean;
|
||||
exitcode?: number;
|
||||
exitsignal?: string;
|
||||
exiterr?: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobControllerAttachJobData
|
||||
type CommandJobControllerAttachJobData = {
|
||||
jobid: string;
|
||||
blockid: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobControllerStartJobData
|
||||
type CommandJobControllerStartJobData = {
|
||||
connname: string;
|
||||
cmd: string;
|
||||
args: string[];
|
||||
env: {[key: string]: string};
|
||||
termsize?: TermSize;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobInputData
|
||||
type CommandJobInputData = {
|
||||
jobid: string;
|
||||
inputdata64?: string;
|
||||
signame?: string;
|
||||
termsize?: TermSize;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobPrepareConnectData
|
||||
type CommandJobPrepareConnectData = {
|
||||
streammeta: StreamMeta;
|
||||
seq: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandJobStartStreamData
|
||||
type CommandJobStartStreamData = {
|
||||
};
|
||||
|
||||
// wshrpc.CommandListAllAppFilesData
|
||||
type CommandListAllAppFilesData = {
|
||||
appid: string;
|
||||
|
|
@ -397,6 +456,11 @@ declare global {
|
|||
modts?: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteDisconnectFromJobManagerData
|
||||
type CommandRemoteDisconnectFromJobManagerData = {
|
||||
jobid: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteListEntriesData
|
||||
type CommandRemoteListEntriesData = {
|
||||
path: string;
|
||||
|
|
@ -408,6 +472,36 @@ declare global {
|
|||
fileinfo?: FileInfo[];
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteReconnectToJobManagerData
|
||||
type CommandRemoteReconnectToJobManagerData = {
|
||||
jobid: string;
|
||||
jobauthtoken: string;
|
||||
mainserverjwttoken: string;
|
||||
jobmanagerpid: number;
|
||||
jobmanagerstartts: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteReconnectToJobManagerRtnData
|
||||
type CommandRemoteReconnectToJobManagerRtnData = {
|
||||
success: boolean;
|
||||
jobmanagergone: boolean;
|
||||
error?: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteStartJobData
|
||||
type CommandRemoteStartJobData = {
|
||||
cmd: string;
|
||||
args: string[];
|
||||
env: {[key: string]: string};
|
||||
termsize: TermSize;
|
||||
streammeta?: StreamMeta;
|
||||
jobauthtoken: string;
|
||||
jobid: string;
|
||||
mainserverjwttoken: string;
|
||||
clientid: string;
|
||||
publickeybase64: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteStreamFileData
|
||||
type CommandRemoteStreamFileData = {
|
||||
path: string;
|
||||
|
|
@ -420,6 +514,13 @@ declare global {
|
|||
opts?: FileCopyOpts;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRemoteTerminateJobManagerData
|
||||
type CommandRemoteTerminateJobManagerData = {
|
||||
jobid: string;
|
||||
jobmanagerpid: number;
|
||||
jobmanagerstartts: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandRenameAppFileData
|
||||
type CommandRenameAppFileData = {
|
||||
appid: string;
|
||||
|
|
@ -461,9 +562,26 @@ declare global {
|
|||
builderid: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandStartJobData
|
||||
type CommandStartJobData = {
|
||||
cmd: string;
|
||||
args: string[];
|
||||
env: {[key: string]: string};
|
||||
termsize: TermSize;
|
||||
streammeta?: StreamMeta;
|
||||
};
|
||||
|
||||
// wshrpc.CommandStartJobRtnData
|
||||
type CommandStartJobRtnData = {
|
||||
cmdpid: number;
|
||||
cmdstartts: number;
|
||||
jobmanagerpid: number;
|
||||
jobmanagerstartts: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandStreamAckData
|
||||
type CommandStreamAckData = {
|
||||
id: number;
|
||||
id: string;
|
||||
seq: number;
|
||||
rwnd: number;
|
||||
fin?: boolean;
|
||||
|
|
@ -474,7 +592,7 @@ declare global {
|
|||
|
||||
// wshrpc.CommandStreamData
|
||||
type CommandStreamData = {
|
||||
id: number;
|
||||
id: string;
|
||||
seq: number;
|
||||
data64?: string;
|
||||
eof?: boolean;
|
||||
|
|
@ -496,6 +614,12 @@ declare global {
|
|||
lastupdated: number;
|
||||
};
|
||||
|
||||
// wshrpc.CommandTermUpdateAttachedJobData
|
||||
type CommandTermUpdateAttachedJobData = {
|
||||
blockid: string;
|
||||
jobid?: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandVarData
|
||||
type CommandVarData = {
|
||||
key: string;
|
||||
|
|
@ -793,6 +917,32 @@ declare global {
|
|||
configerrors: ConfigError[];
|
||||
};
|
||||
|
||||
// waveobj.Job
|
||||
type Job = WaveObj & {
|
||||
connection: string;
|
||||
jobkind: string;
|
||||
cmd: string;
|
||||
cmdargs?: string[];
|
||||
cmdenv?: {[key: string]: string};
|
||||
jobauthtoken: string;
|
||||
attachedblockid?: string;
|
||||
terminateonreconnect?: boolean;
|
||||
jobmanagerstatus: string;
|
||||
jobmanagerdonereason?: string;
|
||||
jobmanagerstartuperror?: string;
|
||||
jobmanagerpid?: number;
|
||||
jobmanagerstartts?: number;
|
||||
cmdpid?: number;
|
||||
cmdstartts?: number;
|
||||
cmdtermsize: TermSize;
|
||||
cmdexitts?: number;
|
||||
cmdexitcode?: number;
|
||||
cmdexitsignal?: string;
|
||||
cmdexiterror?: string;
|
||||
streamdone?: boolean;
|
||||
streamerror?: string;
|
||||
};
|
||||
|
||||
// waveobj.LayoutActionData
|
||||
type LayoutActionData = {
|
||||
actiontype: string;
|
||||
|
|
@ -1062,13 +1212,6 @@ declare global {
|
|||
optional: boolean;
|
||||
};
|
||||
|
||||
// webcmd.SetBlockTermSizeWSCommand
|
||||
type SetBlockTermSizeWSCommand = {
|
||||
wscommand: "setblocktermsize";
|
||||
blockid: string;
|
||||
termsize: TermSize;
|
||||
};
|
||||
|
||||
// wconfig.SettingsType
|
||||
type SettingsType = {
|
||||
"app:*"?: boolean;
|
||||
|
|
@ -1186,6 +1329,14 @@ declare global {
|
|||
display: StickerDisplayOptsType;
|
||||
};
|
||||
|
||||
// wshrpc.StreamMeta
|
||||
type StreamMeta = {
|
||||
id: string;
|
||||
rwnd: number;
|
||||
readerrouteid: string;
|
||||
writerrouteid: string;
|
||||
};
|
||||
|
||||
// wps.SubscriptionRequest
|
||||
type SubscriptionRequest = {
|
||||
event: string;
|
||||
|
|
@ -1640,7 +1791,7 @@ declare global {
|
|||
|
||||
type WSCommandType = {
|
||||
wscommand: string;
|
||||
} & ( SetBlockTermSizeWSCommand | BlockInputWSCommand | WSRpcCommand );
|
||||
} & ( WSRpcCommand );
|
||||
|
||||
// eventbus.WSEventType
|
||||
type WSEventType = {
|
||||
|
|
|
|||
853
pkg/jobcontroller/jobcontroller.go
Normal file
853
pkg/jobcontroller/jobcontroller.go
Normal file
|
|
@ -0,0 +1,853 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobcontroller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/filestore"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/streamclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavejwt"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
const (
|
||||
JobStatus_Init = "init"
|
||||
JobStatus_Running = "running"
|
||||
JobStatus_Done = "done"
|
||||
)
|
||||
|
||||
const (
|
||||
JobDoneReason_StartupError = "startuperror"
|
||||
JobDoneReason_Gone = "gone"
|
||||
JobDoneReason_Terminated = "terminated"
|
||||
)
|
||||
|
||||
const (
|
||||
JobConnStatus_Disconnected = "disconnected"
|
||||
JobConnStatus_Connecting = "connecting"
|
||||
JobConnStatus_Connected = "connected"
|
||||
)
|
||||
|
||||
const DefaultStreamRwnd = 64 * 1024
|
||||
const MetaKey_TotalGap = "totalgap"
|
||||
const JobOutputFileName = "term"
|
||||
|
||||
func isJobManagerRunning(job *waveobj.Job) bool {
|
||||
return job.JobManagerStatus == JobStatus_Running
|
||||
}
|
||||
|
||||
var (
|
||||
jobConnStates = make(map[string]string)
|
||||
jobConnStatesLock sync.Mutex
|
||||
)
|
||||
|
||||
func getMetaInt64(meta wshrpc.FileMeta, key string) int64 {
|
||||
val, ok := meta[key]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
if intVal, ok := val.(int64); ok {
|
||||
return intVal
|
||||
}
|
||||
if floatVal, ok := val.(float64); ok {
|
||||
return int64(floatVal)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func InitJobController() {
|
||||
rpcClient := wshclient.GetBareRpcClient()
|
||||
rpcClient.EventListener.On(wps.Event_RouteUp, handleRouteUpEvent)
|
||||
rpcClient.EventListener.On(wps.Event_RouteDown, handleRouteDownEvent)
|
||||
wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{
|
||||
Event: wps.Event_RouteUp,
|
||||
AllScopes: true,
|
||||
}, nil)
|
||||
wshclient.EventSubCommand(rpcClient, wps.SubscriptionRequest{
|
||||
Event: wps.Event_RouteDown,
|
||||
AllScopes: true,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func handleRouteUpEvent(event *wps.WaveEvent) {
|
||||
handleRouteEvent(event, JobConnStatus_Connected)
|
||||
}
|
||||
|
||||
func handleRouteDownEvent(event *wps.WaveEvent) {
|
||||
handleRouteEvent(event, JobConnStatus_Disconnected)
|
||||
}
|
||||
|
||||
func handleRouteEvent(event *wps.WaveEvent, newStatus string) {
|
||||
for _, scope := range event.Scopes {
|
||||
if strings.HasPrefix(scope, "job:") {
|
||||
jobId := strings.TrimPrefix(scope, "job:")
|
||||
SetJobConnStatus(jobId, newStatus)
|
||||
log.Printf("[job:%s] connection status changed to %s", jobId, newStatus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetJobConnStatus(jobId string) string {
|
||||
jobConnStatesLock.Lock()
|
||||
defer jobConnStatesLock.Unlock()
|
||||
status, exists := jobConnStates[jobId]
|
||||
if !exists {
|
||||
return JobConnStatus_Disconnected
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
func SetJobConnStatus(jobId string, status string) {
|
||||
jobConnStatesLock.Lock()
|
||||
defer jobConnStatesLock.Unlock()
|
||||
if status == JobConnStatus_Disconnected {
|
||||
delete(jobConnStates, jobId)
|
||||
} else {
|
||||
jobConnStates[jobId] = status
|
||||
}
|
||||
}
|
||||
|
||||
func GetConnectedJobIds() []string {
|
||||
jobConnStatesLock.Lock()
|
||||
defer jobConnStatesLock.Unlock()
|
||||
var connectedJobIds []string
|
||||
for jobId, status := range jobConnStates {
|
||||
if status == JobConnStatus_Connected {
|
||||
connectedJobIds = append(connectedJobIds, jobId)
|
||||
}
|
||||
}
|
||||
return connectedJobIds
|
||||
}
|
||||
|
||||
func ensureJobConnected(ctx context.Context, jobId string) (*waveobj.Job, error) {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
isConnected, err := conncontroller.IsConnected(job.Connection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error checking connection status: %w", err)
|
||||
}
|
||||
if !isConnected {
|
||||
return nil, fmt.Errorf("connection %q is not connected", job.Connection)
|
||||
}
|
||||
|
||||
jobConnStatus := GetJobConnStatus(jobId)
|
||||
if jobConnStatus != JobConnStatus_Connected {
|
||||
return nil, fmt.Errorf("job is not connected (status: %s)", jobConnStatus)
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
type StartJobParams struct {
|
||||
ConnName string
|
||||
Cmd string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
TermSize *waveobj.TermSize
|
||||
}
|
||||
|
||||
func StartJob(ctx context.Context, params StartJobParams) (string, error) {
|
||||
if params.ConnName == "" {
|
||||
return "", fmt.Errorf("connection name is required")
|
||||
}
|
||||
if params.Cmd == "" {
|
||||
return "", fmt.Errorf("command is required")
|
||||
}
|
||||
if params.TermSize == nil {
|
||||
params.TermSize = &waveobj.TermSize{Rows: 24, Cols: 80}
|
||||
}
|
||||
|
||||
isConnected, err := conncontroller.IsConnected(params.ConnName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error checking connection status: %w", err)
|
||||
}
|
||||
if !isConnected {
|
||||
return "", fmt.Errorf("connection %q is not connected", params.ConnName)
|
||||
}
|
||||
|
||||
jobId := uuid.New().String()
|
||||
jobAuthToken, err := utilfn.RandomHexString(32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate job auth token: %w", err)
|
||||
}
|
||||
|
||||
jobAccessClaims := &wavejwt.WaveJwtClaims{
|
||||
MainServer: true,
|
||||
JobId: jobId,
|
||||
}
|
||||
jobAccessToken, err := wavejwt.Sign(jobAccessClaims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate job access token: %w", err)
|
||||
}
|
||||
|
||||
job := &waveobj.Job{
|
||||
OID: jobId,
|
||||
Connection: params.ConnName,
|
||||
Cmd: params.Cmd,
|
||||
CmdArgs: params.Args,
|
||||
CmdEnv: params.Env,
|
||||
CmdTermSize: *params.TermSize,
|
||||
JobAuthToken: jobAuthToken,
|
||||
JobManagerStatus: JobStatus_Init,
|
||||
Meta: make(waveobj.MetaMapType),
|
||||
}
|
||||
|
||||
err = wstore.DBInsert(ctx, job)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create job in database: %w", err)
|
||||
}
|
||||
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
broker := bareRpc.StreamBroker
|
||||
readerRouteId := wshclient.GetBareRpcClientRouteId()
|
||||
writerRouteId := wshutil.MakeJobRouteId(jobId)
|
||||
reader, streamMeta := broker.CreateStreamReader(readerRouteId, writerRouteId, DefaultStreamRwnd)
|
||||
|
||||
fileOpts := wshrpc.FileOpts{
|
||||
MaxSize: 10 * 1024 * 1024,
|
||||
Circular: true,
|
||||
}
|
||||
err = filestore.WFS.MakeFile(ctx, jobId, JobOutputFileName, wshrpc.FileMeta{}, fileOpts)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create WaveFS file: %w", err)
|
||||
}
|
||||
|
||||
clientId, err := wstore.DBGetSingleton[*waveobj.Client](ctx)
|
||||
if err != nil || clientId == nil {
|
||||
return "", fmt.Errorf("failed to get client: %w", err)
|
||||
}
|
||||
|
||||
publicKey := wavejwt.GetPublicKey()
|
||||
publicKeyBase64 := base64.StdEncoding.EncodeToString(publicKey)
|
||||
|
||||
startJobData := wshrpc.CommandRemoteStartJobData{
|
||||
Cmd: params.Cmd,
|
||||
Args: params.Args,
|
||||
Env: params.Env,
|
||||
TermSize: *params.TermSize,
|
||||
StreamMeta: streamMeta,
|
||||
JobAuthToken: jobAuthToken,
|
||||
JobId: jobId,
|
||||
MainServerJwtToken: jobAccessToken,
|
||||
ClientId: clientId.OID,
|
||||
PublicKeyBase64: publicKeyBase64,
|
||||
}
|
||||
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeConnectionRouteId(params.ConnName),
|
||||
Timeout: 30000,
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] sending RemoteStartJobCommand to connection %s", jobId, params.ConnName)
|
||||
rtnData, err := wshclient.RemoteStartJobCommand(bareRpc, startJobData, rpcOpts)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] RemoteStartJobCommand failed: %v", jobId, err)
|
||||
errMsg := fmt.Sprintf("failed to start job: %v", err)
|
||||
wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.JobManagerStatus = JobStatus_Done
|
||||
job.JobManagerDoneReason = JobDoneReason_StartupError
|
||||
job.JobManagerStartupError = errMsg
|
||||
})
|
||||
return "", fmt.Errorf("failed to start remote job: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] RemoteStartJobCommand succeeded, cmdpid=%d cmdstartts=%d jobmanagerpid=%d jobmanagerstartts=%d", jobId, rtnData.CmdPid, rtnData.CmdStartTs, rtnData.JobManagerPid, rtnData.JobManagerStartTs)
|
||||
err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.CmdPid = rtnData.CmdPid
|
||||
job.CmdStartTs = rtnData.CmdStartTs
|
||||
job.JobManagerPid = rtnData.JobManagerPid
|
||||
job.JobManagerStartTs = rtnData.JobManagerStartTs
|
||||
job.JobManagerStatus = JobStatus_Running
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] warning: failed to update job status to running: %v", jobId, err)
|
||||
} else {
|
||||
log.Printf("[job:%s] job status updated to running", jobId)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("jobcontroller:runOutputLoop", recover())
|
||||
}()
|
||||
runOutputLoop(context.Background(), jobId, reader)
|
||||
}()
|
||||
|
||||
return jobId, nil
|
||||
}
|
||||
|
||||
func handleAppendJobFile(ctx context.Context, jobId string, fileName string, data []byte) error {
|
||||
err := filestore.WFS.AppendData(ctx, jobId, fileName, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error appending to job file: %w", err)
|
||||
}
|
||||
wps.Broker.Publish(wps.WaveEvent{
|
||||
Event: wps.Event_BlockFile,
|
||||
Scopes: []string{
|
||||
waveobj.MakeORef(waveobj.OType_Job, jobId).String(),
|
||||
},
|
||||
Data: &wps.WSFileEventData{
|
||||
ZoneId: jobId,
|
||||
FileName: fileName,
|
||||
FileOp: wps.FileOp_Append,
|
||||
Data64: base64.StdEncoding.EncodeToString(data),
|
||||
},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func runOutputLoop(ctx context.Context, jobId string, reader *streamclient.Reader) {
|
||||
defer func() {
|
||||
log.Printf("[job:%s] output loop finished", jobId)
|
||||
}()
|
||||
|
||||
log.Printf("[job:%s] output loop started", jobId)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
log.Printf("[job:%s] received %d bytes of data", jobId, n)
|
||||
appendErr := handleAppendJobFile(ctx, jobId, JobOutputFileName, buf[:n])
|
||||
if appendErr != nil {
|
||||
log.Printf("[job:%s] error appending data to WaveFS: %v", jobId, appendErr)
|
||||
} else {
|
||||
log.Printf("[job:%s] successfully appended %d bytes to WaveFS", jobId, n)
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
log.Printf("[job:%s] stream ended (EOF)", jobId)
|
||||
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.StreamDone = true
|
||||
})
|
||||
if updateErr != nil {
|
||||
log.Printf("[job:%s] error updating job stream status: %v", jobId, updateErr)
|
||||
}
|
||||
tryTerminateJobManager(ctx, jobId)
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] stream error: %v", jobId, err)
|
||||
streamErr := err.Error()
|
||||
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.StreamDone = true
|
||||
job.StreamError = streamErr
|
||||
})
|
||||
if updateErr != nil {
|
||||
log.Printf("[job:%s] error updating job stream error: %v", jobId, updateErr)
|
||||
}
|
||||
tryTerminateJobManager(ctx, jobId)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func HandleCmdJobExited(ctx context.Context, jobId string, data wshrpc.CommandJobCmdExitedData) error {
|
||||
err := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.CmdExitError = data.ExitErr
|
||||
job.CmdExitCode = data.ExitCode
|
||||
job.CmdExitSignal = data.ExitSignal
|
||||
job.CmdExitTs = data.ExitTs
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job exit status: %w", err)
|
||||
}
|
||||
tryTerminateJobManager(ctx, jobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func tryTerminateJobManager(ctx context.Context, jobId string) {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] error getting job for termination check: %v", jobId, err)
|
||||
return
|
||||
}
|
||||
|
||||
if job.JobManagerStatus != JobStatus_Running {
|
||||
return
|
||||
}
|
||||
|
||||
cmdExited := job.CmdExitTs != 0
|
||||
|
||||
if !cmdExited || !job.StreamDone {
|
||||
log.Printf("[job:%s] not ready for termination: exited=%v streamDone=%v", jobId, cmdExited, job.StreamDone)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] both job cmd exited and stream finished, terminating job manager", jobId)
|
||||
|
||||
err = TerminateJobManager(ctx, jobId)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] error terminating job manager: %v", jobId, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TerminateJobManager(ctx context.Context, jobId string) error {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
return remoteTerminateJobManager(ctx, job)
|
||||
}
|
||||
|
||||
func DisconnectJob(ctx context.Context, jobId string) error {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeConnectionRouteId(job.Connection),
|
||||
Timeout: 5000,
|
||||
}
|
||||
|
||||
disconnectData := wshrpc.CommandRemoteDisconnectFromJobManagerData{
|
||||
JobId: jobId,
|
||||
}
|
||||
|
||||
err = wshclient.RemoteDisconnectFromJobManagerCommand(bareRpc, disconnectData, rpcOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send disconnect command: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] job disconnect command sent successfully", jobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func remoteTerminateJobManager(ctx context.Context, job *waveobj.Job) error {
|
||||
log.Printf("[job:%s] terminating job manager", job.OID)
|
||||
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
terminateData := wshrpc.CommandRemoteTerminateJobManagerData{
|
||||
JobId: job.OID,
|
||||
JobManagerPid: job.JobManagerPid,
|
||||
JobManagerStartTs: job.JobManagerStartTs,
|
||||
}
|
||||
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeConnectionRouteId(job.Connection),
|
||||
Timeout: 5000,
|
||||
}
|
||||
|
||||
err := wshclient.RemoteTerminateJobManagerCommand(bareRpc, terminateData, rpcOpts)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] error terminating job manager: %v", job.OID, err)
|
||||
return fmt.Errorf("failed to terminate job manager: %w", err)
|
||||
}
|
||||
|
||||
updateErr := wstore.DBUpdateFn(ctx, job.OID, func(job *waveobj.Job) {
|
||||
job.JobManagerStatus = JobStatus_Done
|
||||
job.JobManagerDoneReason = JobDoneReason_Terminated
|
||||
job.TerminateOnReconnect = false
|
||||
if !job.StreamDone {
|
||||
job.StreamDone = true
|
||||
job.StreamError = "job manager terminated"
|
||||
}
|
||||
})
|
||||
if updateErr != nil {
|
||||
log.Printf("[job:%s] error updating job status after termination: %v", job.OID, updateErr)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] job manager terminated successfully", job.OID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReconnectJob(ctx context.Context, jobId string) error {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
isConnected, err := conncontroller.IsConnected(job.Connection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking connection status: %w", err)
|
||||
}
|
||||
if !isConnected {
|
||||
return fmt.Errorf("connection %q is not connected", job.Connection)
|
||||
}
|
||||
|
||||
if job.TerminateOnReconnect {
|
||||
return remoteTerminateJobManager(ctx, job)
|
||||
}
|
||||
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
|
||||
jobAccessClaims := &wavejwt.WaveJwtClaims{
|
||||
MainServer: true,
|
||||
JobId: jobId,
|
||||
}
|
||||
jobAccessToken, err := wavejwt.Sign(jobAccessClaims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate job access token: %w", err)
|
||||
}
|
||||
|
||||
reconnectData := wshrpc.CommandRemoteReconnectToJobManagerData{
|
||||
JobId: jobId,
|
||||
JobAuthToken: job.JobAuthToken,
|
||||
MainServerJwtToken: jobAccessToken,
|
||||
JobManagerPid: job.JobManagerPid,
|
||||
JobManagerStartTs: job.JobManagerStartTs,
|
||||
}
|
||||
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeConnectionRouteId(job.Connection),
|
||||
Timeout: 5000,
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] sending RemoteReconnectToJobManagerCommand to connection %s", jobId, job.Connection)
|
||||
rtnData, err := wshclient.RemoteReconnectToJobManagerCommand(bareRpc, reconnectData, rpcOpts)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] RemoteReconnectToJobManagerCommand failed: %v", jobId, err)
|
||||
return fmt.Errorf("failed to reconnect to job manager: %w", err)
|
||||
}
|
||||
|
||||
if !rtnData.Success {
|
||||
log.Printf("[job:%s] RemoteReconnectToJobManagerCommand returned error: %s", jobId, rtnData.Error)
|
||||
if rtnData.JobManagerGone {
|
||||
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.JobManagerStatus = JobStatus_Done
|
||||
job.JobManagerDoneReason = JobDoneReason_Gone
|
||||
})
|
||||
if updateErr != nil {
|
||||
log.Printf("[job:%s] error updating job manager running status: %v", jobId, updateErr)
|
||||
}
|
||||
return fmt.Errorf("job manager has exited: %s", rtnData.Error)
|
||||
}
|
||||
return fmt.Errorf("failed to reconnect to job manager: %s", rtnData.Error)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] RemoteReconnectToJobManagerCommand succeeded, waiting for route", jobId)
|
||||
|
||||
routeId := wshutil.MakeJobRouteId(jobId)
|
||||
waitCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancelFn()
|
||||
err = wshutil.DefaultRouter.WaitForRegister(waitCtx, routeId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route did not establish after successful reconnection: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] route established, restarting streaming", jobId)
|
||||
return RestartStreaming(ctx, jobId, true)
|
||||
}
|
||||
|
||||
func ReconnectJobsForConn(ctx context.Context, connName string) error {
|
||||
isConnected, err := conncontroller.IsConnected(connName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking connection status: %w", err)
|
||||
}
|
||||
if !isConnected {
|
||||
return fmt.Errorf("connection %q is not connected", connName)
|
||||
}
|
||||
|
||||
allJobs, err := wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get jobs: %w", err)
|
||||
}
|
||||
|
||||
var jobsToReconnect []*waveobj.Job
|
||||
for _, job := range allJobs {
|
||||
if job.Connection == connName && isJobManagerRunning(job) {
|
||||
jobsToReconnect = append(jobsToReconnect, job)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[conn:%s] found %d jobs to reconnect", connName, len(jobsToReconnect))
|
||||
|
||||
for _, job := range jobsToReconnect {
|
||||
err = ReconnectJob(ctx, job.OID)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] error reconnecting: %v", job.OID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RestartStreaming(ctx context.Context, jobId string, knownConnected bool) error {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, jobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
if !knownConnected {
|
||||
isConnected, err := conncontroller.IsConnected(job.Connection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking connection status: %w", err)
|
||||
}
|
||||
if !isConnected {
|
||||
return fmt.Errorf("connection %q is not connected", job.Connection)
|
||||
}
|
||||
|
||||
jobConnStatus := GetJobConnStatus(jobId)
|
||||
if jobConnStatus != JobConnStatus_Connected {
|
||||
return fmt.Errorf("job manager is not connected (status: %s)", jobConnStatus)
|
||||
}
|
||||
}
|
||||
|
||||
var currentSeq int64 = 0
|
||||
var totalGap int64 = 0
|
||||
waveFile, err := filestore.WFS.Stat(ctx, jobId, JobOutputFileName)
|
||||
if err == nil {
|
||||
currentSeq = waveFile.Size
|
||||
totalGap = getMetaInt64(waveFile.Meta, MetaKey_TotalGap)
|
||||
currentSeq += totalGap
|
||||
}
|
||||
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
broker := bareRpc.StreamBroker
|
||||
readerRouteId := wshclient.GetBareRpcClientRouteId()
|
||||
writerRouteId := wshutil.MakeJobRouteId(jobId)
|
||||
|
||||
reader, streamMeta := broker.CreateStreamReaderWithSeq(readerRouteId, writerRouteId, DefaultStreamRwnd, currentSeq)
|
||||
|
||||
prepareData := wshrpc.CommandJobPrepareConnectData{
|
||||
StreamMeta: *streamMeta,
|
||||
Seq: currentSeq,
|
||||
}
|
||||
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeJobRouteId(jobId),
|
||||
Timeout: 5000,
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] sending JobPrepareConnectCommand with seq=%d (fileSize=%d, totalGap=%d)", jobId, currentSeq, waveFile.Size, totalGap)
|
||||
rtnData, err := wshclient.JobPrepareConnectCommand(bareRpc, prepareData, rpcOpts)
|
||||
if err != nil {
|
||||
reader.Close()
|
||||
return fmt.Errorf("failed to prepare connect: %w", err)
|
||||
}
|
||||
|
||||
if rtnData.HasExited {
|
||||
exitCodeStr := "nil"
|
||||
if rtnData.ExitCode != nil {
|
||||
exitCodeStr = fmt.Sprintf("%d", *rtnData.ExitCode)
|
||||
}
|
||||
log.Printf("[job:%s] job has already exited: code=%s signal=%q err=%q", jobId, exitCodeStr, rtnData.ExitSignal, rtnData.ExitErr)
|
||||
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.JobManagerStatus = JobStatus_Done
|
||||
job.CmdExitCode = rtnData.ExitCode
|
||||
job.CmdExitSignal = rtnData.ExitSignal
|
||||
job.CmdExitError = rtnData.ExitErr
|
||||
})
|
||||
if updateErr != nil {
|
||||
log.Printf("[job:%s] error updating job exit status: %v", jobId, updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
if rtnData.StreamDone {
|
||||
log.Printf("[job:%s] stream is already done: error=%q", jobId, rtnData.StreamError)
|
||||
updateErr := wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
if !job.StreamDone {
|
||||
job.StreamDone = true
|
||||
if rtnData.StreamError != "" {
|
||||
job.StreamError = rtnData.StreamError
|
||||
}
|
||||
}
|
||||
})
|
||||
if updateErr != nil {
|
||||
log.Printf("[job:%s] error updating job stream status: %v", jobId, updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
if rtnData.StreamDone && rtnData.HasExited {
|
||||
reader.Close()
|
||||
log.Printf("[job:%s] both stream done and job exited, calling tryExitJobManager", jobId)
|
||||
tryTerminateJobManager(ctx, jobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rtnData.StreamDone {
|
||||
reader.Close()
|
||||
log.Printf("[job:%s] stream already done, no need to restart streaming", jobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rtnData.Seq > currentSeq {
|
||||
gap := rtnData.Seq - currentSeq
|
||||
totalGap += gap
|
||||
log.Printf("[job:%s] detected gap: our seq=%d, server seq=%d, gap=%d, new totalGap=%d", jobId, currentSeq, rtnData.Seq, gap, totalGap)
|
||||
|
||||
metaErr := filestore.WFS.WriteMeta(ctx, jobId, JobOutputFileName, wshrpc.FileMeta{
|
||||
MetaKey_TotalGap: totalGap,
|
||||
}, true)
|
||||
if metaErr != nil {
|
||||
log.Printf("[job:%s] error updating totalgap metadata: %v", jobId, metaErr)
|
||||
}
|
||||
|
||||
reader.UpdateNextSeq(rtnData.Seq)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] sending JobStartStreamCommand", jobId)
|
||||
startStreamData := wshrpc.CommandJobStartStreamData{}
|
||||
err = wshclient.JobStartStreamCommand(bareRpc, startStreamData, rpcOpts)
|
||||
if err != nil {
|
||||
reader.Close()
|
||||
return fmt.Errorf("failed to start stream: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("jobcontroller:RestartStreaming:runOutputLoop", recover())
|
||||
}()
|
||||
runOutputLoop(context.Background(), jobId, reader)
|
||||
}()
|
||||
|
||||
log.Printf("[job:%s] streaming restarted successfully", jobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteJob(ctx context.Context, jobId string) error {
|
||||
SetJobConnStatus(jobId, JobConnStatus_Disconnected)
|
||||
err := filestore.WFS.DeleteZone(ctx, jobId)
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] warning: error deleting WaveFS zone: %v", jobId, err)
|
||||
}
|
||||
return wstore.DBDelete(ctx, waveobj.OType_Job, jobId)
|
||||
}
|
||||
|
||||
func AttachJobToBlock(ctx context.Context, jobId string, blockId string) error {
|
||||
err := wstore.WithTx(ctx, func(tx *wstore.TxWrap) error {
|
||||
err := wstore.DBUpdateFn(tx.Context(), blockId, func(block *waveobj.Block) {
|
||||
block.JobId = jobId
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update block: %w", err)
|
||||
}
|
||||
|
||||
err = wstore.DBUpdateFnErr(tx.Context(), jobId, func(job *waveobj.Job) error {
|
||||
if job.AttachedBlockId != "" {
|
||||
return fmt.Errorf("job %s already attached to block %s", jobId, job.AttachedBlockId)
|
||||
}
|
||||
job.AttachedBlockId = blockId
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] attached to block:%s", jobId, blockId)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeFeBlockRouteId(blockId),
|
||||
NoResponse: true,
|
||||
}
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
wshclient.TermUpdateAttachedJobCommand(bareRpc, wshrpc.CommandTermUpdateAttachedJobData{
|
||||
BlockId: blockId,
|
||||
JobId: jobId,
|
||||
}, rpcOpts)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DetachJobFromBlock(ctx context.Context, jobId string, updateBlock bool) error {
|
||||
var blockId string
|
||||
err := wstore.WithTx(ctx, func(tx *wstore.TxWrap) error {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](tx.Context(), jobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
blockId = job.AttachedBlockId
|
||||
if blockId == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if updateBlock {
|
||||
block, err := wstore.DBGet[*waveobj.Block](tx.Context(), blockId)
|
||||
if err == nil && block != nil {
|
||||
err = wstore.DBUpdateFn(tx.Context(), blockId, func(block *waveobj.Block) {
|
||||
block.JobId = ""
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] warning: failed to clear JobId from block:%s: %v", jobId, blockId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = wstore.DBUpdateFn(tx.Context(), jobId, func(job *waveobj.Job) {
|
||||
job.AttachedBlockId = ""
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update job: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[job:%s] detached from block:%s", jobId, blockId)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if blockId != "" {
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeFeBlockRouteId(blockId),
|
||||
NoResponse: true,
|
||||
}
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
wshclient.TermUpdateAttachedJobCommand(bareRpc, wshrpc.CommandTermUpdateAttachedJobData{
|
||||
BlockId: blockId,
|
||||
JobId: "",
|
||||
}, rpcOpts)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SendInput(ctx context.Context, data wshrpc.CommandJobInputData) error {
|
||||
jobId := data.JobId
|
||||
_, err := ensureJobConnected(ctx, jobId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rpcOpts := &wshrpc.RpcOpts{
|
||||
Route: wshutil.MakeJobRouteId(jobId),
|
||||
Timeout: 5000,
|
||||
NoResponse: false,
|
||||
}
|
||||
|
||||
bareRpc := wshclient.GetBareRpcClient()
|
||||
err = wshclient.JobInputCommand(bareRpc, data, rpcOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send input to job: %w", err)
|
||||
}
|
||||
|
||||
if data.TermSize != nil {
|
||||
err = wstore.DBUpdateFn(ctx, jobId, func(job *waveobj.Job) {
|
||||
job.CmdTermSize = *data.TermSize
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[job:%s] warning: failed to update termsize in DB: %v", jobId, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
218
pkg/jobmanager/cirbuf.go
Normal file
218
pkg/jobmanager/cirbuf.go
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type CirBuf struct {
|
||||
lock sync.Mutex
|
||||
waiterChan chan chan struct{}
|
||||
buf []byte
|
||||
readPos int
|
||||
writePos int
|
||||
count int
|
||||
totalSize int64
|
||||
syncMode bool
|
||||
windowSize int
|
||||
}
|
||||
|
||||
func MakeCirBuf(maxSize int, initSyncMode bool) *CirBuf {
|
||||
cb := &CirBuf{
|
||||
buf: make([]byte, maxSize),
|
||||
syncMode: initSyncMode,
|
||||
waiterChan: make(chan chan struct{}, 1),
|
||||
windowSize: maxSize,
|
||||
}
|
||||
return cb
|
||||
}
|
||||
|
||||
// SetEffectiveWindow changes the sync mode and effective window size for flow control.
|
||||
// The windowSize is capped at the buffer size.
|
||||
// When window shrinks: sync mode blocks new writes, async mode truncates old data to enforce limit.
|
||||
// When window increases: blocked writers are woken up if space becomes available.
|
||||
func (cb *CirBuf) SetEffectiveWindow(syncMode bool, windowSize int) {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
|
||||
maxSize := len(cb.buf)
|
||||
if windowSize > maxSize {
|
||||
windowSize = maxSize
|
||||
}
|
||||
|
||||
oldSyncMode := cb.syncMode
|
||||
oldWindowSize := cb.windowSize
|
||||
cb.windowSize = windowSize
|
||||
cb.syncMode = syncMode
|
||||
|
||||
// In async mode, enforce window size by truncating buffer if needed
|
||||
if !syncMode && cb.count > windowSize {
|
||||
excess := cb.count - windowSize
|
||||
cb.readPos = (cb.readPos + excess) % maxSize
|
||||
cb.count = windowSize
|
||||
}
|
||||
|
||||
// Only sync mode blocks writers, so only wake if we were in sync mode.
|
||||
// Wake when window grows (more space available) or switching to async (no longer blocking).
|
||||
if oldSyncMode && (windowSize > oldWindowSize || !syncMode) {
|
||||
cb.tryWakeWriter()
|
||||
}
|
||||
}
|
||||
|
||||
// Write will never block if syncMode is false
|
||||
// If syncMode is true, write will block until enough data is consumed to allow the write to finish
|
||||
// to cancel a write in progress use WriteCtx
|
||||
func (cb *CirBuf) Write(data []byte) (int, error) {
|
||||
return cb.WriteCtx(context.Background(), data)
|
||||
}
|
||||
|
||||
// WriteCtx writes data to the circular buffer with context support for cancellation.
|
||||
// In sync mode, blocks when buffer is full until space is available or context is cancelled.
|
||||
// Returns partial byte count and context error if cancelled mid-write.
|
||||
// NOTE: Only one concurrent blocked write is allowed. Multiple blocked writes will panic.
|
||||
func (cb *CirBuf) WriteCtx(ctx context.Context, data []byte) (int, error) {
|
||||
if len(data) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
bytesWritten := 0
|
||||
for bytesWritten < len(data) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
n, spaceAvailable := cb.writeAvailable(data[bytesWritten:])
|
||||
bytesWritten += n
|
||||
|
||||
if spaceAvailable != nil {
|
||||
select {
|
||||
case <-spaceAvailable:
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
tryReadCh(cb.waiterChan)
|
||||
return bytesWritten, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytesWritten, nil
|
||||
}
|
||||
|
||||
func (cb *CirBuf) writeAvailable(data []byte) (int, chan struct{}) {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
|
||||
size := len(cb.buf)
|
||||
written := 0
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
if cb.syncMode && cb.count >= cb.windowSize {
|
||||
spaceAvailable := make(chan struct{})
|
||||
if !tryWriteCh(cb.waiterChan, spaceAvailable) {
|
||||
panic("CirBuf: multiple concurrent blocked writes not allowed")
|
||||
}
|
||||
return written, spaceAvailable
|
||||
}
|
||||
|
||||
cb.buf[cb.writePos] = data[i]
|
||||
cb.writePos = (cb.writePos + 1) % size
|
||||
if cb.count < cb.windowSize {
|
||||
cb.count++
|
||||
} else {
|
||||
cb.readPos = (cb.readPos + 1) % size
|
||||
}
|
||||
cb.totalSize++
|
||||
written++
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (cb *CirBuf) PeekData(data []byte) int {
|
||||
return cb.PeekDataAt(0, data)
|
||||
}
|
||||
|
||||
func (cb *CirBuf) PeekDataAt(offset int, data []byte) int {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
|
||||
if cb.count == 0 || offset >= cb.count {
|
||||
return 0
|
||||
}
|
||||
|
||||
size := len(cb.buf)
|
||||
pos := (cb.readPos + offset) % size
|
||||
maxRead := cb.count - offset
|
||||
read := 0
|
||||
|
||||
for i := 0; i < len(data) && i < maxRead; i++ {
|
||||
data[i] = cb.buf[pos]
|
||||
pos = (pos + 1) % size
|
||||
read++
|
||||
}
|
||||
|
||||
return read
|
||||
}
|
||||
|
||||
func (cb *CirBuf) Consume(numBytes int) error {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
|
||||
if numBytes > cb.count {
|
||||
return fmt.Errorf("cannot consume %d bytes, only %d available", numBytes, cb.count)
|
||||
}
|
||||
|
||||
size := len(cb.buf)
|
||||
cb.readPos = (cb.readPos + numBytes) % size
|
||||
cb.count -= numBytes
|
||||
|
||||
cb.tryWakeWriter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cb *CirBuf) HeadPos() int64 {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
return cb.totalSize - int64(cb.count)
|
||||
}
|
||||
|
||||
func (cb *CirBuf) Size() int {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
return cb.count
|
||||
}
|
||||
|
||||
func (cb *CirBuf) TotalSize() int64 {
|
||||
cb.lock.Lock()
|
||||
defer cb.lock.Unlock()
|
||||
return cb.totalSize
|
||||
}
|
||||
|
||||
func tryWriteCh[T any](ch chan<- T, val T) bool {
|
||||
select {
|
||||
case ch <- val:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func tryReadCh[T any](ch <-chan T) (*T, bool) {
|
||||
select {
|
||||
case rtn := <-ch:
|
||||
return &rtn, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *CirBuf) tryWakeWriter() {
|
||||
if waiterCh, ok := tryReadCh(cb.waiterChan); ok {
|
||||
close(*waiterCh)
|
||||
}
|
||||
}
|
||||
208
pkg/jobmanager/jobcmd.go
Normal file
208
pkg/jobmanager/jobcmd.go
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type CmdDef struct {
|
||||
Cmd string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
TermSize waveobj.TermSize
|
||||
}
|
||||
|
||||
type JobCmd struct {
|
||||
jobId string
|
||||
lock sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
cmdPty pty.Pty
|
||||
ptsName string
|
||||
cleanedUp bool
|
||||
ptyClosed bool
|
||||
processExited bool
|
||||
exitCode *int
|
||||
exitSignal string
|
||||
exitErr error
|
||||
exitTs int64
|
||||
}
|
||||
|
||||
func MakeJobCmd(jobId string, cmdDef CmdDef) (*JobCmd, error) {
|
||||
jm := &JobCmd{
|
||||
jobId: jobId,
|
||||
}
|
||||
if cmdDef.TermSize.Rows == 0 || cmdDef.TermSize.Cols == 0 {
|
||||
cmdDef.TermSize.Rows = 25
|
||||
cmdDef.TermSize.Cols = 80
|
||||
}
|
||||
if cmdDef.TermSize.Rows <= 0 || cmdDef.TermSize.Cols <= 0 {
|
||||
return nil, fmt.Errorf("invalid term size: %v", cmdDef.TermSize)
|
||||
}
|
||||
ecmd := exec.Command(cmdDef.Cmd, cmdDef.Args...)
|
||||
if len(cmdDef.Env) > 0 {
|
||||
ecmd.Env = os.Environ()
|
||||
for key, val := range cmdDef.Env {
|
||||
ecmd.Env = append(ecmd.Env, fmt.Sprintf("%s=%s", key, val))
|
||||
}
|
||||
}
|
||||
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(cmdDef.TermSize.Rows), Cols: uint16(cmdDef.TermSize.Cols)})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start command: %w", err)
|
||||
}
|
||||
setCloseOnExec(int(cmdPty.Fd()))
|
||||
jm.cmd = ecmd
|
||||
jm.cmdPty = cmdPty
|
||||
jm.ptsName = jm.cmdPty.Name()
|
||||
go jm.waitForProcess()
|
||||
return jm, nil
|
||||
}
|
||||
|
||||
func (jm *JobCmd) waitForProcess() {
|
||||
if jm.cmd == nil || jm.cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
err := jm.cmd.Wait()
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
|
||||
jm.processExited = true
|
||||
jm.exitTs = time.Now().UnixMilli()
|
||||
jm.exitErr = err
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
|
||||
if status.Signaled() {
|
||||
jm.exitSignal = status.Signal().String()
|
||||
} else if status.Exited() {
|
||||
code := status.ExitStatus()
|
||||
jm.exitCode = &code
|
||||
} else {
|
||||
log.Printf("Invalid WaitStatus, not exited or signaled: %v", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
code := 0
|
||||
jm.exitCode = &code
|
||||
}
|
||||
exitCodeStr := "nil"
|
||||
if jm.exitCode != nil {
|
||||
exitCodeStr = fmt.Sprintf("%d", *jm.exitCode)
|
||||
}
|
||||
log.Printf("process exited: exitcode=%s, signal=%s, err=%v\n", exitCodeStr, jm.exitSignal, jm.exitErr)
|
||||
|
||||
go WshCmdJobManager.sendJobExited()
|
||||
}
|
||||
|
||||
func (jm *JobCmd) GetCmd() (*exec.Cmd, pty.Pty) {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
return jm.cmd, jm.cmdPty
|
||||
}
|
||||
|
||||
func (jm *JobCmd) GetPGID() (int, error) {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
if jm.cmd == nil || jm.cmd.Process == nil {
|
||||
return 0, fmt.Errorf("no active process")
|
||||
}
|
||||
if jm.processExited {
|
||||
return 0, fmt.Errorf("process already exited")
|
||||
}
|
||||
pgid, err := getProcessGroupId(jm.cmd.Process.Pid)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get pgid: %w", err)
|
||||
}
|
||||
if pgid <= 0 {
|
||||
return 0, fmt.Errorf("invalid pgid returned: %d", pgid)
|
||||
}
|
||||
return pgid, nil
|
||||
}
|
||||
|
||||
func (jm *JobCmd) GetExitInfo() (bool, *wshrpc.CommandJobCmdExitedData) {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
if !jm.processExited {
|
||||
return false, nil
|
||||
}
|
||||
exitData := &wshrpc.CommandJobCmdExitedData{
|
||||
JobId: WshCmdJobManager.JobId,
|
||||
ExitCode: jm.exitCode,
|
||||
ExitSignal: jm.exitSignal,
|
||||
ExitTs: jm.exitTs,
|
||||
}
|
||||
if jm.exitErr != nil {
|
||||
exitData.ExitErr = jm.exitErr.Error()
|
||||
}
|
||||
return true, exitData
|
||||
}
|
||||
|
||||
// TODO set up a single input handler loop + queue so we dont need to hold the lock but still get synchronized in-order execution
|
||||
func (jm *JobCmd) HandleInput(data wshrpc.CommandJobInputData) error {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
|
||||
if jm.cmd == nil || jm.cmdPty == nil {
|
||||
return fmt.Errorf("no active process")
|
||||
}
|
||||
|
||||
if len(data.InputData64) > 0 {
|
||||
inputBuf := make([]byte, base64.StdEncoding.DecodedLen(len(data.InputData64)))
|
||||
nw, err := base64.StdEncoding.Decode(inputBuf, []byte(data.InputData64))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding input data: %w", err)
|
||||
}
|
||||
_, err = jm.cmdPty.Write(inputBuf[:nw])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to pty: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if data.SigName != "" {
|
||||
sig := normalizeSignal(data.SigName)
|
||||
if sig != nil && jm.cmd.Process != nil {
|
||||
err := jm.cmd.Process.Signal(sig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending signal: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data.TermSize != nil {
|
||||
err := pty.Setsize(jm.cmdPty, &pty.Winsize{
|
||||
Rows: uint16(data.TermSize.Rows),
|
||||
Cols: uint16(data.TermSize.Cols),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting terminal size: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (jm *JobCmd) TerminateByClosingPtyMaster() {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
if jm.ptyClosed {
|
||||
return
|
||||
}
|
||||
if jm.cmdPty != nil {
|
||||
jm.cmdPty.Close()
|
||||
jm.ptyClosed = true
|
||||
log.Printf("pty closed for job %s\n", jm.jobId)
|
||||
}
|
||||
}
|
||||
246
pkg/jobmanager/jobmanager.go
Normal file
246
pkg/jobmanager/jobmanager.go
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
// Copyright 2026, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/baseds"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavejwt"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
const JobAccessTokenLabel = "Wave-JobAccessToken"
|
||||
const JobManagerStartLabel = "Wave-JobManagerStart"
|
||||
|
||||
var WshCmdJobManager JobManager
|
||||
|
||||
type JobManager struct {
|
||||
ClientId string
|
||||
JobId string
|
||||
Cmd *JobCmd
|
||||
JwtPublicKey []byte
|
||||
JobAuthToken string
|
||||
StreamManager *StreamManager
|
||||
lock sync.Mutex
|
||||
attachedClient *MainServerConn
|
||||
connectedStreamClient *MainServerConn
|
||||
pendingStreamMeta *wshrpc.StreamMeta
|
||||
}
|
||||
|
||||
func SetupJobManager(clientId string, jobId string, publicKeyBytes []byte, jobAuthToken string, readyFile *os.File) error {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
||||
return fmt.Errorf("job manager only supported on unix systems, not %s", runtime.GOOS)
|
||||
}
|
||||
WshCmdJobManager.ClientId = clientId
|
||||
WshCmdJobManager.JobId = jobId
|
||||
WshCmdJobManager.JwtPublicKey = publicKeyBytes
|
||||
WshCmdJobManager.JobAuthToken = jobAuthToken
|
||||
WshCmdJobManager.StreamManager = MakeStreamManager()
|
||||
err := wavejwt.SetPublicKey(publicKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set public key: %w", err)
|
||||
}
|
||||
err = MakeJobDomainSocket(clientId, jobId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(readyFile, JobManagerStartLabel+"\n")
|
||||
readyFile.Close()
|
||||
|
||||
err = daemonize(clientId, jobId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to daemonize: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (jm *JobManager) GetCmd() *JobCmd {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
return jm.Cmd
|
||||
}
|
||||
|
||||
func (jm *JobManager) sendJobExited() {
|
||||
jm.lock.Lock()
|
||||
attachedClient := jm.attachedClient
|
||||
cmd := jm.Cmd
|
||||
jm.lock.Unlock()
|
||||
|
||||
if attachedClient == nil {
|
||||
log.Printf("sendJobExited: no attached client, exit notification not sent\n")
|
||||
return
|
||||
}
|
||||
if attachedClient.WshRpc == nil {
|
||||
log.Printf("sendJobExited: no wsh rpc connection, exit notification not sent\n")
|
||||
return
|
||||
}
|
||||
if cmd == nil {
|
||||
log.Printf("sendJobExited: no cmd, exit notification not sent\n")
|
||||
return
|
||||
}
|
||||
|
||||
exited, exitData := cmd.GetExitInfo()
|
||||
if !exited || exitData == nil {
|
||||
log.Printf("sendJobExited: process not exited yet\n")
|
||||
return
|
||||
}
|
||||
|
||||
exitCodeStr := "nil"
|
||||
if exitData.ExitCode != nil {
|
||||
exitCodeStr = fmt.Sprintf("%d", *exitData.ExitCode)
|
||||
}
|
||||
log.Printf("sendJobExited: sending exit notification to main server exitcode=%s signal=%s\n", exitCodeStr, exitData.ExitSignal)
|
||||
err := wshclient.JobCmdExitedCommand(attachedClient.WshRpc, *exitData, nil)
|
||||
if err != nil {
|
||||
log.Printf("sendJobExited: error sending exit notification: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (jm *JobManager) GetJobAuthInfo() (string, string) {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
return jm.JobId, jm.JobAuthToken
|
||||
}
|
||||
|
||||
func (jm *JobManager) IsJobStarted() bool {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
return jm.Cmd != nil
|
||||
}
|
||||
|
||||
func (jm *JobManager) connectToStreamHelper_withlock(mainServerConn *MainServerConn, streamMeta wshrpc.StreamMeta, seq int64) (int64, error) {
|
||||
rwndSize := int(streamMeta.RWnd)
|
||||
if rwndSize < 0 {
|
||||
return 0, fmt.Errorf("invalid rwnd size: %d", rwndSize)
|
||||
}
|
||||
|
||||
if jm.connectedStreamClient != nil {
|
||||
log.Printf("connectToStreamHelper: disconnecting existing client\n")
|
||||
oldStreamId := jm.StreamManager.GetStreamId()
|
||||
jm.StreamManager.ClientDisconnected()
|
||||
if oldStreamId != "" {
|
||||
mainServerConn.WshRpc.StreamBroker.DetachStreamWriter(oldStreamId)
|
||||
log.Printf("connectToStreamHelper: detached old stream id=%s\n", oldStreamId)
|
||||
}
|
||||
jm.connectedStreamClient = nil
|
||||
}
|
||||
dataSender := &routedDataSender{
|
||||
wshRpc: mainServerConn.WshRpc,
|
||||
route: streamMeta.ReaderRouteId,
|
||||
}
|
||||
serverSeq, err := jm.StreamManager.ClientConnected(
|
||||
streamMeta.Id,
|
||||
dataSender,
|
||||
rwndSize,
|
||||
seq,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to connect client: %w", err)
|
||||
}
|
||||
jm.connectedStreamClient = mainServerConn
|
||||
return serverSeq, nil
|
||||
}
|
||||
|
||||
func (jm *JobManager) disconnectFromStreamHelper(mainServerConn *MainServerConn) {
|
||||
jm.lock.Lock()
|
||||
defer jm.lock.Unlock()
|
||||
if jm.connectedStreamClient == nil || jm.connectedStreamClient != mainServerConn {
|
||||
return
|
||||
}
|
||||
jm.StreamManager.ClientDisconnected()
|
||||
jm.connectedStreamClient = nil
|
||||
}
|
||||
|
||||
func GetJobSocketPath(jobId string) string {
|
||||
socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid()))
|
||||
return filepath.Join(socketDir, fmt.Sprintf("%s.sock", jobId))
|
||||
}
|
||||
|
||||
func GetJobFilePath(clientId string, jobId string, extension string) string {
|
||||
homeDir := wavebase.GetHomeDir()
|
||||
jobDir := filepath.Join(homeDir, ".waveterm", "jobs", clientId)
|
||||
return filepath.Join(jobDir, fmt.Sprintf("%s.%s", jobId, extension))
|
||||
}
|
||||
|
||||
func MakeJobDomainSocket(clientId string, jobId string) error {
|
||||
socketDir := filepath.Join("/tmp", fmt.Sprintf("waveterm-%d", os.Getuid()))
|
||||
err := os.MkdirAll(socketDir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket directory: %w", err)
|
||||
}
|
||||
|
||||
socketPath := GetJobSocketPath(jobId)
|
||||
|
||||
os.Remove(socketPath)
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on domain socket: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("MakeJobDomainSocket:accept", recover())
|
||||
listener.Close()
|
||||
os.Remove(socketPath)
|
||||
}()
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("error accepting connection: %v\n", err)
|
||||
return
|
||||
}
|
||||
go handleJobDomainSocketClient(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleJobDomainSocketClient(conn net.Conn) {
|
||||
inputCh := make(chan baseds.RpcInputChType, wshutil.DefaultInputChSize)
|
||||
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||
|
||||
serverImpl := &MainServerConn{
|
||||
Conn: conn,
|
||||
inputCh: inputCh,
|
||||
}
|
||||
rpcCtx := wshrpc.RpcContext{}
|
||||
wshRpc := wshutil.MakeWshRpcWithChannels(inputCh, outputCh, rpcCtx, serverImpl, "job-domain")
|
||||
serverImpl.WshRpc = wshRpc
|
||||
defer WshCmdJobManager.disconnectFromStreamHelper(serverImpl)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("handleJobDomainSocketClient:AdaptOutputChToStream", recover())
|
||||
}()
|
||||
defer serverImpl.Close()
|
||||
writeErr := wshutil.AdaptOutputChToStream(outputCh, conn)
|
||||
if writeErr != nil {
|
||||
log.Printf("error writing to domain socket: %v\n", writeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("handleJobDomainSocketClient:AdaptStreamToMsgCh", recover())
|
||||
}()
|
||||
defer serverImpl.Close()
|
||||
wshutil.AdaptStreamToMsgCh(conn, inputCh)
|
||||
}()
|
||||
|
||||
_ = wshRpc
|
||||
}
|
||||
95
pkg/jobmanager/jobmanager_unix.go
Normal file
95
pkg/jobmanager/jobmanager_unix.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright 2026, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build unix
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func getProcessGroupId(pid int) (int, error) {
|
||||
pgid, err := syscall.Getpgid(pid)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return pgid, nil
|
||||
}
|
||||
|
||||
func normalizeSignal(sigName string) os.Signal {
|
||||
sigName = strings.ToUpper(sigName)
|
||||
sigName = strings.TrimPrefix(sigName, "SIG")
|
||||
|
||||
switch sigName {
|
||||
case "HUP":
|
||||
return syscall.SIGHUP
|
||||
case "INT":
|
||||
return syscall.SIGINT
|
||||
case "QUIT":
|
||||
return syscall.SIGQUIT
|
||||
case "KILL":
|
||||
return syscall.SIGKILL
|
||||
case "TERM":
|
||||
return syscall.SIGTERM
|
||||
case "USR1":
|
||||
return syscall.SIGUSR1
|
||||
case "USR2":
|
||||
return syscall.SIGUSR2
|
||||
case "STOP":
|
||||
return syscall.SIGSTOP
|
||||
case "CONT":
|
||||
return syscall.SIGCONT
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func daemonize(clientId string, jobId string) error {
|
||||
_, err := unix.Setsid()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setsid: %w", err)
|
||||
}
|
||||
|
||||
devNull, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open /dev/null: %w", err)
|
||||
}
|
||||
err = unix.Dup2(int(devNull.Fd()), int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dup2 stdin: %w", err)
|
||||
}
|
||||
devNull.Close()
|
||||
|
||||
logPath := GetJobFilePath(clientId, jobId, "log")
|
||||
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
err = unix.Dup2(int(logFile.Fd()), int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dup2 stdout: %w", err)
|
||||
}
|
||||
err = unix.Dup2(int(logFile.Fd()), int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dup2 stderr: %w", err)
|
||||
}
|
||||
|
||||
log.SetOutput(logFile)
|
||||
log.Printf("job manager daemonized, logging to %s\n", logPath)
|
||||
|
||||
signal.Ignore(syscall.SIGHUP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setCloseOnExec(fd int) {
|
||||
unix.CloseOnExec(fd)
|
||||
}
|
||||
29
pkg/jobmanager/jobmanager_windows.go
Normal file
29
pkg/jobmanager/jobmanager_windows.go
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2026, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build windows
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func getProcessGroupId(pid int) (int, error) {
|
||||
return 0, fmt.Errorf("process group id not supported on windows")
|
||||
}
|
||||
|
||||
func normalizeSignal(sigName string) os.Signal {
|
||||
return nil
|
||||
}
|
||||
|
||||
func daemonize(clientId string, jobId string) error {
|
||||
return fmt.Errorf("daemonize not supported on windows")
|
||||
}
|
||||
|
||||
func setupJobManagerSignalHandlers() {
|
||||
}
|
||||
|
||||
func setCloseOnExec(fd int) {
|
||||
}
|
||||
292
pkg/jobmanager/mainserverconn.go
Normal file
292
pkg/jobmanager/mainserverconn.go
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
// Copyright 2026, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/process"
|
||||
"github.com/wavetermdev/waveterm/pkg/baseds"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavejwt"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
type MainServerConn struct {
|
||||
PeerAuthenticated atomic.Bool
|
||||
SelfAuthenticated atomic.Bool
|
||||
WshRpc *wshutil.WshRpc
|
||||
Conn net.Conn
|
||||
inputCh chan baseds.RpcInputChType
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (*MainServerConn) WshServerImpl() {}
|
||||
|
||||
func (msc *MainServerConn) Close() {
|
||||
msc.closeOnce.Do(func() {
|
||||
msc.Conn.Close()
|
||||
close(msc.inputCh)
|
||||
})
|
||||
}
|
||||
|
||||
type routedDataSender struct {
|
||||
wshRpc *wshutil.WshRpc
|
||||
route string
|
||||
}
|
||||
|
||||
func (rds *routedDataSender) SendData(dataPk wshrpc.CommandStreamData) {
|
||||
log.Printf("SendData: sending seq=%d, len=%d, eof=%t, error=%s, route=%s",
|
||||
dataPk.Seq, len(dataPk.Data64), dataPk.Eof, dataPk.Error, rds.route)
|
||||
err := wshclient.StreamDataCommand(rds.wshRpc, dataPk, &wshrpc.RpcOpts{NoResponse: true, Route: rds.route})
|
||||
if err != nil {
|
||||
log.Printf("SendData: error sending stream data: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (msc *MainServerConn) authenticateSelfToServer(jobAuthToken string) error {
|
||||
jobId, _ := WshCmdJobManager.GetJobAuthInfo()
|
||||
authData := wshrpc.CommandAuthenticateJobManagerData{
|
||||
JobId: jobId,
|
||||
JobAuthToken: jobAuthToken,
|
||||
}
|
||||
err := wshclient.AuthenticateJobManagerCommand(msc.WshRpc, authData, &wshrpc.RpcOpts{Route: wshutil.ControlRoute})
|
||||
if err != nil {
|
||||
log.Printf("authenticateSelfToServer: failed to authenticate to server: %v\n", err)
|
||||
return fmt.Errorf("failed to authenticate to server: %w", err)
|
||||
}
|
||||
msc.SelfAuthenticated.Store(true)
|
||||
log.Printf("authenticateSelfToServer: successfully authenticated to server\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msc *MainServerConn) AuthenticateToJobManagerCommand(ctx context.Context, data wshrpc.CommandAuthenticateToJobData) error {
|
||||
jobId, jobAuthToken := WshCmdJobManager.GetJobAuthInfo()
|
||||
|
||||
claims, err := wavejwt.ValidateAndExtract(data.JobAccessToken)
|
||||
if err != nil {
|
||||
log.Printf("AuthenticateToJobManager: failed to validate token: %v\n", err)
|
||||
return fmt.Errorf("failed to validate token: %w", err)
|
||||
}
|
||||
if !claims.MainServer {
|
||||
log.Printf("AuthenticateToJobManager: MainServer claim not set\n")
|
||||
return fmt.Errorf("MainServer claim not set")
|
||||
}
|
||||
if claims.JobId != jobId {
|
||||
log.Printf("AuthenticateToJobManager: JobId mismatch: expected %s, got %s\n", jobId, claims.JobId)
|
||||
return fmt.Errorf("JobId mismatch")
|
||||
}
|
||||
msc.PeerAuthenticated.Store(true)
|
||||
log.Printf("AuthenticateToJobManager: authentication successful for JobId=%s\n", claims.JobId)
|
||||
|
||||
err = msc.authenticateSelfToServer(jobAuthToken)
|
||||
if err != nil {
|
||||
msc.PeerAuthenticated.Store(false)
|
||||
return err
|
||||
}
|
||||
|
||||
WshCmdJobManager.lock.Lock()
|
||||
defer WshCmdJobManager.lock.Unlock()
|
||||
|
||||
if WshCmdJobManager.attachedClient != nil {
|
||||
log.Printf("AuthenticateToJobManager: kicking out existing client\n")
|
||||
WshCmdJobManager.attachedClient.Close()
|
||||
}
|
||||
WshCmdJobManager.attachedClient = msc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msc *MainServerConn) StartJobCommand(ctx context.Context, data wshrpc.CommandStartJobData) (*wshrpc.CommandStartJobRtnData, error) {
|
||||
log.Printf("StartJobCommand: received command=%s args=%v", data.Cmd, data.Args)
|
||||
if !msc.PeerAuthenticated.Load() {
|
||||
log.Printf("StartJobCommand: not authenticated")
|
||||
return nil, fmt.Errorf("not authenticated")
|
||||
}
|
||||
if WshCmdJobManager.IsJobStarted() {
|
||||
log.Printf("StartJobCommand: job already started")
|
||||
return nil, fmt.Errorf("job already started")
|
||||
}
|
||||
|
||||
WshCmdJobManager.lock.Lock()
|
||||
defer WshCmdJobManager.lock.Unlock()
|
||||
|
||||
if WshCmdJobManager.Cmd != nil {
|
||||
log.Printf("StartJobCommand: job already started (double check)")
|
||||
return nil, fmt.Errorf("job already started")
|
||||
}
|
||||
|
||||
cmdDef := CmdDef{
|
||||
Cmd: data.Cmd,
|
||||
Args: data.Args,
|
||||
Env: data.Env,
|
||||
TermSize: data.TermSize,
|
||||
}
|
||||
log.Printf("StartJobCommand: creating job cmd for jobid=%s", WshCmdJobManager.JobId)
|
||||
jobCmd, err := MakeJobCmd(WshCmdJobManager.JobId, cmdDef)
|
||||
if err != nil {
|
||||
log.Printf("StartJobCommand: failed to make job cmd: %v", err)
|
||||
return nil, fmt.Errorf("failed to start job: %w", err)
|
||||
}
|
||||
WshCmdJobManager.Cmd = jobCmd
|
||||
log.Printf("StartJobCommand: job cmd created successfully")
|
||||
|
||||
if data.StreamMeta != nil {
|
||||
serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, *data.StreamMeta, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect stream: %w", err)
|
||||
}
|
||||
err = msc.WshRpc.StreamBroker.AttachStreamWriter(data.StreamMeta, WshCmdJobManager.StreamManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to attach stream writer: %w", err)
|
||||
}
|
||||
log.Printf("StartJob: connected stream streamid=%s serverSeq=%d\n", data.StreamMeta.Id, serverSeq)
|
||||
}
|
||||
|
||||
_, cmdPty := jobCmd.GetCmd()
|
||||
if cmdPty != nil {
|
||||
log.Printf("StartJobCommand: attaching pty reader to stream manager")
|
||||
err = WshCmdJobManager.StreamManager.AttachReader(cmdPty)
|
||||
if err != nil {
|
||||
log.Printf("StartJobCommand: failed to attach reader: %v", err)
|
||||
return nil, fmt.Errorf("failed to attach reader to stream manager: %w", err)
|
||||
}
|
||||
log.Printf("StartJobCommand: pty reader attached successfully")
|
||||
} else {
|
||||
log.Printf("StartJobCommand: no pty to attach")
|
||||
}
|
||||
|
||||
cmd, _ := jobCmd.GetCmd()
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
log.Printf("StartJobCommand: cmd or process is nil")
|
||||
return nil, fmt.Errorf("cmd or process is nil")
|
||||
}
|
||||
cmdPid := cmd.Process.Pid
|
||||
cmdProc, err := process.NewProcess(int32(cmdPid))
|
||||
if err != nil {
|
||||
log.Printf("StartJobCommand: failed to get cmd process: %v", err)
|
||||
return nil, fmt.Errorf("failed to get cmd process: %w", err)
|
||||
}
|
||||
cmdStartTs, err := cmdProc.CreateTime()
|
||||
if err != nil {
|
||||
log.Printf("StartJobCommand: failed to get cmd start time: %v", err)
|
||||
return nil, fmt.Errorf("failed to get cmd start time: %w", err)
|
||||
}
|
||||
|
||||
jobManagerPid := os.Getpid()
|
||||
jobManagerProc, err := process.NewProcess(int32(jobManagerPid))
|
||||
if err != nil {
|
||||
log.Printf("StartJobCommand: failed to get job manager process: %v", err)
|
||||
return nil, fmt.Errorf("failed to get job manager process: %w", err)
|
||||
}
|
||||
jobManagerStartTs, err := jobManagerProc.CreateTime()
|
||||
if err != nil {
|
||||
log.Printf("StartJobCommand: failed to get job manager start time: %v", err)
|
||||
return nil, fmt.Errorf("failed to get job manager start time: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("StartJobCommand: job started successfully cmdPid=%d cmdStartTs=%d jobManagerPid=%d jobManagerStartTs=%d", cmdPid, cmdStartTs, jobManagerPid, jobManagerStartTs)
|
||||
return &wshrpc.CommandStartJobRtnData{
|
||||
CmdPid: cmdPid,
|
||||
CmdStartTs: cmdStartTs,
|
||||
JobManagerPid: jobManagerPid,
|
||||
JobManagerStartTs: jobManagerStartTs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (msc *MainServerConn) JobPrepareConnectCommand(ctx context.Context, data wshrpc.CommandJobPrepareConnectData) (*wshrpc.CommandJobConnectRtnData, error) {
|
||||
WshCmdJobManager.lock.Lock()
|
||||
defer WshCmdJobManager.lock.Unlock()
|
||||
|
||||
if !msc.PeerAuthenticated.Load() {
|
||||
return nil, fmt.Errorf("peer not authenticated")
|
||||
}
|
||||
if !msc.SelfAuthenticated.Load() {
|
||||
return nil, fmt.Errorf("not authenticated to server")
|
||||
}
|
||||
if WshCmdJobManager.Cmd == nil {
|
||||
return nil, fmt.Errorf("job not started")
|
||||
}
|
||||
|
||||
rtnData := &wshrpc.CommandJobConnectRtnData{}
|
||||
streamDone, streamError := WshCmdJobManager.StreamManager.GetStreamDoneInfo()
|
||||
|
||||
if streamDone {
|
||||
log.Printf("JobPrepareConnect: stream already done, skipping connection streamError=%q\n", streamError)
|
||||
rtnData.Seq = data.Seq
|
||||
rtnData.StreamDone = true
|
||||
rtnData.StreamError = streamError
|
||||
} else {
|
||||
corkedStreamMeta := data.StreamMeta
|
||||
corkedStreamMeta.RWnd = 0
|
||||
serverSeq, err := WshCmdJobManager.connectToStreamHelper_withlock(msc, corkedStreamMeta, data.Seq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
WshCmdJobManager.pendingStreamMeta = &data.StreamMeta
|
||||
rtnData.Seq = serverSeq
|
||||
rtnData.StreamDone = false
|
||||
}
|
||||
|
||||
hasExited, exitData := WshCmdJobManager.Cmd.GetExitInfo()
|
||||
if hasExited && exitData != nil {
|
||||
rtnData.HasExited = true
|
||||
rtnData.ExitCode = exitData.ExitCode
|
||||
rtnData.ExitSignal = exitData.ExitSignal
|
||||
rtnData.ExitErr = exitData.ExitErr
|
||||
}
|
||||
|
||||
log.Printf("JobPrepareConnect: streamid=%s clientSeq=%d serverSeq=%d streamDone=%v streamError=%q hasExited=%v\n", data.StreamMeta.Id, data.Seq, rtnData.Seq, rtnData.StreamDone, rtnData.StreamError, hasExited)
|
||||
return rtnData, nil
|
||||
}
|
||||
|
||||
func (msc *MainServerConn) JobStartStreamCommand(ctx context.Context, data wshrpc.CommandJobStartStreamData) error {
|
||||
WshCmdJobManager.lock.Lock()
|
||||
defer WshCmdJobManager.lock.Unlock()
|
||||
|
||||
if !msc.PeerAuthenticated.Load() {
|
||||
return fmt.Errorf("not authenticated")
|
||||
}
|
||||
if WshCmdJobManager.Cmd == nil {
|
||||
return fmt.Errorf("job not started")
|
||||
}
|
||||
if WshCmdJobManager.pendingStreamMeta == nil {
|
||||
return fmt.Errorf("no pending stream (call JobPrepareConnect first)")
|
||||
}
|
||||
|
||||
err := msc.WshRpc.StreamBroker.AttachStreamWriter(WshCmdJobManager.pendingStreamMeta, WshCmdJobManager.StreamManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to attach stream writer: %w", err)
|
||||
}
|
||||
|
||||
err = WshCmdJobManager.StreamManager.SetRwndSize(int(WshCmdJobManager.pendingStreamMeta.RWnd))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set rwnd size: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("JobStartStream: streamid=%s rwnd=%d streaming started\n", WshCmdJobManager.pendingStreamMeta.Id, WshCmdJobManager.pendingStreamMeta.RWnd)
|
||||
WshCmdJobManager.pendingStreamMeta = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msc *MainServerConn) JobInputCommand(ctx context.Context, data wshrpc.CommandJobInputData) error {
|
||||
WshCmdJobManager.lock.Lock()
|
||||
defer WshCmdJobManager.lock.Unlock()
|
||||
|
||||
if !msc.PeerAuthenticated.Load() {
|
||||
return fmt.Errorf("not authenticated")
|
||||
}
|
||||
if WshCmdJobManager.Cmd == nil {
|
||||
return fmt.Errorf("job not started")
|
||||
}
|
||||
|
||||
return WshCmdJobManager.Cmd.HandleInput(data)
|
||||
}
|
||||
|
||||
419
pkg/jobmanager/streammanager.go
Normal file
419
pkg/jobmanager/streammanager.go
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
const (
|
||||
CwndSize = 64 * 1024 // 64 KB window for connected mode
|
||||
CirBufSize = 2 * 1024 * 1024 // 2 MB max buffer size
|
||||
DisconnReadSz = 4 * 1024 // 4 KB read chunks when disconnected
|
||||
MaxPacketSize = 4 * 1024 // 4 KB max data per packet
|
||||
)
|
||||
|
||||
type DataSender interface {
|
||||
SendData(dataPk wshrpc.CommandStreamData)
|
||||
}
|
||||
|
||||
type streamTerminalEvent struct {
|
||||
isEof bool
|
||||
err string
|
||||
}
|
||||
|
||||
// StreamManager handles PTY output buffering with ACK-based flow control
|
||||
type StreamManager struct {
|
||||
lock sync.Mutex
|
||||
drainCond *sync.Cond
|
||||
|
||||
streamId string
|
||||
|
||||
// this is the data read from the attached reader
|
||||
buf *CirBuf
|
||||
terminalEvent *streamTerminalEvent
|
||||
eofPos int64 // fixed position when EOF/error occurs (-1 if not yet)
|
||||
|
||||
reader io.Reader
|
||||
|
||||
cwndSize int
|
||||
rwndSize int
|
||||
// invariant: if connected is true, dataSender is non-nil
|
||||
connected bool
|
||||
dataSender DataSender
|
||||
|
||||
// unacked state (reset on disconnect)
|
||||
sentNotAcked int64
|
||||
terminalEventSent bool
|
||||
|
||||
// terminal state - once true, stream is complete
|
||||
terminalEventAcked bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func MakeStreamManager() *StreamManager {
|
||||
return MakeStreamManagerWithSizes(CwndSize, CirBufSize)
|
||||
}
|
||||
|
||||
func MakeStreamManagerWithSizes(cwndSize, cirbufSize int) *StreamManager {
|
||||
sm := &StreamManager{
|
||||
buf: MakeCirBuf(cirbufSize, true),
|
||||
eofPos: -1,
|
||||
cwndSize: cwndSize,
|
||||
rwndSize: cwndSize,
|
||||
}
|
||||
sm.drainCond = sync.NewCond(&sm.lock)
|
||||
go sm.senderLoop()
|
||||
return sm
|
||||
}
|
||||
|
||||
// AttachReader starts reading from the given reader
|
||||
func (sm *StreamManager) AttachReader(r io.Reader) error {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
if sm.reader != nil {
|
||||
return fmt.Errorf("reader already attached")
|
||||
}
|
||||
|
||||
sm.reader = r
|
||||
go sm.readLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConnected transitions to CONNECTED mode
|
||||
func (sm *StreamManager) ClientConnected(streamId string, dataSender DataSender, rwndSize int, clientSeq int64) (int64, error) {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
if sm.closed || sm.terminalEventAcked {
|
||||
return 0, fmt.Errorf("stream is closed")
|
||||
}
|
||||
|
||||
if sm.connected {
|
||||
return 0, fmt.Errorf("client already connected")
|
||||
}
|
||||
|
||||
if dataSender == nil {
|
||||
return 0, fmt.Errorf("dataSender cannot be nil")
|
||||
}
|
||||
|
||||
headPos := sm.buf.HeadPos()
|
||||
if clientSeq > headPos {
|
||||
bytesToConsume := int(clientSeq - headPos)
|
||||
available := sm.buf.Size()
|
||||
if bytesToConsume > available {
|
||||
return 0, fmt.Errorf("client seq %d is beyond our stream end (head=%d, size=%d)", clientSeq, headPos, available)
|
||||
}
|
||||
if bytesToConsume > 0 {
|
||||
if err := sm.buf.Consume(bytesToConsume); err != nil {
|
||||
return 0, fmt.Errorf("failed to consume buffer: %w", err)
|
||||
}
|
||||
headPos = sm.buf.HeadPos()
|
||||
}
|
||||
}
|
||||
|
||||
sm.streamId = streamId
|
||||
sm.dataSender = dataSender
|
||||
sm.connected = true
|
||||
sm.rwndSize = rwndSize
|
||||
sm.sentNotAcked = 0
|
||||
effectiveWindow := sm.cwndSize
|
||||
if sm.rwndSize < effectiveWindow {
|
||||
effectiveWindow = sm.rwndSize
|
||||
}
|
||||
sm.buf.SetEffectiveWindow(true, effectiveWindow)
|
||||
sm.drainCond.Signal()
|
||||
|
||||
startSeq := headPos
|
||||
if clientSeq > startSeq {
|
||||
startSeq = clientSeq
|
||||
}
|
||||
|
||||
return startSeq, nil
|
||||
}
|
||||
|
||||
// GetStreamId returns the current stream ID (safe to call with lock held by caller)
|
||||
func (sm *StreamManager) GetStreamId() string {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
return sm.streamId
|
||||
}
|
||||
|
||||
// GetStreamDoneInfo returns whether the stream is done and the error if there was one.
|
||||
// The error is only meaningful if done=true, as the error is delivered as part of the stream otherwise.
|
||||
func (sm *StreamManager) GetStreamDoneInfo() (done bool, streamError string) {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
if !sm.terminalEventAcked {
|
||||
return false, ""
|
||||
}
|
||||
if sm.terminalEvent != nil && !sm.terminalEvent.isEof {
|
||||
return true, sm.terminalEvent.err
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// ClientDisconnected transitions to DISCONNECTED mode
|
||||
func (sm *StreamManager) ClientDisconnected() {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
if !sm.connected {
|
||||
return
|
||||
}
|
||||
|
||||
sm.connected = false
|
||||
sm.dataSender = nil
|
||||
sm.sentNotAcked = 0
|
||||
if !sm.terminalEventAcked {
|
||||
sm.terminalEventSent = false
|
||||
}
|
||||
sm.buf.SetEffectiveWindow(false, CirBufSize)
|
||||
sm.drainCond.Signal()
|
||||
}
|
||||
|
||||
// RecvAck processes an ACK from the client
|
||||
// must be connected, and streamid must match
|
||||
func (sm *StreamManager) RecvAck(ackPk wshrpc.CommandStreamAckData) {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
if !sm.connected || ackPk.Id != sm.streamId {
|
||||
return
|
||||
}
|
||||
|
||||
if ackPk.Fin {
|
||||
sm.terminalEventAcked = true
|
||||
sm.drainCond.Signal()
|
||||
return
|
||||
}
|
||||
|
||||
seq := ackPk.Seq
|
||||
headPos := sm.buf.HeadPos()
|
||||
if seq < headPos {
|
||||
return
|
||||
}
|
||||
|
||||
ackedBytes := seq - headPos
|
||||
if ackedBytes > sm.sentNotAcked {
|
||||
return
|
||||
}
|
||||
|
||||
if ackedBytes > 0 {
|
||||
if err := sm.buf.Consume(int(ackedBytes)); err != nil {
|
||||
return
|
||||
}
|
||||
sm.sentNotAcked -= ackedBytes
|
||||
}
|
||||
|
||||
prevRwnd := sm.rwndSize
|
||||
sm.rwndSize = int(ackPk.RWnd)
|
||||
effectiveWindow := sm.cwndSize
|
||||
if sm.rwndSize < effectiveWindow {
|
||||
effectiveWindow = sm.rwndSize
|
||||
}
|
||||
sm.buf.SetEffectiveWindow(true, effectiveWindow)
|
||||
|
||||
if sm.rwndSize > prevRwnd || ackedBytes > 0 {
|
||||
sm.drainCond.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
// SetRwndSize dynamically updates the receive window size
|
||||
func (sm *StreamManager) SetRwndSize(rwndSize int) error {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
if rwndSize < 0 {
|
||||
return fmt.Errorf("rwndSize cannot be negative")
|
||||
}
|
||||
if !sm.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
sm.rwndSize = rwndSize
|
||||
effectiveWindow := sm.cwndSize
|
||||
if sm.rwndSize < effectiveWindow {
|
||||
effectiveWindow = sm.rwndSize
|
||||
}
|
||||
sm.buf.SetEffectiveWindow(true, effectiveWindow)
|
||||
sm.drainCond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the sender loop. The reader loop will exit on its next iteration
|
||||
// or when the underlying reader is closed.
|
||||
func (sm *StreamManager) Close() {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
sm.closed = true
|
||||
sm.drainCond.Signal()
|
||||
}
|
||||
|
||||
// readLoop is the main read goroutine
|
||||
func (sm *StreamManager) readLoop() {
|
||||
readBuf := make([]byte, MaxPacketSize)
|
||||
for {
|
||||
sm.lock.Lock()
|
||||
closed := sm.closed
|
||||
sm.lock.Unlock()
|
||||
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := sm.reader.Read(readBuf)
|
||||
log.Printf("readLoop: read %d bytes from PTY, err=%v", n, err)
|
||||
|
||||
if n > 0 {
|
||||
sm.handleReadData(readBuf[:n])
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
sm.handleEOF()
|
||||
} else {
|
||||
sm.handleError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StreamManager) handleReadData(data []byte) {
|
||||
log.Printf("handleReadData: writing %d bytes to buffer", len(data))
|
||||
sm.buf.Write(data)
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
log.Printf("handleReadData: buffer size=%d, connected=%t, signaling=%t", sm.buf.Size(), sm.connected, sm.connected)
|
||||
if sm.connected {
|
||||
sm.drainCond.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StreamManager) handleEOF() {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
log.Printf("handleEOF: PTY reached EOF, totalSize=%d", sm.buf.TotalSize())
|
||||
sm.eofPos = sm.buf.TotalSize()
|
||||
sm.terminalEvent = &streamTerminalEvent{isEof: true}
|
||||
sm.drainCond.Signal()
|
||||
}
|
||||
|
||||
func (sm *StreamManager) handleError(err error) {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
log.Printf("handleError: PTY error=%v, totalSize=%d", err, sm.buf.TotalSize())
|
||||
sm.eofPos = sm.buf.TotalSize()
|
||||
sm.terminalEvent = &streamTerminalEvent{err: err.Error()}
|
||||
sm.drainCond.Signal()
|
||||
}
|
||||
|
||||
func (sm *StreamManager) senderLoop() {
|
||||
for {
|
||||
done, pkt, sender := sm.prepareNextPacket()
|
||||
if done {
|
||||
return
|
||||
}
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
sender.SendData(*pkt)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StreamManager) prepareNextPacket() (done bool, pkt *wshrpc.CommandStreamData, sender DataSender) {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
|
||||
available := sm.buf.Size()
|
||||
log.Printf("prepareNextPacket: connected=%t, available=%d, closed=%t, terminalEventAcked=%t, terminalEvent=%v",
|
||||
sm.connected, available, sm.closed, sm.terminalEventAcked, sm.terminalEvent != nil)
|
||||
|
||||
if sm.closed || sm.terminalEventAcked {
|
||||
return true, nil, nil
|
||||
}
|
||||
|
||||
if !sm.connected {
|
||||
log.Printf("prepareNextPacket: waiting for connection")
|
||||
sm.drainCond.Wait()
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
if available == 0 {
|
||||
if sm.terminalEvent != nil && !sm.terminalEventSent {
|
||||
log.Printf("prepareNextPacket: preparing terminal packet")
|
||||
return false, sm.prepareTerminalPacket(), sm.dataSender
|
||||
}
|
||||
log.Printf("prepareNextPacket: no data available, waiting")
|
||||
sm.drainCond.Wait()
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
effectiveRwnd := sm.rwndSize
|
||||
if sm.cwndSize < effectiveRwnd {
|
||||
effectiveRwnd = sm.cwndSize
|
||||
}
|
||||
availableToSend := int64(effectiveRwnd) - sm.sentNotAcked
|
||||
|
||||
if availableToSend <= 0 {
|
||||
sm.drainCond.Wait()
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
peekSize := int(availableToSend)
|
||||
if peekSize > MaxPacketSize {
|
||||
peekSize = MaxPacketSize
|
||||
}
|
||||
if peekSize > available {
|
||||
peekSize = available
|
||||
}
|
||||
|
||||
data := make([]byte, peekSize)
|
||||
n := sm.buf.PeekDataAt(int(sm.sentNotAcked), data)
|
||||
if n == 0 {
|
||||
log.Printf("prepareNextPacket: PeekDataAt returned 0 bytes, waiting for ACK")
|
||||
sm.drainCond.Wait()
|
||||
return false, nil, nil
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
seq := sm.buf.HeadPos() + sm.sentNotAcked
|
||||
sm.sentNotAcked += int64(n)
|
||||
|
||||
log.Printf("prepareNextPacket: sending packet seq=%d, len=%d bytes", seq, n)
|
||||
return false, &wshrpc.CommandStreamData{
|
||||
Id: sm.streamId,
|
||||
Seq: seq,
|
||||
Data64: base64.StdEncoding.EncodeToString(data),
|
||||
}, sm.dataSender
|
||||
}
|
||||
|
||||
func (sm *StreamManager) prepareTerminalPacket() *wshrpc.CommandStreamData {
|
||||
if sm.terminalEventSent || sm.terminalEvent == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pkt := &wshrpc.CommandStreamData{
|
||||
Id: sm.streamId,
|
||||
Seq: sm.eofPos,
|
||||
}
|
||||
|
||||
if sm.terminalEvent.isEof {
|
||||
pkt.Eof = true
|
||||
} else {
|
||||
pkt.Error = sm.terminalEvent.err
|
||||
}
|
||||
|
||||
sm.terminalEventSent = true
|
||||
return pkt
|
||||
}
|
||||
348
pkg/jobmanager/streammanager_test.go
Normal file
348
pkg/jobmanager/streammanager_test.go
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jobmanager
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type testWriter struct {
|
||||
mu sync.Mutex
|
||||
packets []wshrpc.CommandStreamData
|
||||
}
|
||||
|
||||
func (tw *testWriter) SendData(pkt wshrpc.CommandStreamData) {
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
tw.packets = append(tw.packets, pkt)
|
||||
}
|
||||
|
||||
func (tw *testWriter) GetPackets() []wshrpc.CommandStreamData {
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
result := make([]wshrpc.CommandStreamData, len(tw.packets))
|
||||
copy(result, tw.packets)
|
||||
return result
|
||||
}
|
||||
|
||||
func (tw *testWriter) Clear() {
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
tw.packets = nil
|
||||
}
|
||||
|
||||
func decodeData(data64 string) string {
|
||||
decoded, _ := base64.StdEncoding.DecodeString(data64)
|
||||
return string(decoded)
|
||||
}
|
||||
|
||||
func TestBasicDisconnectedMode(t *testing.T) {
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManager()
|
||||
|
||||
reader := strings.NewReader("hello world")
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
packets := tw.GetPackets()
|
||||
if len(packets) > 0 {
|
||||
t.Errorf("Expected no packets in DISCONNECTED mode without client, got %d", len(packets))
|
||||
}
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
func TestConnectedModeBasicFlow(t *testing.T) {
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManager()
|
||||
|
||||
reader := strings.NewReader("hello")
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientConnected failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
packets := tw.GetPackets()
|
||||
if len(packets) == 0 {
|
||||
t.Fatal("Expected packets after ClientConnected")
|
||||
}
|
||||
|
||||
// Verify we got the data
|
||||
allData := ""
|
||||
for _, pkt := range packets {
|
||||
if pkt.Data64 != "" {
|
||||
allData += decodeData(pkt.Data64)
|
||||
}
|
||||
}
|
||||
|
||||
if allData != "hello" {
|
||||
t.Errorf("Expected 'hello', got '%s'", allData)
|
||||
}
|
||||
|
||||
// Send ACK
|
||||
sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: 5, RWnd: CwndSize})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check for EOF packet
|
||||
packets = tw.GetPackets()
|
||||
hasEof := false
|
||||
for _, pkt := range packets {
|
||||
if pkt.Eof {
|
||||
hasEof = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasEof {
|
||||
t.Error("Expected EOF packet after ACKing all data")
|
||||
}
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
func TestDisconnectedToConnectedTransition(t *testing.T) {
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManager()
|
||||
|
||||
reader := strings.NewReader("test data")
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientConnected failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
packets := tw.GetPackets()
|
||||
if len(packets) == 0 {
|
||||
t.Fatal("Expected cirbuf drain after connect")
|
||||
}
|
||||
|
||||
allData := ""
|
||||
for _, pkt := range packets {
|
||||
if pkt.Data64 != "" {
|
||||
allData += decodeData(pkt.Data64)
|
||||
}
|
||||
}
|
||||
|
||||
if allData != "test data" {
|
||||
t.Errorf("Expected 'test data', got '%s'", allData)
|
||||
}
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
func TestConnectedToDisconnectedTransition(t *testing.T) {
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManager()
|
||||
|
||||
reader := &slowReader{data: []byte("slow data"), delay: 50 * time.Millisecond}
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientConnected failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
sm.ClientDisconnected()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
func TestFlowControl(t *testing.T) {
|
||||
cwndSize := 1024
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManagerWithSizes(cwndSize, 8*1024)
|
||||
|
||||
largeData := strings.Repeat("x", cwndSize+500)
|
||||
reader := strings.NewReader(largeData)
|
||||
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = sm.ClientConnected("1", tw, cwndSize, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientConnected failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
packets := tw.GetPackets()
|
||||
totalData := 0
|
||||
for _, pkt := range packets {
|
||||
if pkt.Data64 != "" {
|
||||
decoded, _ := base64.StdEncoding.DecodeString(pkt.Data64)
|
||||
totalData += len(decoded)
|
||||
}
|
||||
}
|
||||
|
||||
if totalData > cwndSize {
|
||||
t.Errorf("Sent %d bytes without ACK, exceeds cwnd size %d", totalData, cwndSize)
|
||||
}
|
||||
|
||||
sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: int64(totalData), RWnd: int64(cwndSize)})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
func TestSequenceNumbering(t *testing.T) {
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManager()
|
||||
|
||||
reader := strings.NewReader("abcdefghij")
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientConnected failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
packets := tw.GetPackets()
|
||||
if len(packets) == 0 {
|
||||
t.Fatal("Expected packets")
|
||||
}
|
||||
|
||||
expectedSeq := int64(0)
|
||||
for _, pkt := range packets {
|
||||
if pkt.Data64 == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if pkt.Seq != expectedSeq {
|
||||
t.Errorf("Expected seq %d, got %d", expectedSeq, pkt.Seq)
|
||||
}
|
||||
|
||||
decoded, _ := base64.StdEncoding.DecodeString(pkt.Data64)
|
||||
expectedSeq += int64(len(decoded))
|
||||
}
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
func TestTerminalEventOrdering(t *testing.T) {
|
||||
tw := &testWriter{}
|
||||
sm := MakeStreamManager()
|
||||
|
||||
reader := strings.NewReader("data")
|
||||
err := sm.AttachReader(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachReader failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = sm.ClientConnected("1", tw, CwndSize, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientConnected failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
packets := tw.GetPackets()
|
||||
if len(packets) == 0 {
|
||||
t.Fatal("Expected data packets")
|
||||
}
|
||||
|
||||
hasData := false
|
||||
hasEof := false
|
||||
eofSeq := int64(-1)
|
||||
|
||||
for _, pkt := range packets {
|
||||
if pkt.Data64 != "" {
|
||||
hasData = true
|
||||
}
|
||||
if pkt.Eof {
|
||||
hasEof = true
|
||||
eofSeq = pkt.Seq
|
||||
}
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
t.Error("Expected data packet")
|
||||
}
|
||||
|
||||
if hasEof {
|
||||
t.Error("Should not have EOF before ACK")
|
||||
}
|
||||
|
||||
sm.RecvAck(wshrpc.CommandStreamAckData{Id: "1", Seq: 4, RWnd: CwndSize})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
packets = tw.GetPackets()
|
||||
hasEof = false
|
||||
for _, pkt := range packets {
|
||||
if pkt.Eof {
|
||||
hasEof = true
|
||||
eofSeq = pkt.Seq
|
||||
}
|
||||
}
|
||||
|
||||
if !hasEof {
|
||||
t.Error("Expected EOF after ACKing all data")
|
||||
}
|
||||
|
||||
if eofSeq != 4 {
|
||||
t.Errorf("Expected EOF at seq 4, got %d", eofSeq)
|
||||
}
|
||||
|
||||
sm.Close()
|
||||
}
|
||||
|
||||
type slowReader struct {
|
||||
data []byte
|
||||
pos int
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (sr *slowReader) Read(p []byte) (n int, err error) {
|
||||
if sr.pos >= len(sr.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
time.Sleep(sr.delay)
|
||||
|
||||
n = copy(p, sr.data[sr.pos:])
|
||||
sr.pos += n
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
|
@ -85,7 +85,7 @@ type SSHConn struct {
|
|||
var ConnServerCmdTemplate = strings.TrimSpace(
|
||||
strings.Join([]string{
|
||||
"%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm; exit 0);",
|
||||
"exec %s connserver --conn %s %s",
|
||||
"exec %s connserver --conn %s %s %s",
|
||||
}, "\n"))
|
||||
|
||||
func IsLocalConnName(connName string) bool {
|
||||
|
|
@ -285,8 +285,9 @@ func (conn *SSHConn) GetConfigShellPath() string {
|
|||
// returns (needsInstall, clientVersion, osArchStr, error)
|
||||
// if wsh is not installed, the clientVersion will be "not-installed", and it will also return an osArchStr
|
||||
// if clientVersion is set, then no osArchStr will be returned
|
||||
func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (bool, string, string, error) {
|
||||
conn.Infof(ctx, "running StartConnServer...\n")
|
||||
// if useRouterMode is true, will start connserver with --router-domainsocket flag
|
||||
func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool, useRouterMode bool) (bool, string, string, error) {
|
||||
conn.Infof(ctx, "running StartConnServer (routerMode=%v)...\n", useRouterMode)
|
||||
allowed := WithLockRtn(conn, func() bool {
|
||||
return conn.Status == Status_Connecting
|
||||
})
|
||||
|
|
@ -296,10 +297,19 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
|
|||
client := conn.GetClient()
|
||||
wshPath := conn.getWshPath()
|
||||
sockName := conn.GetDomainSocketName()
|
||||
rpcCtx := wshrpc.RpcContext{
|
||||
RouteId: wshutil.MakeConnectionRouteId(conn.GetName()),
|
||||
SockName: sockName,
|
||||
Conn: conn.GetName(),
|
||||
var rpcCtx wshrpc.RpcContext
|
||||
if useRouterMode {
|
||||
rpcCtx = wshrpc.RpcContext{
|
||||
IsRouter: true,
|
||||
SockName: sockName,
|
||||
Conn: conn.GetName(),
|
||||
}
|
||||
} else {
|
||||
rpcCtx = wshrpc.RpcContext{
|
||||
RouteId: wshutil.MakeConnectionRouteId(conn.GetName()),
|
||||
SockName: sockName,
|
||||
Conn: conn.GetName(),
|
||||
}
|
||||
}
|
||||
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx)
|
||||
if err != nil {
|
||||
|
|
@ -321,7 +331,11 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
|
|||
if wavebase.IsDevMode() {
|
||||
devFlag = "--dev"
|
||||
}
|
||||
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag)
|
||||
routerFlag := ""
|
||||
if useRouterMode {
|
||||
routerFlag = "--router-domainsocket"
|
||||
}
|
||||
cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag, routerFlag)
|
||||
log.Printf("starting conn controller: %q\n", cmdStr)
|
||||
shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr))
|
||||
blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr)
|
||||
|
|
@ -702,7 +716,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
|
|||
err = fmt.Errorf("error opening domain socket listener: %w", err)
|
||||
return WshCheckResult{NoWshReason: "error opening domain socket", NoWshCode: NoWshCode_DomainSocketError, WshError: err}
|
||||
}
|
||||
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false)
|
||||
needsInstall, clientVersion, osArchStr, err := conn.StartConnServer(ctx, false, false)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR starting conn server: %v\n", err)
|
||||
err = fmt.Errorf("error starting conn server: %w", err)
|
||||
|
|
@ -716,7 +730,7 @@ func (conn *SSHConn) tryEnableWsh(ctx context.Context, clientDisplayName string)
|
|||
err = fmt.Errorf("error installing wsh: %w", err)
|
||||
return WshCheckResult{NoWshReason: "error installing wsh/connserver", NoWshCode: NoWshCode_InstallError, WshError: err}
|
||||
}
|
||||
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true)
|
||||
needsInstall, clientVersion, _, err = conn.StartConnServer(ctx, true, false)
|
||||
if err != nil {
|
||||
conn.Infof(ctx, "ERROR starting conn server (after install): %v\n", err)
|
||||
err = fmt.Errorf("error starting conn server (after install): %w", err)
|
||||
|
|
@ -842,23 +856,38 @@ func (conn *SSHConn) ClearWshError() {
|
|||
})
|
||||
}
|
||||
|
||||
func getConnInternal(opts *remote.SSHOpts) *SSHConn {
|
||||
func getConnInternal(opts *remote.SSHOpts, createIfNotExists bool) *SSHConn {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
rtn := clientControllerMap[*opts]
|
||||
if rtn == nil {
|
||||
if rtn == nil && createIfNotExists {
|
||||
rtn = &SSHConn{Lock: &sync.Mutex{}, Status: Status_Init, WshEnabled: &atomic.Bool{}, Opts: opts, HasWaiter: &atomic.Bool{}}
|
||||
clientControllerMap[*opts] = rtn
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// does NOT connect, can return nil if connection does not exist
|
||||
// does NOT connect, does not return nil
|
||||
func GetConn(opts *remote.SSHOpts) *SSHConn {
|
||||
conn := getConnInternal(opts)
|
||||
conn := getConnInternal(opts, true)
|
||||
return conn
|
||||
}
|
||||
|
||||
func IsConnected(connName string) (bool, error) {
|
||||
if IsLocalConnName(connName) {
|
||||
return true, nil
|
||||
}
|
||||
connOpts, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error parsing connection name: %w", err)
|
||||
}
|
||||
conn := getConnInternal(connOpts, false)
|
||||
if conn == nil {
|
||||
return false, nil
|
||||
}
|
||||
return conn.GetStatus() == Status_Connected, nil
|
||||
}
|
||||
|
||||
// Convenience function for ensuring a connection is established
|
||||
func EnsureConnection(ctx context.Context, connName string) error {
|
||||
if IsLocalConnName(connName) {
|
||||
|
|
@ -888,7 +917,7 @@ func EnsureConnection(ctx context.Context, connName string) error {
|
|||
}
|
||||
|
||||
func DisconnectClient(opts *remote.SSHOpts) error {
|
||||
conn := getConnInternal(opts)
|
||||
conn := getConnInternal(opts, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("client %q not found", opts.String())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package streamclient
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -32,8 +33,8 @@ func (ft *fakeTransport) SendAck(ackPk wshrpc.CommandStreamAckData) {
|
|||
func TestBasicReadWrite(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
|
||||
reader := NewReader(1, 1024, transport)
|
||||
writer := NewWriter(1, 1024, transport)
|
||||
reader := NewReader("1", 1024, transport)
|
||||
writer := NewWriter("1", 1024, transport)
|
||||
|
||||
go func() {
|
||||
for dataPk := range transport.dataChan {
|
||||
|
|
@ -72,8 +73,8 @@ func TestBasicReadWrite(t *testing.T) {
|
|||
func TestEOF(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
|
||||
reader := NewReader(1, 1024, transport)
|
||||
writer := NewWriter(1, 1024, transport)
|
||||
reader := NewReader("1", 1024, transport)
|
||||
writer := NewWriter("1", 1024, transport)
|
||||
|
||||
go func() {
|
||||
for dataPk := range transport.dataChan {
|
||||
|
|
@ -110,8 +111,8 @@ func TestFlowControl(t *testing.T) {
|
|||
smallWindow := int64(10)
|
||||
transport := newFakeTransport()
|
||||
|
||||
reader := NewReader(1, smallWindow, transport)
|
||||
writer := NewWriter(1, smallWindow, transport)
|
||||
reader := NewReader("1", smallWindow, transport)
|
||||
writer := NewWriter("1", smallWindow, transport)
|
||||
|
||||
go func() {
|
||||
for dataPk := range transport.dataChan {
|
||||
|
|
@ -163,8 +164,8 @@ func TestFlowControl(t *testing.T) {
|
|||
func TestError(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
|
||||
reader := NewReader(1, 1024, transport)
|
||||
writer := NewWriter(1, 1024, transport)
|
||||
reader := NewReader("1", 1024, transport)
|
||||
writer := NewWriter("1", 1024, transport)
|
||||
|
||||
go func() {
|
||||
for dataPk := range transport.dataChan {
|
||||
|
|
@ -194,8 +195,8 @@ func TestError(t *testing.T) {
|
|||
func TestCancel(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
|
||||
reader := NewReader(1, 1024, transport)
|
||||
writer := NewWriter(1, 1024, transport)
|
||||
reader := NewReader("1", 1024, transport)
|
||||
writer := NewWriter("1", 1024, transport)
|
||||
|
||||
go func() {
|
||||
for dataPk := range transport.dataChan {
|
||||
|
|
@ -227,8 +228,8 @@ func TestCancel(t *testing.T) {
|
|||
func TestMultipleWrites(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
|
||||
reader := NewReader(1, 1024, transport)
|
||||
writer := NewWriter(1, 1024, transport)
|
||||
reader := NewReader("1", 1024, transport)
|
||||
writer := NewWriter("1", 1024, transport)
|
||||
|
||||
go func() {
|
||||
for dataPk := range transport.dataChan {
|
||||
|
|
@ -265,3 +266,258 @@ func TestMultipleWrites(t *testing.T) {
|
|||
t.Fatalf("Expected %q, got %q", expected, string(buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutOfOrderPackets(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
reader := NewReader("test-ooo", 1024, transport)
|
||||
|
||||
packet0 := wshrpc.CommandStreamData{
|
||||
Id: "test-ooo",
|
||||
Seq: 0,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("AAAAA")),
|
||||
}
|
||||
packet5 := wshrpc.CommandStreamData{
|
||||
Id: "test-ooo",
|
||||
Seq: 5,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("BBBBB")),
|
||||
}
|
||||
packet10 := wshrpc.CommandStreamData{
|
||||
Id: "test-ooo",
|
||||
Seq: 10,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("CCCCC")),
|
||||
}
|
||||
packet15 := wshrpc.CommandStreamData{
|
||||
Id: "test-ooo",
|
||||
Seq: 15,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("DDDDD")),
|
||||
}
|
||||
|
||||
// Send packets out of order: 0, 10, 15, 5
|
||||
reader.RecvData(packet0)
|
||||
reader.RecvData(packet10) // OOO - should be buffered
|
||||
reader.RecvData(packet15) // OOO - should be buffered
|
||||
reader.RecvData(packet5) // fills the gap - should trigger processing
|
||||
|
||||
// Read all data
|
||||
buf := make([]byte, 1024)
|
||||
totalRead := 0
|
||||
expectedLen := 20 // 4 packets * 5 bytes each
|
||||
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
for totalRead < expectedLen {
|
||||
n, err := reader.Read(buf[totalRead:])
|
||||
if err != nil {
|
||||
t.Errorf("Read failed: %v", err)
|
||||
return
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
close(readDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-readDone:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("Read didn't complete in time. Read %d bytes, expected %d", totalRead, expectedLen)
|
||||
}
|
||||
|
||||
if totalRead != expectedLen {
|
||||
t.Fatalf("Expected to read %d bytes, got %d", expectedLen, totalRead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutOfOrderWithDuplicates(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
reader := NewReader("test-dup", 1024, transport)
|
||||
|
||||
packet0 := wshrpc.CommandStreamData{
|
||||
Id: "test-dup",
|
||||
Seq: 0,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("aaaaa")),
|
||||
}
|
||||
packet10 := wshrpc.CommandStreamData{
|
||||
Id: "test-dup",
|
||||
Seq: 10,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("ccccc")),
|
||||
}
|
||||
packet5First := wshrpc.CommandStreamData{
|
||||
Id: "test-dup",
|
||||
Seq: 5,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("xxxxx")),
|
||||
}
|
||||
packet5Second := wshrpc.CommandStreamData{
|
||||
Id: "test-dup",
|
||||
Seq: 5,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("bbbbb")),
|
||||
}
|
||||
|
||||
reader.RecvData(packet0)
|
||||
reader.RecvData(packet10) // OOO - buffered
|
||||
reader.RecvData(packet5First) // OOO - buffered
|
||||
reader.RecvData(packet5First) // Duplicate - should be ignored
|
||||
reader.RecvData(packet5Second) // Duplicate with different data - should be ignored
|
||||
|
||||
// Read all data - should get all 3 packets in order
|
||||
buf := make([]byte, 20)
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
|
||||
// Should get all 15 bytes (3 packets * 5 bytes)
|
||||
if n != 15 {
|
||||
t.Fatalf("Expected to read 15 bytes, got %d", n)
|
||||
}
|
||||
|
||||
// Should be "aaaaaxxxxxccccc" (first packet received for each seq wins)
|
||||
expected := "aaaaaxxxxxccccc"
|
||||
if string(buf[:n]) != expected {
|
||||
t.Fatalf("Expected %q, got %q", expected, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutOfOrderWithGaps(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
reader := NewReader("test-gaps", 1024, transport)
|
||||
|
||||
packet0 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 0,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("aaaaa")),
|
||||
}
|
||||
packet20 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 20,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("eeeee")),
|
||||
}
|
||||
packet40 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 40,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("iiiii")),
|
||||
}
|
||||
packet5 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 5,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("bbbbb")),
|
||||
}
|
||||
|
||||
reader.RecvData(packet0)
|
||||
reader.RecvData(packet40) // Way ahead - should be buffered
|
||||
reader.RecvData(packet20) // Still ahead - should be buffered
|
||||
|
||||
// Read first packet
|
||||
buf := make([]byte, 10)
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n != 5 || string(buf[:n]) != "aaaaa" {
|
||||
t.Fatalf("Expected 'aaaaa', got %q", string(buf[:n]))
|
||||
}
|
||||
|
||||
// Send packet to partially fill gap
|
||||
reader.RecvData(packet5)
|
||||
|
||||
// Should be able to read it now
|
||||
n, err = reader.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Second read failed: %v", err)
|
||||
}
|
||||
if n != 5 || string(buf[:n]) != "bbbbb" {
|
||||
t.Fatalf("Expected 'bbbbb', got %q", string(buf[:n]))
|
||||
}
|
||||
|
||||
packet10 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 10,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("ccccc")),
|
||||
}
|
||||
packet15 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 15,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("ddddd")),
|
||||
}
|
||||
packet25 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 25,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("fffff")),
|
||||
}
|
||||
packet30 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 30,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("ggggg")),
|
||||
}
|
||||
packet35 := wshrpc.CommandStreamData{
|
||||
Id: "test-gaps",
|
||||
Seq: 35,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("hhhhh")),
|
||||
}
|
||||
|
||||
reader.RecvData(packet10)
|
||||
reader.RecvData(packet15)
|
||||
reader.RecvData(packet25)
|
||||
reader.RecvData(packet30)
|
||||
reader.RecvData(packet35)
|
||||
|
||||
// Read all remaining data at once
|
||||
allData := make([]byte, 100)
|
||||
totalRead := 0
|
||||
for totalRead < 35 {
|
||||
n, err = reader.Read(allData[totalRead:])
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
|
||||
expected := "cccccdddddeeeeefffffggggghhhhhiiiii"
|
||||
if string(allData[:totalRead]) != expected {
|
||||
t.Fatalf("Expected %q, got %q", expected, string(allData[:totalRead]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutOfOrderWithEOF(t *testing.T) {
|
||||
transport := newFakeTransport()
|
||||
reader := NewReader("test-eof", 1024, transport)
|
||||
|
||||
packet0 := wshrpc.CommandStreamData{
|
||||
Id: "test-eof",
|
||||
Seq: 0,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("first")),
|
||||
}
|
||||
packet11 := wshrpc.CommandStreamData{
|
||||
Id: "test-eof",
|
||||
Seq: 11,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("third")),
|
||||
Eof: true,
|
||||
}
|
||||
packet5 := wshrpc.CommandStreamData{
|
||||
Id: "test-eof",
|
||||
Seq: 5,
|
||||
Data64: base64.StdEncoding.EncodeToString([]byte("second")),
|
||||
}
|
||||
|
||||
reader.RecvData(packet0)
|
||||
reader.RecvData(packet11) // OOO with EOF
|
||||
reader.RecvData(packet5) // Fill the gap
|
||||
|
||||
// Read all data
|
||||
buf := make([]byte, 20)
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "firstsecondthird"
|
||||
if string(buf[:n]) != expected {
|
||||
t.Fatalf("Expected %q, got %q", expected, string(buf[:n]))
|
||||
}
|
||||
|
||||
// Should get EOF now
|
||||
_, err = reader.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("Expected EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/utilds"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
type workItem struct {
|
||||
|
|
@ -17,36 +16,23 @@ type workItem struct {
|
|||
dataPk wshrpc.CommandStreamData
|
||||
}
|
||||
|
||||
type StreamWriter interface {
|
||||
RecvAck(ackPk wshrpc.CommandStreamAckData)
|
||||
}
|
||||
|
||||
type StreamRpcInterface interface {
|
||||
StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error
|
||||
StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error
|
||||
}
|
||||
|
||||
type wshRpcAdapter struct {
|
||||
rpc *wshutil.WshRpc
|
||||
}
|
||||
|
||||
func (a *wshRpcAdapter) StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error {
|
||||
return wshclient.StreamDataAckCommand(a.rpc, data, opts)
|
||||
}
|
||||
|
||||
func (a *wshRpcAdapter) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error {
|
||||
return wshclient.StreamDataCommand(a.rpc, data, opts)
|
||||
}
|
||||
|
||||
func AdaptWshRpc(rpc *wshutil.WshRpc) StreamRpcInterface {
|
||||
return &wshRpcAdapter{rpc: rpc}
|
||||
}
|
||||
|
||||
type Broker struct {
|
||||
lock sync.Mutex
|
||||
rpcClient StreamRpcInterface
|
||||
streamIdCounter int64
|
||||
readers map[int64]*Reader
|
||||
writers map[int64]*Writer
|
||||
readerRoutes map[int64]string
|
||||
writerRoutes map[int64]string
|
||||
readerErrorSentTime map[int64]time.Time
|
||||
readers map[string]*Reader
|
||||
writers map[string]StreamWriter
|
||||
readerRoutes map[string]string
|
||||
writerRoutes map[string]string
|
||||
readerErrorSentTime map[string]time.Time
|
||||
sendQueue *utilds.WorkQueue[workItem]
|
||||
recvQueue *utilds.WorkQueue[workItem]
|
||||
}
|
||||
|
|
@ -54,12 +40,11 @@ type Broker struct {
|
|||
func NewBroker(rpcClient StreamRpcInterface) *Broker {
|
||||
b := &Broker{
|
||||
rpcClient: rpcClient,
|
||||
streamIdCounter: 0,
|
||||
readers: make(map[int64]*Reader),
|
||||
writers: make(map[int64]*Writer),
|
||||
readerRoutes: make(map[int64]string),
|
||||
writerRoutes: make(map[int64]string),
|
||||
readerErrorSentTime: make(map[int64]time.Time),
|
||||
readers: make(map[string]*Reader),
|
||||
writers: make(map[string]StreamWriter),
|
||||
readerRoutes: make(map[string]string),
|
||||
writerRoutes: make(map[string]string),
|
||||
readerErrorSentTime: make(map[string]time.Time),
|
||||
}
|
||||
b.sendQueue = utilds.NewWorkQueue(b.processSendWork)
|
||||
b.recvQueue = utilds.NewWorkQueue(b.processRecvWork)
|
||||
|
|
@ -67,13 +52,16 @@ func NewBroker(rpcClient StreamRpcInterface) *Broker {
|
|||
}
|
||||
|
||||
func (b *Broker) CreateStreamReader(readerRoute string, writerRoute string, rwnd int64) (*Reader, *wshrpc.StreamMeta) {
|
||||
return b.CreateStreamReaderWithSeq(readerRoute, writerRoute, rwnd, 0)
|
||||
}
|
||||
|
||||
func (b *Broker) CreateStreamReaderWithSeq(readerRoute string, writerRoute string, rwnd int64, startSeq int64) (*Reader, *wshrpc.StreamMeta) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
b.streamIdCounter++
|
||||
streamId := b.streamIdCounter
|
||||
streamId := uuid.New().String()
|
||||
|
||||
reader := NewReader(streamId, rwnd, b)
|
||||
reader := NewReaderWithSeq(streamId, rwnd, startSeq, b)
|
||||
b.readers[streamId] = reader
|
||||
b.readerRoutes[streamId] = readerRoute
|
||||
b.writerRoutes[streamId] = writerRoute
|
||||
|
|
@ -88,19 +76,35 @@ func (b *Broker) CreateStreamReader(readerRoute string, writerRoute string, rwnd
|
|||
return reader, meta
|
||||
}
|
||||
|
||||
func (b *Broker) AttachStreamWriter(meta *wshrpc.StreamMeta) (*Writer, error) {
|
||||
func (b *Broker) AttachStreamWriter(meta *wshrpc.StreamMeta, writer StreamWriter) error {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if _, exists := b.writers[meta.Id]; exists {
|
||||
return nil, fmt.Errorf("writer already registered for stream id %d", meta.Id)
|
||||
return fmt.Errorf("writer already registered for stream id %s", meta.Id)
|
||||
}
|
||||
|
||||
writer := NewWriter(meta.Id, meta.RWnd, b)
|
||||
b.writers[meta.Id] = writer
|
||||
b.readerRoutes[meta.Id] = meta.ReaderRouteId
|
||||
b.writerRoutes[meta.Id] = meta.WriterRouteId
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) DetachStreamWriter(streamId string) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
delete(b.writers, streamId)
|
||||
delete(b.writerRoutes, streamId)
|
||||
}
|
||||
|
||||
func (b *Broker) CreateStreamWriter(meta *wshrpc.StreamMeta) (*Writer, error) {
|
||||
writer := NewWriter(meta.Id, meta.RWnd, b)
|
||||
err := b.AttachStreamWriter(meta, writer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
|
|
@ -112,6 +116,9 @@ func (b *Broker) SendData(dataPk wshrpc.CommandStreamData) {
|
|||
b.sendQueue.Enqueue(workItem{workType: "senddata", dataPk: dataPk})
|
||||
}
|
||||
|
||||
// RecvData and RecvAck are designed to be non-blocking and must remain so to prevent deadlock.
|
||||
// They only enqueue work items to be processed asynchronously by the work queue's goroutine.
|
||||
// These methods are called from the main RPC runServer loop, so blocking here would stall all RPC processing.
|
||||
func (b *Broker) RecvData(dataPk wshrpc.CommandStreamData) {
|
||||
b.recvQueue.Enqueue(workItem{workType: "recvdata", dataPk: dataPk})
|
||||
}
|
||||
|
|
@ -220,7 +227,7 @@ func (b *Broker) Close() {
|
|||
b.recvQueue.Wait()
|
||||
}
|
||||
|
||||
func (b *Broker) cleanupReader(streamId int64) {
|
||||
func (b *Broker) cleanupReader(streamId string) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
|
|
@ -229,7 +236,7 @@ func (b *Broker) cleanupReader(streamId int64) {
|
|||
delete(b.readerErrorSentTime, streamId)
|
||||
}
|
||||
|
||||
func (b *Broker) cleanupWriter(streamId int64) {
|
||||
func (b *Broker) cleanupWriter(streamId string) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ func TestBrokerBasicReadWrite(t *testing.T) {
|
|||
broker1, broker2 := setupBrokerPair()
|
||||
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
testData := []byte("Hello, World!")
|
||||
|
|
@ -105,9 +105,9 @@ func TestBrokerEOF(t *testing.T) {
|
|||
broker1, broker2 := setupBrokerPair()
|
||||
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
testData := []byte("Test data")
|
||||
|
|
@ -134,9 +134,9 @@ func TestBrokerFlowControl(t *testing.T) {
|
|||
|
||||
smallWindow := int64(10)
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", smallWindow)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
largeData := make([]byte, 100)
|
||||
|
|
@ -180,9 +180,9 @@ func TestBrokerError(t *testing.T) {
|
|||
broker1, broker2 := setupBrokerPair()
|
||||
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
testErr := io.ErrUnexpectedEOF
|
||||
|
|
@ -202,9 +202,9 @@ func TestBrokerCancel(t *testing.T) {
|
|||
broker1, broker2 := setupBrokerPair()
|
||||
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
reader.Close()
|
||||
|
|
@ -226,9 +226,9 @@ func TestBrokerMultipleWrites(t *testing.T) {
|
|||
broker1, broker2 := setupBrokerPair()
|
||||
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
messages := []string{"First", "Second", "Third"}
|
||||
|
|
@ -261,9 +261,9 @@ func TestBrokerCleanup(t *testing.T) {
|
|||
broker1, broker2 := setupBrokerPair()
|
||||
|
||||
reader, meta := broker1.CreateStreamReader("reader1", "writer1", 1024)
|
||||
writer, err := broker2.AttachStreamWriter(meta)
|
||||
writer, err := broker2.CreateStreamWriter(meta)
|
||||
if err != nil {
|
||||
t.Fatalf("AttachStreamWriter failed: %v", err)
|
||||
t.Fatalf("CreateStreamWriter failed: %v", err)
|
||||
}
|
||||
|
||||
testData := []byte("cleanup test")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
|
|
@ -16,7 +17,7 @@ type AckSender interface {
|
|||
type Reader struct {
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
id int64
|
||||
id string
|
||||
ackSender AckSender
|
||||
readWindow int64
|
||||
nextSeq int64
|
||||
|
|
@ -25,14 +26,19 @@ type Reader struct {
|
|||
err error
|
||||
closed bool
|
||||
lastRwndSent int64
|
||||
oooPackets []wshrpc.CommandStreamData // out-of-order packets awaiting delivery
|
||||
}
|
||||
|
||||
func NewReader(id int64, readWindow int64, ackSender AckSender) *Reader {
|
||||
func NewReader(id string, readWindow int64, ackSender AckSender) *Reader {
|
||||
return NewReaderWithSeq(id, readWindow, 0, ackSender)
|
||||
}
|
||||
|
||||
func NewReaderWithSeq(id string, readWindow int64, startSeq int64, ackSender AckSender) *Reader {
|
||||
r := &Reader{
|
||||
id: id,
|
||||
readWindow: readWindow,
|
||||
ackSender: ackSender,
|
||||
nextSeq: 0,
|
||||
nextSeq: startSeq,
|
||||
lastRwndSent: readWindow,
|
||||
}
|
||||
r.cond = sync.NewCond(&r.lock)
|
||||
|
|
@ -43,7 +49,7 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) {
|
|||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
|
||||
if r.closed {
|
||||
if r.closed || r.eof || r.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -59,18 +65,25 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) {
|
|||
return
|
||||
}
|
||||
|
||||
if dataPk.Seq != r.nextSeq {
|
||||
r.err = fmt.Errorf("stream sequence mismatch: expected %d, got %d", r.nextSeq, dataPk.Seq)
|
||||
r.cond.Broadcast()
|
||||
r.sendAckLocked(false, true, "sequence mismatch error")
|
||||
if dataPk.Seq < r.nextSeq {
|
||||
return
|
||||
}
|
||||
if dataPk.Seq > r.nextSeq {
|
||||
r.addOOOPacketLocked(dataPk)
|
||||
return
|
||||
}
|
||||
|
||||
r.recvDataOrderedLocked(dataPk)
|
||||
r.processOOOPacketsLocked()
|
||||
r.cond.Broadcast()
|
||||
r.sendAckLocked(r.eof, false, "")
|
||||
}
|
||||
|
||||
func (r *Reader) recvDataOrderedLocked(dataPk wshrpc.CommandStreamData) {
|
||||
if dataPk.Data64 != "" {
|
||||
data, err := base64.StdEncoding.DecodeString(dataPk.Data64)
|
||||
if err != nil {
|
||||
r.err = err
|
||||
r.cond.Broadcast()
|
||||
r.sendAckLocked(false, true, "base64 decode error")
|
||||
return
|
||||
}
|
||||
|
|
@ -80,13 +93,40 @@ func (r *Reader) RecvData(dataPk wshrpc.CommandStreamData) {
|
|||
|
||||
if dataPk.Eof {
|
||||
r.eof = true
|
||||
r.cond.Broadcast()
|
||||
r.sendAckLocked(true, false, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) addOOOPacketLocked(dataPk wshrpc.CommandStreamData) {
|
||||
for _, pkt := range r.oooPackets {
|
||||
if pkt.Seq == dataPk.Seq {
|
||||
// this handles duplicates
|
||||
return
|
||||
}
|
||||
}
|
||||
r.oooPackets = append(r.oooPackets, dataPk)
|
||||
}
|
||||
|
||||
func (r *Reader) processOOOPacketsLocked() {
|
||||
if len(r.oooPackets) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.cond.Broadcast()
|
||||
r.sendAckLocked(false, false, "")
|
||||
sort.Slice(r.oooPackets, func(i, j int) bool {
|
||||
return r.oooPackets[i].Seq < r.oooPackets[j].Seq
|
||||
})
|
||||
consumed := 0
|
||||
for _, pkt := range r.oooPackets {
|
||||
if r.eof || r.err != nil {
|
||||
// we're done, so we can clear any pending ooo packets
|
||||
r.oooPackets = nil
|
||||
return
|
||||
}
|
||||
if pkt.Seq != r.nextSeq {
|
||||
break
|
||||
}
|
||||
r.recvDataOrderedLocked(pkt)
|
||||
consumed++
|
||||
}
|
||||
r.oooPackets = r.oooPackets[consumed:]
|
||||
}
|
||||
|
||||
func (r *Reader) sendAckLocked(fin bool, cancel bool, errStr string) {
|
||||
|
|
@ -146,6 +186,12 @@ func (r *Reader) Read(p []byte) (int, error) {
|
|||
return n, nil
|
||||
}
|
||||
|
||||
func (r *Reader) UpdateNextSeq(newSeq int64) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.nextSeq = newSeq
|
||||
}
|
||||
|
||||
func (r *Reader) Close() error {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ type DataSender interface {
|
|||
type Writer struct {
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
id int64
|
||||
id string
|
||||
dataSender DataSender
|
||||
readWindow int64
|
||||
nextSeq int64
|
||||
|
|
@ -31,7 +31,7 @@ type Writer struct {
|
|||
closed bool
|
||||
}
|
||||
|
||||
func NewWriter(id int64, readWindow int64, dataSender DataSender) *Writer {
|
||||
func NewWriter(id string, readWindow int64, dataSender DataSender) *Writer {
|
||||
w := &Writer{
|
||||
id: id,
|
||||
readWindow: readWindow,
|
||||
|
|
|
|||
|
|
@ -26,11 +26,13 @@ var (
|
|||
|
||||
type WaveJwtClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Sock string `json:"sock,omitempty"`
|
||||
RouteId string `json:"routeid,omitempty"`
|
||||
BlockId string `json:"blockid,omitempty"`
|
||||
Conn string `json:"conn,omitempty"`
|
||||
Router bool `json:"router,omitempty"`
|
||||
MainServer bool `json:"mainserver,omitempty"`
|
||||
Sock string `json:"sock,omitempty"`
|
||||
RouteId string `json:"routeid,omitempty"`
|
||||
BlockId string `json:"blockid,omitempty"`
|
||||
JobId string `json:"jobid,omitempty"`
|
||||
Conn string `json:"conn,omitempty"`
|
||||
Router bool `json:"router,omitempty"`
|
||||
}
|
||||
|
||||
type KeyPair struct {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ const (
|
|||
OType_LayoutState = "layout"
|
||||
OType_Block = "block"
|
||||
OType_MainServer = "mainserver"
|
||||
OType_Job = "job"
|
||||
OType_Temp = "temp"
|
||||
OType_Builder = "builder" // not persisted to DB
|
||||
)
|
||||
|
|
@ -41,6 +42,7 @@ var ValidOTypes = map[string]bool{
|
|||
OType_LayoutState: true,
|
||||
OType_Block: true,
|
||||
OType_MainServer: true,
|
||||
OType_Job: true,
|
||||
OType_Temp: true,
|
||||
OType_Builder: true,
|
||||
}
|
||||
|
|
@ -134,6 +136,7 @@ type Client struct {
|
|||
TosAgreed int64 `json:"tosagreed,omitempty"` // unix milli
|
||||
HasOldHistory bool `json:"hasoldhistory,omitempty"`
|
||||
TempOID string `json:"tempoid,omitempty"`
|
||||
InstallId string `json:"installid,omitempty"`
|
||||
}
|
||||
|
||||
func (*Client) GetOType() string {
|
||||
|
|
@ -288,6 +291,7 @@ type Block struct {
|
|||
Stickers []*StickerType `json:"stickers,omitempty"`
|
||||
Meta MetaMapType `json:"meta"`
|
||||
SubBlockIds []string `json:"subblockids,omitempty"`
|
||||
JobId string `json:"jobid,omitempty"` // if set, the block will render this jobid's pty output
|
||||
}
|
||||
|
||||
func (*Block) GetOType() string {
|
||||
|
|
@ -306,6 +310,49 @@ func (*MainServer) GetOType() string {
|
|||
return OType_MainServer
|
||||
}
|
||||
|
||||
type Job struct {
|
||||
OID string `json:"oid"`
|
||||
Version int `json:"version"`
|
||||
|
||||
// job metadata
|
||||
Connection string `json:"connection"`
|
||||
JobKind string `json:"jobkind"` // shell, task
|
||||
Cmd string `json:"cmd"`
|
||||
CmdArgs []string `json:"cmdargs,omitempty"`
|
||||
CmdEnv map[string]string `json:"cmdenv,omitempty"`
|
||||
JobAuthToken string `json:"jobauthtoken"` // job manger -> wave
|
||||
AttachedBlockId string `json:"attachedblockid,omitempty"`
|
||||
|
||||
// reconnect option (e.g. orphaned, so we need to kill on connect)
|
||||
TerminateOnReconnect bool `json:"terminateonreconnect,omitempty"`
|
||||
|
||||
// job manager state
|
||||
JobManagerStatus string `json:"jobmanagerstatus"` // init, running, done
|
||||
JobManagerDoneReason string `json:"jobmanagerdonereason,omitempty"` // startuperror, gone, terminated
|
||||
JobManagerStartupError string `json:"jobmanagerstartuperror,omitempty"`
|
||||
JobManagerPid int `json:"jobmanagerpid,omitempty"`
|
||||
JobManagerStartTs int64 `json:"jobmanagerstartts,omitempty"` // exact process start time (milliseconds)
|
||||
|
||||
// cmd/process runtime info
|
||||
CmdPid int `json:"cmdpid,omitempty"` // command process id
|
||||
CmdStartTs int64 `json:"cmdstartts,omitempty"` // exact command process start time (milliseconds from epoch)
|
||||
CmdTermSize TermSize `json:"cmdtermsize"`
|
||||
CmdExitTs int64 `json:"cmdexitts,omitempty"` // timestamp (milliseconds) -- use CmdExitTs > 0 to check if command has exited
|
||||
CmdExitCode *int `json:"cmdexitcode,omitempty"` // nil when CmdExitSignal is set. success exit is when CmdExitCode is 0
|
||||
CmdExitSignal string `json:"cmdexitsignal,omitempty"` // empty string if CmdExitCode is set
|
||||
CmdExitError string `json:"cmdexiterror,omitempty"`
|
||||
|
||||
// output info
|
||||
StreamDone bool `json:"streamdone,omitempty"`
|
||||
StreamError string `json:"streamerror,omitempty"`
|
||||
|
||||
Meta MetaMapType `json:"meta"`
|
||||
}
|
||||
|
||||
func (*Job) GetOType() string {
|
||||
return OType_Job
|
||||
}
|
||||
|
||||
func AllWaveObjTypes() []reflect.Type {
|
||||
return []reflect.Type{
|
||||
reflect.TypeOf(&Client{}),
|
||||
|
|
@ -315,6 +362,7 @@ func AllWaveObjTypes() []reflect.Type {
|
|||
reflect.TypeOf(&Block{}),
|
||||
reflect.TypeOf(&LayoutState{}),
|
||||
reflect.TypeOf(&MainServer{}),
|
||||
reflect.TypeOf(&Job{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/blockcontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/filestore"
|
||||
"github.com/wavetermdev/waveterm/pkg/jobcontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/telemetry"
|
||||
"github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata"
|
||||
|
|
@ -167,6 +168,19 @@ func DeleteBlock(ctx context.Context, blockId string, recursive bool) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
if block.JobId != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("DetachJobFromBlock", recover())
|
||||
}()
|
||||
detachCtx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancelFn()
|
||||
err := jobcontroller.DetachJobFromBlock(detachCtx, block.JobId, false)
|
||||
if err != nil {
|
||||
log.Printf("error detaching job from block %s: %v", blockId, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
parentBlockCount, err := deleteBlockObj(ctx, blockId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting block: %w", err)
|
||||
|
|
|
|||
|
|
@ -50,6 +50,14 @@ func EnsureInitialData() (bool, error) {
|
|||
return firstLaunch, fmt.Errorf("error updating client: %w", err)
|
||||
}
|
||||
}
|
||||
if client.InstallId == "" {
|
||||
log.Println("client.InstallId is empty")
|
||||
client.InstallId = uuid.NewString()
|
||||
err = wstore.DBUpdate(ctx, client)
|
||||
if err != nil {
|
||||
return firstLaunch, fmt.Errorf("error updating client: %w", err)
|
||||
}
|
||||
}
|
||||
log.Printf("clientid: %s\n", client.OID)
|
||||
if len(client.WindowIds) == 1 {
|
||||
log.Println("client has one window")
|
||||
|
|
|
|||
|
|
@ -9,14 +9,11 @@ import (
|
|||
|
||||
"github.com/wavetermdev/waveterm/pkg/tsgen/tsgenmeta"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
const (
|
||||
WSCommand_SetBlockTermSize = "setblocktermsize"
|
||||
WSCommand_BlockInput = "blockinput"
|
||||
WSCommand_Rpc = "rpc"
|
||||
WSCommand_Rpc = "rpc"
|
||||
)
|
||||
|
||||
type WSCommandType interface {
|
||||
|
|
@ -28,8 +25,6 @@ func WSCommandTypeUnionMeta() tsgenmeta.TypeUnionMeta {
|
|||
BaseType: reflect.TypeOf((*WSCommandType)(nil)).Elem(),
|
||||
TypeFieldName: "wscommand",
|
||||
Types: []reflect.Type{
|
||||
reflect.TypeOf(SetBlockTermSizeWSCommand{}),
|
||||
reflect.TypeOf(BlockInputWSCommand{}),
|
||||
reflect.TypeOf(WSRpcCommand{}),
|
||||
},
|
||||
}
|
||||
|
|
@ -44,46 +39,12 @@ func (cmd *WSRpcCommand) GetWSCommand() string {
|
|||
return cmd.WSCommand
|
||||
}
|
||||
|
||||
type SetBlockTermSizeWSCommand struct {
|
||||
WSCommand string `json:"wscommand" tstype:"\"setblocktermsize\""`
|
||||
BlockId string `json:"blockid"`
|
||||
TermSize waveobj.TermSize `json:"termsize"`
|
||||
}
|
||||
|
||||
func (cmd *SetBlockTermSizeWSCommand) GetWSCommand() string {
|
||||
return cmd.WSCommand
|
||||
}
|
||||
|
||||
type BlockInputWSCommand struct {
|
||||
WSCommand string `json:"wscommand" tstype:"\"blockinput\""`
|
||||
BlockId string `json:"blockid"`
|
||||
InputData64 string `json:"inputdata64"`
|
||||
}
|
||||
|
||||
func (cmd *BlockInputWSCommand) GetWSCommand() string {
|
||||
return cmd.WSCommand
|
||||
}
|
||||
|
||||
func ParseWSCommandMap(cmdMap map[string]any) (WSCommandType, error) {
|
||||
cmdType, ok := cmdMap["wscommand"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no wscommand field in command map")
|
||||
}
|
||||
switch cmdType {
|
||||
case WSCommand_SetBlockTermSize:
|
||||
var cmd SetBlockTermSizeWSCommand
|
||||
err := utilfn.DoMapStructure(&cmd, cmdMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding SetBlockTermSizeWSCommand: %w", err)
|
||||
}
|
||||
return &cmd, nil
|
||||
case WSCommand_BlockInput:
|
||||
var cmd BlockInputWSCommand
|
||||
err := utilfn.DoMapStructure(&cmd, cmdMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding BlockInputWSCommand: %w", err)
|
||||
}
|
||||
return &cmd, nil
|
||||
case WSCommand_Rpc:
|
||||
var cmd WSRpcCommand
|
||||
err := utilfn.DoMapStructure(&cmd, cmdMap)
|
||||
|
|
@ -94,5 +55,4 @@ func ParseWSCommandMap(cmdMap map[string]any) (WSCommandType, error) {
|
|||
default:
|
||||
return nil, fmt.Errorf("unknown wscommand type %q", cmdType)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"github.com/wavetermdev/waveterm/pkg/eventbus"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/web/webcmd"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
|
|
@ -110,40 +109,6 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan ba
|
|||
}
|
||||
cmdType = wsCommand.GetWSCommand()
|
||||
switch cmd := wsCommand.(type) {
|
||||
case *webcmd.SetBlockTermSizeWSCommand:
|
||||
data := wshrpc.CommandBlockInputData{
|
||||
BlockId: cmd.BlockId,
|
||||
TermSize: &cmd.TermSize,
|
||||
}
|
||||
rpcMsg := wshutil.RpcMessage{
|
||||
Command: wshrpc.Command_ControllerInput,
|
||||
Data: data,
|
||||
}
|
||||
msgBytes, err := json.Marshal(rpcMsg)
|
||||
if err != nil {
|
||||
// this really should never fail since we just unmarshalled this value
|
||||
log.Printf("[websocket] error marshalling rpc message: %v\n", err)
|
||||
return
|
||||
}
|
||||
rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes}
|
||||
|
||||
case *webcmd.BlockInputWSCommand:
|
||||
data := wshrpc.CommandBlockInputData{
|
||||
BlockId: cmd.BlockId,
|
||||
InputData64: cmd.InputData64,
|
||||
}
|
||||
rpcMsg := wshutil.RpcMessage{
|
||||
Command: wshrpc.Command_ControllerInput,
|
||||
Data: data,
|
||||
}
|
||||
msgBytes, err := json.Marshal(rpcMsg)
|
||||
if err != nil {
|
||||
// this really should never fail since we just unmarshalled this value
|
||||
log.Printf("[websocket] error marshalling rpc message: %v\n", err)
|
||||
return
|
||||
}
|
||||
rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes}
|
||||
|
||||
case *webcmd.WSRpcCommand:
|
||||
rpcMsg := cmd.Message
|
||||
if rpcMsg == nil {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ const (
|
|||
Event_BlockFile = "blockfile"
|
||||
Event_Config = "config"
|
||||
Event_UserInput = "userinput"
|
||||
Event_RouteGone = "route:gone"
|
||||
Event_RouteDown = "route:down"
|
||||
Event_RouteUp = "route:up"
|
||||
Event_WorkspaceUpdate = "workspace:update"
|
||||
Event_WaveAIRateLimit = "waveai:ratelimit"
|
||||
Event_WaveAppAppGoUpdated = "waveapp:appgoupdated"
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@
|
|||
package wshclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
|
|
@ -17,21 +19,22 @@ func (*WshServer) WshServerImpl() {}
|
|||
|
||||
var WshServerImpl = WshServer{}
|
||||
|
||||
const (
|
||||
DefaultOutputChSize = 32
|
||||
DefaultInputChSize = 32
|
||||
)
|
||||
|
||||
var waveSrvClient_Singleton *wshutil.WshRpc
|
||||
var waveSrvClient_Once = &sync.Once{}
|
||||
|
||||
const BareClientRoute = "bare"
|
||||
var waveSrvClient_RouteId string
|
||||
|
||||
func GetBareRpcClient() *wshutil.WshRpc {
|
||||
waveSrvClient_Once.Do(func() {
|
||||
waveSrvClient_Singleton = wshutil.MakeWshRpc(wshrpc.RpcContext{}, &WshServerImpl, "bare-client")
|
||||
wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, BareClientRoute)
|
||||
waveSrvClient_RouteId = fmt.Sprintf("bare:%s", uuid.New().String())
|
||||
// we can safely ignore the error from RegisterTrustedLeaf since the route is valid
|
||||
wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, waveSrvClient_RouteId)
|
||||
wps.Broker.SetClient(wshutil.DefaultRouter)
|
||||
})
|
||||
return waveSrvClient_Singleton
|
||||
}
|
||||
|
||||
func GetBareRpcClientRouteId() string {
|
||||
GetBareRpcClient()
|
||||
return waveSrvClient_RouteId
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,24 @@ func AuthenticateCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) (
|
|||
return resp, err
|
||||
}
|
||||
|
||||
// command "authenticatejobmanager", wshserver.AuthenticateJobManagerCommand
|
||||
func AuthenticateJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateJobManagerData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "authenticatejobmanager", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "authenticatejobmanagerverify", wshserver.AuthenticateJobManagerVerifyCommand
|
||||
func AuthenticateJobManagerVerifyCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateJobManagerData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "authenticatejobmanagerverify", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "authenticatetojobmanager", wshserver.AuthenticateToJobManagerCommand
|
||||
func AuthenticateToJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateToJobData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "authenticatetojobmanager", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "authenticatetoken", wshserver.AuthenticateTokenCommand
|
||||
func AuthenticateTokenCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateTokenData, opts *wshrpc.RpcOpts) (wshrpc.CommandAuthenticateRtnData, error) {
|
||||
resp, err := sendRpcRequestCallHelper[wshrpc.CommandAuthenticateRtnData](w, "authenticatetoken", data, opts)
|
||||
|
|
@ -458,6 +476,90 @@ func GetWaveAIRateLimitCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (*uctype
|
|||
return resp, err
|
||||
}
|
||||
|
||||
// command "jobcmdexited", wshserver.JobCmdExitedCommand
|
||||
func JobCmdExitedCommand(w *wshutil.WshRpc, data wshrpc.CommandJobCmdExitedData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcmdexited", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerattachjob", wshserver.JobControllerAttachJobCommand
|
||||
func JobControllerAttachJobCommand(w *wshutil.WshRpc, data wshrpc.CommandJobControllerAttachJobData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerattachjob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerconnectedjobs", wshserver.JobControllerConnectedJobsCommand
|
||||
func JobControllerConnectedJobsCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
|
||||
resp, err := sendRpcRequestCallHelper[[]string](w, "jobcontrollerconnectedjobs", nil, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "jobcontrollerdeletejob", wshserver.JobControllerDeleteJobCommand
|
||||
func JobControllerDeleteJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdeletejob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerdetachjob", wshserver.JobControllerDetachJobCommand
|
||||
func JobControllerDetachJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdetachjob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerdisconnectjob", wshserver.JobControllerDisconnectJobCommand
|
||||
func JobControllerDisconnectJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerdisconnectjob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerexitjob", wshserver.JobControllerExitJobCommand
|
||||
func JobControllerExitJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerexitjob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerlist", wshserver.JobControllerListCommand
|
||||
func JobControllerListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]*waveobj.Job, error) {
|
||||
resp, err := sendRpcRequestCallHelper[[]*waveobj.Job](w, "jobcontrollerlist", nil, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "jobcontrollerreconnectjob", wshserver.JobControllerReconnectJobCommand
|
||||
func JobControllerReconnectJobCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerreconnectjob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerreconnectjobsforconn", wshserver.JobControllerReconnectJobsForConnCommand
|
||||
func JobControllerReconnectJobsForConnCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobcontrollerreconnectjobsforconn", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobcontrollerstartjob", wshserver.JobControllerStartJobCommand
|
||||
func JobControllerStartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandJobControllerStartJobData, opts *wshrpc.RpcOpts) (string, error) {
|
||||
resp, err := sendRpcRequestCallHelper[string](w, "jobcontrollerstartjob", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "jobinput", wshserver.JobInputCommand
|
||||
func JobInputCommand(w *wshutil.WshRpc, data wshrpc.CommandJobInputData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobinput", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "jobprepareconnect", wshserver.JobPrepareConnectCommand
|
||||
func JobPrepareConnectCommand(w *wshutil.WshRpc, data wshrpc.CommandJobPrepareConnectData, opts *wshrpc.RpcOpts) (*wshrpc.CommandJobConnectRtnData, error) {
|
||||
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandJobConnectRtnData](w, "jobprepareconnect", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "jobstartstream", wshserver.JobStartStreamCommand
|
||||
func JobStartStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandJobStartStreamData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "jobstartstream", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "listallappfiles", wshserver.ListAllAppFilesCommand
|
||||
func ListAllAppFilesCommand(w *wshutil.WshRpc, data wshrpc.CommandListAllAppFilesData, opts *wshrpc.RpcOpts) (*wshrpc.CommandListAllAppFilesRtnData, error) {
|
||||
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandListAllAppFilesRtnData](w, "listallappfiles", data, opts)
|
||||
|
|
@ -524,6 +626,12 @@ func RecordTEventCommand(w *wshutil.WshRpc, data telemetrydata.TEvent, opts *wsh
|
|||
return err
|
||||
}
|
||||
|
||||
// command "remotedisconnectfromjobmanager", wshserver.RemoteDisconnectFromJobManagerCommand
|
||||
func RemoteDisconnectFromJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteDisconnectFromJobManagerData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "remotedisconnectfromjobmanager", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "remotefilecopy", wshserver.RemoteFileCopyCommand
|
||||
func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) (bool, error) {
|
||||
resp, err := sendRpcRequestCallHelper[bool](w, "remotefilecopy", data, opts)
|
||||
|
|
@ -583,6 +691,18 @@ func RemoteMkdirCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) er
|
|||
return err
|
||||
}
|
||||
|
||||
// command "remotereconnecttojobmanager", wshserver.RemoteReconnectToJobManagerCommand
|
||||
func RemoteReconnectToJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteReconnectToJobManagerData, opts *wshrpc.RpcOpts) (*wshrpc.CommandRemoteReconnectToJobManagerRtnData, error) {
|
||||
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandRemoteReconnectToJobManagerRtnData](w, "remotereconnecttojobmanager", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "remotestartjob", wshserver.RemoteStartJobCommand
|
||||
func RemoteStartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStartJobData, opts *wshrpc.RpcOpts) (*wshrpc.CommandStartJobRtnData, error) {
|
||||
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandStartJobRtnData](w, "remotestartjob", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "remotestreamcpudata", wshserver.RemoteStreamCpuDataCommand
|
||||
func RemoteStreamCpuDataCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.TimeSeriesData] {
|
||||
return sendRpcRequestResponseStreamHelper[wshrpc.TimeSeriesData](w, "remotestreamcpudata", nil, opts)
|
||||
|
|
@ -598,6 +718,12 @@ func RemoteTarStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTa
|
|||
return sendRpcRequestResponseStreamHelper[iochantypes.Packet](w, "remotetarstream", data, opts)
|
||||
}
|
||||
|
||||
// command "remoteterminatejobmanager", wshserver.RemoteTerminateJobManagerCommand
|
||||
func RemoteTerminateJobManagerCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteTerminateJobManagerData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "remoteterminatejobmanager", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "remotewritefile", wshserver.RemoteWriteFileCommand
|
||||
func RemoteWriteFileCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "remotewritefile", data, opts)
|
||||
|
|
@ -688,6 +814,12 @@ func StartBuilderCommand(w *wshutil.WshRpc, data wshrpc.CommandStartBuilderData,
|
|||
return err
|
||||
}
|
||||
|
||||
// command "startjob", wshserver.StartJobCommand
|
||||
func StartJobCommand(w *wshutil.WshRpc, data wshrpc.CommandStartJobData, opts *wshrpc.RpcOpts) (*wshrpc.CommandStartJobRtnData, error) {
|
||||
resp, err := sendRpcRequestCallHelper[*wshrpc.CommandStartJobRtnData](w, "startjob", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "stopbuilder", wshserver.StopBuilderCommand
|
||||
func StopBuilderCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "stopbuilder", data, opts)
|
||||
|
|
@ -727,6 +859,12 @@ func TermGetScrollbackLinesCommand(w *wshutil.WshRpc, data wshrpc.CommandTermGet
|
|||
return resp, err
|
||||
}
|
||||
|
||||
// command "termupdateattachedjob", wshserver.TermUpdateAttachedJobCommand
|
||||
func TermUpdateAttachedJobCommand(w *wshutil.WshRpc, data wshrpc.CommandTermUpdateAttachedJobData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "termupdateattachedjob", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "test", wshserver.TestCommand
|
||||
func TestCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "test", data, opts)
|
||||
|
|
|
|||
|
|
@ -4,35 +4,44 @@
|
|||
package wshremote
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/connparse"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
|
||||
"github.com/wavetermdev/waveterm/pkg/suggestion"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/fileutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/tarcopy"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
type JobManagerConnection struct {
|
||||
JobId string
|
||||
Conn net.Conn
|
||||
WshRpc *wshutil.WshRpc
|
||||
CleanupFn func()
|
||||
}
|
||||
|
||||
type ServerImpl struct {
|
||||
LogWriter io.Writer
|
||||
LogWriter io.Writer
|
||||
Router *wshutil.WshRouter
|
||||
RpcClient *wshutil.WshRpc
|
||||
IsLocal bool
|
||||
JobManagerMap map[string]*JobManagerConnection
|
||||
Lock sync.Mutex
|
||||
}
|
||||
|
||||
func MakeRemoteRpcServerImpl(logWriter io.Writer, router *wshutil.WshRouter, rpcClient *wshutil.WshRpc, isLocal bool) *ServerImpl {
|
||||
return &ServerImpl{
|
||||
LogWriter: logWriter,
|
||||
Router: router,
|
||||
RpcClient: rpcClient,
|
||||
IsLocal: isLocal,
|
||||
JobManagerMap: make(map[string]*JobManagerConnection),
|
||||
}
|
||||
}
|
||||
|
||||
func (*ServerImpl) WshServerImpl() {}
|
||||
|
|
@ -66,785 +75,6 @@ func (impl *ServerImpl) StreamTestCommand(ctx context.Context) chan wshrpc.RespO
|
|||
return ch
|
||||
}
|
||||
|
||||
type ByteRangeType struct {
|
||||
All bool
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
|
||||
func parseByteRange(rangeStr string) (ByteRangeType, error) {
|
||||
if rangeStr == "" {
|
||||
return ByteRangeType{All: true}, nil
|
||||
}
|
||||
var start, end int64
|
||||
_, err := fmt.Sscanf(rangeStr, "%d-%d", &start, &end)
|
||||
if err != nil {
|
||||
return ByteRangeType{}, errors.New("invalid byte range")
|
||||
}
|
||||
if start < 0 || end < 0 || start > end {
|
||||
return ByteRangeType{}, errors.New("invalid byte range")
|
||||
}
|
||||
return ByteRangeType{Start: start, End: end}, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) remoteStreamFileDir(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
|
||||
innerFilesEntries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open dir %q: %w", path, err)
|
||||
}
|
||||
if byteRange.All {
|
||||
if len(innerFilesEntries) > wshrpc.MaxDirSize {
|
||||
innerFilesEntries = innerFilesEntries[:wshrpc.MaxDirSize]
|
||||
}
|
||||
} else {
|
||||
if byteRange.Start < int64(len(innerFilesEntries)) {
|
||||
realEnd := byteRange.End
|
||||
if realEnd > int64(len(innerFilesEntries)) {
|
||||
realEnd = int64(len(innerFilesEntries))
|
||||
}
|
||||
innerFilesEntries = innerFilesEntries[byteRange.Start:realEnd]
|
||||
} else {
|
||||
innerFilesEntries = []os.DirEntry{}
|
||||
}
|
||||
}
|
||||
var fileInfoArr []*wshrpc.FileInfo
|
||||
for _, innerFileEntry := range innerFilesEntries {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
innerFileInfoInt, err := innerFileEntry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
|
||||
fileInfoArr = append(fileInfoArr, innerFileInfo)
|
||||
if len(fileInfoArr) >= wshrpc.DirChunkSize {
|
||||
dataCallback(fileInfoArr, nil, byteRange)
|
||||
fileInfoArr = nil
|
||||
}
|
||||
}
|
||||
if len(fileInfoArr) > 0 {
|
||||
dataCallback(fileInfoArr, nil, byteRange)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
|
||||
fd, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file %q: %w", path, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path)
|
||||
var filePos int64
|
||||
if !byteRange.All && byteRange.Start > 0 {
|
||||
_, err := fd.Seek(byteRange.Start, io.SeekStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seeking file %q: %w", path, err)
|
||||
}
|
||||
filePos = byteRange.Start
|
||||
}
|
||||
buf := make([]byte, wshrpc.FileChunkSize)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
n, err := fd.Read(buf)
|
||||
if n > 0 {
|
||||
if !byteRange.All && filePos+int64(n) > byteRange.End {
|
||||
n = int(byteRange.End - filePos)
|
||||
}
|
||||
filePos += int64(n)
|
||||
dataCallback(nil, buf[:n], byteRange)
|
||||
}
|
||||
if !byteRange.All && filePos >= byteRange.End {
|
||||
break
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading file %q: %w", path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrpc.CommandRemoteStreamFileData, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
|
||||
byteRange, err := parseByteRange(data.ByteRange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err := wavebase.ExpandHomeDir(data.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finfo, err := impl.fileInfoInternal(path, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
dataCallback([]*wshrpc.FileInfo{finfo}, nil, byteRange)
|
||||
if finfo.NotFound {
|
||||
return nil
|
||||
}
|
||||
if finfo.IsDir {
|
||||
return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback)
|
||||
} else {
|
||||
return impl.remoteStreamFileRegular(ctx, path, byteRange, dataCallback)
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc.CommandRemoteStreamFileData) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] {
|
||||
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
firstPk := true
|
||||
err := impl.remoteStreamFileInternal(ctx, data, func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType) {
|
||||
resp := wshrpc.FileData{}
|
||||
fileInfoLen := len(fileInfo)
|
||||
if fileInfoLen > 1 || !firstPk {
|
||||
resp.Entries = fileInfo
|
||||
} else if fileInfoLen == 1 {
|
||||
resp.Info = fileInfo[0]
|
||||
}
|
||||
if firstPk {
|
||||
firstPk = false
|
||||
}
|
||||
if len(data) > 0 {
|
||||
resp.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||
resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)}
|
||||
}
|
||||
ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp}
|
||||
})
|
||||
if err != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.FileData](err)
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] {
|
||||
path := data.Path
|
||||
opts := data.Opts
|
||||
if opts == nil {
|
||||
opts = &wshrpc.FileCopyOpts{}
|
||||
}
|
||||
log.Printf("RemoteTarStreamCommand: path=%s\n", path)
|
||||
srcHasSlash := strings.HasSuffix(path, "/")
|
||||
path, err := wavebase.ExpandHomeDir(path)
|
||||
if err != nil {
|
||||
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err))
|
||||
}
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
finfo, err := os.Stat(cleanedPath)
|
||||
if err != nil {
|
||||
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err))
|
||||
}
|
||||
|
||||
var pathPrefix string
|
||||
singleFile := !finfo.IsDir()
|
||||
if !singleFile && srcHasSlash {
|
||||
pathPrefix = cleanedPath
|
||||
} else {
|
||||
pathPrefix = filepath.Dir(cleanedPath)
|
||||
}
|
||||
|
||||
timeout := fstype.DefaultTimeout
|
||||
if opts.Timeout > 0 {
|
||||
timeout = time.Duration(opts.Timeout) * time.Millisecond
|
||||
}
|
||||
readerCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
tarClose()
|
||||
cancel()
|
||||
}()
|
||||
walkFunc := func(path string, info fs.FileInfo, err error) error {
|
||||
if readerCtx.Err() != nil {
|
||||
return readerCtx.Err()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = writeHeader(info, path, singleFile); err != nil {
|
||||
return err
|
||||
}
|
||||
// if not a dir, write file content
|
||||
if !info.IsDir() {
|
||||
data, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path)
|
||||
if _, err := io.Copy(fileWriter, data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
log.Printf("RemoteTarStreamCommand: starting\n")
|
||||
err = nil
|
||||
if singleFile {
|
||||
err = walkFunc(cleanedPath, finfo, nil)
|
||||
} else {
|
||||
err = filepath.Walk(cleanedPath, walkFunc)
|
||||
}
|
||||
if err != nil {
|
||||
rtn <- wshutil.RespErr[iochantypes.Packet](err)
|
||||
}
|
||||
log.Printf("RemoteTarStreamCommand: done\n")
|
||||
}()
|
||||
log.Printf("RemoteTarStreamCommand: returning channel\n")
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) (bool, error) {
|
||||
log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri)
|
||||
opts := data.Opts
|
||||
if opts == nil {
|
||||
opts = &wshrpc.FileCopyOpts{}
|
||||
}
|
||||
destUri := data.DestUri
|
||||
srcUri := data.SrcUri
|
||||
merge := opts.Merge
|
||||
overwrite := opts.Overwrite
|
||||
if overwrite && merge {
|
||||
return false, fmt.Errorf("cannot specify both overwrite and merge")
|
||||
}
|
||||
|
||||
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot parse destination URI %q: %w", destUri, err)
|
||||
}
|
||||
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
|
||||
destinfo, err := os.Stat(destPathCleaned)
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
return false, fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err)
|
||||
}
|
||||
}
|
||||
|
||||
destExists := destinfo != nil
|
||||
destIsDir := destExists && destinfo.IsDir()
|
||||
destHasSlash := strings.HasSuffix(destUri, "/")
|
||||
|
||||
if destExists && !destIsDir {
|
||||
if !overwrite {
|
||||
return false, fmt.Errorf(fstype.OverwriteRequiredError, destPathCleaned)
|
||||
} else {
|
||||
err := os.Remove(destPathCleaned)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
|
||||
}
|
||||
|
||||
copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) {
|
||||
nextinfo, err := os.Stat(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
|
||||
if nextinfo != nil {
|
||||
if nextinfo.IsDir() {
|
||||
if !finfo.IsDir() {
|
||||
// try to create file in directory
|
||||
path = filepath.Join(path, filepath.Base(finfo.Name()))
|
||||
newdestinfo, err := os.Stat(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
if newdestinfo != nil && !overwrite {
|
||||
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
|
||||
}
|
||||
} else if overwrite {
|
||||
err := os.RemoveAll(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
|
||||
}
|
||||
} else if !merge {
|
||||
return 0, fmt.Errorf(fstype.MergeRequiredError, path)
|
||||
}
|
||||
} else {
|
||||
if !overwrite {
|
||||
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
|
||||
} else if finfo.IsDir() {
|
||||
err := os.RemoveAll(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if finfo.IsDir() {
|
||||
err := os.MkdirAll(path, finfo.Mode())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot create directory %q: %w", path, err)
|
||||
}
|
||||
return 0, nil
|
||||
} else {
|
||||
err := os.MkdirAll(filepath.Dir(path), 0755)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot create new file %q: %w", path, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path)
|
||||
_, err = io.Copy(file, srcFile)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot write file %q: %w", path, err)
|
||||
}
|
||||
|
||||
return finfo.Size(), nil
|
||||
}
|
||||
|
||||
srcIsDir := false
|
||||
if srcConn.Host == destConn.Host {
|
||||
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
|
||||
|
||||
srcFileStat, err := os.Stat(srcPathCleaned)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
|
||||
}
|
||||
|
||||
if srcFileStat.IsDir() {
|
||||
srcIsDir = true
|
||||
var srcPathPrefix string
|
||||
if destIsDir {
|
||||
srcPathPrefix = filepath.Dir(srcPathCleaned)
|
||||
} else {
|
||||
srcPathPrefix = srcPathCleaned
|
||||
}
|
||||
err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srcFilePath := path
|
||||
destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix))
|
||||
var file *os.File
|
||||
if !info.IsDir() {
|
||||
file, err = os.Open(srcFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file %q: %w", srcFilePath, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath)
|
||||
}
|
||||
_, err = copyFileFunc(destFilePath, info, file)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
|
||||
}
|
||||
} else {
|
||||
file, err := os.Open(srcPathCleaned)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned)
|
||||
var destFilePath string
|
||||
if destHasSlash {
|
||||
destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned))
|
||||
} else {
|
||||
destFilePath = destPathCleaned
|
||||
}
|
||||
_, err = copyFileFunc(destFilePath, srcFileStat, file)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
timeout := fstype.DefaultTimeout
|
||||
if opts.Timeout > 0 {
|
||||
timeout = time.Duration(opts.Timeout) * time.Millisecond
|
||||
}
|
||||
readCtx, cancel := context.WithCancelCause(ctx)
|
||||
readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri))
|
||||
defer timeoutCancel()
|
||||
copyStart := time.Now()
|
||||
ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout})
|
||||
numFiles := 0
|
||||
numSkipped := 0
|
||||
totalBytes := int64(0)
|
||||
|
||||
err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error {
|
||||
numFiles++
|
||||
nextpath := filepath.Join(destPathCleaned, next.Name)
|
||||
srcIsDir = !singleFile
|
||||
if singleFile && !destHasSlash {
|
||||
// custom flag to indicate that the source is a single file, not a directory the contents of a directory
|
||||
nextpath = destPathCleaned
|
||||
}
|
||||
finfo := next.FileInfo()
|
||||
n, err := copyFileFunc(nextpath, finfo, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot copy file %q: %w", next.Name, err)
|
||||
}
|
||||
totalBytes += n
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
|
||||
}
|
||||
totalTime := time.Since(copyStart).Seconds()
|
||||
totalMegaBytes := float64(totalBytes) / 1024 / 1024
|
||||
rate := float64(0)
|
||||
if totalTime > 0 {
|
||||
rate = totalMegaBytes / totalTime
|
||||
}
|
||||
log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped)
|
||||
}
|
||||
return srcIsDir, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrpc.CommandRemoteListEntriesData) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] {
|
||||
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
path, err := wavebase.ExpandHomeDir(data.Path)
|
||||
if err != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err)
|
||||
return
|
||||
}
|
||||
innerFilesEntries := []os.DirEntry{}
|
||||
seen := 0
|
||||
if data.Opts.Limit == 0 {
|
||||
data.Opts.Limit = wshrpc.MaxDirSize
|
||||
}
|
||||
if data.Opts.All {
|
||||
fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error {
|
||||
defer func() {
|
||||
seen++
|
||||
}()
|
||||
if seen < data.Opts.Offset {
|
||||
return nil
|
||||
}
|
||||
if seen >= data.Opts.Offset+data.Opts.Limit {
|
||||
return io.EOF
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
innerFilesEntries = append(innerFilesEntries, d)
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
innerFilesEntries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](fmt.Errorf("cannot open dir %q: %w", path, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
var fileInfoArr []*wshrpc.FileInfo
|
||||
for _, innerFileEntry := range innerFilesEntries {
|
||||
if ctx.Err() != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](ctx.Err())
|
||||
return
|
||||
}
|
||||
innerFileInfoInt, err := innerFileEntry.Info()
|
||||
if err != nil {
|
||||
log.Printf("cannot stat file %q: %v\n", innerFileEntry.Name(), err)
|
||||
continue
|
||||
}
|
||||
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
|
||||
fileInfoArr = append(fileInfoArr, innerFileInfo)
|
||||
if len(fileInfoArr) >= wshrpc.DirChunkSize {
|
||||
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
|
||||
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
|
||||
fileInfoArr = nil
|
||||
}
|
||||
}
|
||||
if len(fileInfoArr) > 0 {
|
||||
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
|
||||
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo {
|
||||
mimeType := fileutil.DetectMimeType(fullPath, finfo, extended)
|
||||
rtn := &wshrpc.FileInfo{
|
||||
Path: wavebase.ReplaceHomeDir(fullPath),
|
||||
Dir: computeDirPart(fullPath),
|
||||
Name: finfo.Name(),
|
||||
Size: finfo.Size(),
|
||||
Mode: finfo.Mode(),
|
||||
ModeStr: finfo.Mode().String(),
|
||||
ModTime: finfo.ModTime().UnixMilli(),
|
||||
IsDir: finfo.IsDir(),
|
||||
MimeType: mimeType,
|
||||
SupportsMkdir: true,
|
||||
}
|
||||
if finfo.IsDir() {
|
||||
rtn.Size = -1
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// fileInfo might be null
|
||||
func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool {
|
||||
if !exists || fileInfo.Mode().IsDir() {
|
||||
dirName := filepath.Dir(path)
|
||||
randHexStr, err := utilfn.RandomHexString(12)
|
||||
if err != nil {
|
||||
// we're not sure, just return false
|
||||
return false
|
||||
}
|
||||
tmpFileName := filepath.Join(dirName, "wsh-tmp-"+randHexStr)
|
||||
fd, err := os.Create(tmpFileName)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName)
|
||||
os.Remove(tmpFileName)
|
||||
return false
|
||||
}
|
||||
// try to open for writing, if this fails then it is read-only
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
utilfn.GracefulClose(file, "checkIsReadOnly", path)
|
||||
return false
|
||||
}
|
||||
|
||||
func computeDirPart(path string) string {
|
||||
path = filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
path = filepath.ToSlash(path)
|
||||
if path == "/" {
|
||||
return "/"
|
||||
}
|
||||
return filepath.Dir(path)
|
||||
}
|
||||
|
||||
func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInfo, error) {
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
finfo, err := os.Stat(cleanedPath)
|
||||
if os.IsNotExist(err) {
|
||||
return &wshrpc.FileInfo{
|
||||
Path: wavebase.ReplaceHomeDir(path),
|
||||
Dir: computeDirPart(path),
|
||||
NotFound: true,
|
||||
ReadOnly: checkIsReadOnly(cleanedPath, finfo, false),
|
||||
SupportsMkdir: true,
|
||||
}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
rtn := statToFileInfo(cleanedPath, finfo, extended)
|
||||
if extended {
|
||||
rtn.ReadOnly = checkIsReadOnly(cleanedPath, finfo, true)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func resolvePaths(paths []string) string {
|
||||
if len(paths) == 0 {
|
||||
return wavebase.ExpandHomeDirSafe("~")
|
||||
}
|
||||
rtnPath := wavebase.ExpandHomeDirSafe(paths[0])
|
||||
for _, path := range paths[1:] {
|
||||
path = wavebase.ExpandHomeDirSafe(path)
|
||||
if filepath.IsAbs(path) {
|
||||
rtnPath = path
|
||||
continue
|
||||
}
|
||||
rtnPath = filepath.Join(rtnPath, path)
|
||||
}
|
||||
return rtnPath
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) {
|
||||
rtnPath := resolvePaths(paths)
|
||||
return impl.fileInfoInternal(rtnPath, true)
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileInfoCommand(ctx context.Context, path string) (*wshrpc.FileInfo, error) {
|
||||
return impl.fileInfoInternal(path, true)
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) error {
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
if _, err := os.Stat(cleanedPath); err == nil {
|
||||
return fmt.Errorf("file %q already exists", path)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(cleanedPath), 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory %q: %w", filepath.Dir(cleanedPath), err)
|
||||
}
|
||||
if err := os.WriteFile(cleanedPath, []byte{}, 0644); err != nil {
|
||||
return fmt.Errorf("cannot create file %q: %w", cleanedPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error {
|
||||
opts := data.Opts
|
||||
destUri := data.DestUri
|
||||
srcUri := data.SrcUri
|
||||
overwrite := opts != nil && opts.Overwrite
|
||||
recursive := opts != nil && opts.Recursive
|
||||
|
||||
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse destination URI %q: %w", srcUri, err)
|
||||
}
|
||||
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
|
||||
destinfo, err := os.Stat(destPathCleaned)
|
||||
if err == nil {
|
||||
if !destinfo.IsDir() {
|
||||
if !overwrite {
|
||||
return fmt.Errorf("destination %q already exists, use overwrite option", destUri)
|
||||
} else {
|
||||
err := os.Remove(destPathCleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot remove file %q: %w", destUri, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("cannot stat destination %q: %w", destUri, err)
|
||||
}
|
||||
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
|
||||
}
|
||||
if srcConn.Host == destConn.Host {
|
||||
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
|
||||
finfo, err := os.Stat(srcPathCleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
|
||||
}
|
||||
if finfo.IsDir() && !recursive {
|
||||
return fmt.Errorf(fstype.RecursiveRequiredError)
|
||||
}
|
||||
err = os.Rename(srcPathCleaned, destPathCleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteMkdirCommand(ctx context.Context, path string) error {
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
if stat, err := os.Stat(cleanedPath); err == nil {
|
||||
if stat.IsDir() {
|
||||
return fmt.Errorf("directory %q already exists", path)
|
||||
} else {
|
||||
return fmt.Errorf("cannot create directory %q, file exists at path", path)
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(cleanedPath, 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory %q: %w", cleanedPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileData) error {
|
||||
var truncate, append bool
|
||||
var atOffset int64
|
||||
if data.Info != nil && data.Info.Opts != nil {
|
||||
truncate = data.Info.Opts.Truncate
|
||||
append = data.Info.Opts.Append
|
||||
}
|
||||
if data.At != nil {
|
||||
atOffset = data.At.Offset
|
||||
}
|
||||
if truncate && atOffset > 0 {
|
||||
return fmt.Errorf("cannot specify non-zero offset with truncate option")
|
||||
}
|
||||
if append && atOffset > 0 {
|
||||
return fmt.Errorf("cannot specify non-zero offset with append option")
|
||||
}
|
||||
path, err := wavebase.ExpandHomeDir(data.Info.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
createMode := os.FileMode(0644)
|
||||
if data.Info != nil && data.Info.Mode > 0 {
|
||||
createMode = data.Info.Mode
|
||||
}
|
||||
dataSize := base64.StdEncoding.DecodedLen(len(data.Data64))
|
||||
dataBytes := make([]byte, dataSize)
|
||||
n, err := base64.StdEncoding.Decode(dataBytes, []byte(data.Data64))
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot decode base64 data: %w", err)
|
||||
}
|
||||
finfo, err := os.Stat(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
fileSize := int64(0)
|
||||
if finfo != nil {
|
||||
fileSize = finfo.Size()
|
||||
}
|
||||
if atOffset > fileSize {
|
||||
return fmt.Errorf("cannot write at offset %d, file size is %d", atOffset, fileSize)
|
||||
}
|
||||
openFlags := os.O_CREATE | os.O_WRONLY
|
||||
if truncate {
|
||||
openFlags |= os.O_TRUNC
|
||||
}
|
||||
if append {
|
||||
openFlags |= os.O_APPEND
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, openFlags, createMode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file %q: %w", path, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path)
|
||||
if atOffset > 0 && !append {
|
||||
n, err = file.WriteAt(dataBytes[:n], atOffset)
|
||||
} else {
|
||||
n, err = file.Write(dataBytes[:n])
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write to file %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error {
|
||||
expandedPath, err := wavebase.ExpandHomeDir(data.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
|
||||
}
|
||||
cleanedPath := filepath.Clean(expandedPath)
|
||||
|
||||
err = os.Remove(cleanedPath)
|
||||
if err != nil {
|
||||
finfo, _ := os.Stat(cleanedPath)
|
||||
if finfo != nil && finfo.IsDir() {
|
||||
if !data.Recursive {
|
||||
return fmt.Errorf(fstype.RecursiveRequiredError)
|
||||
}
|
||||
err = os.RemoveAll(cleanedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete directory %q: %w", data.Path, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*ServerImpl) RemoteGetInfoCommand(ctx context.Context) (wshrpc.RemoteInfo, error) {
|
||||
return wshutil.GetInfo(), nil
|
||||
}
|
||||
|
|
@ -861,3 +91,14 @@ func (*ServerImpl) DisposeSuggestionsCommand(ctx context.Context, widgetId strin
|
|||
suggestion.DisposeSuggestions(ctx, widgetId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) getWshPath() (string, error) {
|
||||
if impl.IsLocal {
|
||||
return filepath.Join(wavebase.GetWaveDataDir(), "bin", "wsh"), nil
|
||||
}
|
||||
wshPath, err := wavebase.ExpandHomeDir("~/.waveterm/bin/wsh")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot expand wsh path: %w", err)
|
||||
}
|
||||
return wshPath, nil
|
||||
}
|
||||
|
|
|
|||
810
pkg/wshrpc/wshremote/wshremote_file.go
Normal file
810
pkg/wshrpc/wshremote/wshremote_file.go
Normal file
|
|
@ -0,0 +1,810 @@
|
|||
// Copyright 2026, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshremote
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/connparse"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/fileutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/tarcopy"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
type ByteRangeType struct {
|
||||
All bool
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
|
||||
func parseByteRange(rangeStr string) (ByteRangeType, error) {
|
||||
if rangeStr == "" {
|
||||
return ByteRangeType{All: true}, nil
|
||||
}
|
||||
var start, end int64
|
||||
_, err := fmt.Sscanf(rangeStr, "%d-%d", &start, &end)
|
||||
if err != nil {
|
||||
return ByteRangeType{}, errors.New("invalid byte range")
|
||||
}
|
||||
if start < 0 || end < 0 || start > end {
|
||||
return ByteRangeType{}, errors.New("invalid byte range")
|
||||
}
|
||||
return ByteRangeType{Start: start, End: end}, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) remoteStreamFileDir(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
|
||||
innerFilesEntries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open dir %q: %w", path, err)
|
||||
}
|
||||
if byteRange.All {
|
||||
if len(innerFilesEntries) > wshrpc.MaxDirSize {
|
||||
innerFilesEntries = innerFilesEntries[:wshrpc.MaxDirSize]
|
||||
}
|
||||
} else {
|
||||
if byteRange.Start < int64(len(innerFilesEntries)) {
|
||||
realEnd := byteRange.End
|
||||
if realEnd > int64(len(innerFilesEntries)) {
|
||||
realEnd = int64(len(innerFilesEntries))
|
||||
}
|
||||
innerFilesEntries = innerFilesEntries[byteRange.Start:realEnd]
|
||||
} else {
|
||||
innerFilesEntries = []os.DirEntry{}
|
||||
}
|
||||
}
|
||||
var fileInfoArr []*wshrpc.FileInfo
|
||||
for _, innerFileEntry := range innerFilesEntries {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
innerFileInfoInt, err := innerFileEntry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
|
||||
fileInfoArr = append(fileInfoArr, innerFileInfo)
|
||||
if len(fileInfoArr) >= wshrpc.DirChunkSize {
|
||||
dataCallback(fileInfoArr, nil, byteRange)
|
||||
fileInfoArr = nil
|
||||
}
|
||||
}
|
||||
if len(fileInfoArr) > 0 {
|
||||
dataCallback(fileInfoArr, nil, byteRange)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string, byteRange ByteRangeType, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
|
||||
fd, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file %q: %w", path, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path)
|
||||
var filePos int64
|
||||
if !byteRange.All && byteRange.Start > 0 {
|
||||
_, err := fd.Seek(byteRange.Start, io.SeekStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seeking file %q: %w", path, err)
|
||||
}
|
||||
filePos = byteRange.Start
|
||||
}
|
||||
buf := make([]byte, wshrpc.FileChunkSize)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
n, err := fd.Read(buf)
|
||||
if n > 0 {
|
||||
if !byteRange.All && filePos+int64(n) > byteRange.End {
|
||||
n = int(byteRange.End - filePos)
|
||||
}
|
||||
filePos += int64(n)
|
||||
dataCallback(nil, buf[:n], byteRange)
|
||||
}
|
||||
if !byteRange.All && filePos >= byteRange.End {
|
||||
break
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading file %q: %w", path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrpc.CommandRemoteStreamFileData, dataCallback func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType)) error {
|
||||
byteRange, err := parseByteRange(data.ByteRange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err := wavebase.ExpandHomeDir(data.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finfo, err := impl.fileInfoInternal(path, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
dataCallback([]*wshrpc.FileInfo{finfo}, nil, byteRange)
|
||||
if finfo.NotFound {
|
||||
return nil
|
||||
}
|
||||
if finfo.IsDir {
|
||||
return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback)
|
||||
} else {
|
||||
return impl.remoteStreamFileRegular(ctx, path, byteRange, dataCallback)
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc.CommandRemoteStreamFileData) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] {
|
||||
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
firstPk := true
|
||||
err := impl.remoteStreamFileInternal(ctx, data, func(fileInfo []*wshrpc.FileInfo, data []byte, byteRange ByteRangeType) {
|
||||
resp := wshrpc.FileData{}
|
||||
fileInfoLen := len(fileInfo)
|
||||
if fileInfoLen > 1 || !firstPk {
|
||||
resp.Entries = fileInfo
|
||||
} else if fileInfoLen == 1 {
|
||||
resp.Info = fileInfo[0]
|
||||
}
|
||||
if firstPk {
|
||||
firstPk = false
|
||||
}
|
||||
if len(data) > 0 {
|
||||
resp.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||
resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)}
|
||||
}
|
||||
ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp}
|
||||
})
|
||||
if err != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.FileData](err)
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] {
|
||||
path := data.Path
|
||||
opts := data.Opts
|
||||
if opts == nil {
|
||||
opts = &wshrpc.FileCopyOpts{}
|
||||
}
|
||||
log.Printf("RemoteTarStreamCommand: path=%s\n", path)
|
||||
srcHasSlash := strings.HasSuffix(path, "/")
|
||||
path, err := wavebase.ExpandHomeDir(path)
|
||||
if err != nil {
|
||||
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err))
|
||||
}
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
finfo, err := os.Stat(cleanedPath)
|
||||
if err != nil {
|
||||
return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err))
|
||||
}
|
||||
|
||||
var pathPrefix string
|
||||
singleFile := !finfo.IsDir()
|
||||
if !singleFile && srcHasSlash {
|
||||
pathPrefix = cleanedPath
|
||||
} else {
|
||||
pathPrefix = filepath.Dir(cleanedPath)
|
||||
}
|
||||
|
||||
timeout := fstype.DefaultTimeout
|
||||
if opts.Timeout > 0 {
|
||||
timeout = time.Duration(opts.Timeout) * time.Millisecond
|
||||
}
|
||||
readerCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
tarClose()
|
||||
cancel()
|
||||
}()
|
||||
walkFunc := func(path string, info fs.FileInfo, err error) error {
|
||||
if readerCtx.Err() != nil {
|
||||
return readerCtx.Err()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = writeHeader(info, path, singleFile); err != nil {
|
||||
return err
|
||||
}
|
||||
// if not a dir, write file content
|
||||
if !info.IsDir() {
|
||||
data, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path)
|
||||
if _, err := io.Copy(fileWriter, data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
log.Printf("RemoteTarStreamCommand: starting\n")
|
||||
err = nil
|
||||
if singleFile {
|
||||
err = walkFunc(cleanedPath, finfo, nil)
|
||||
} else {
|
||||
err = filepath.Walk(cleanedPath, walkFunc)
|
||||
}
|
||||
if err != nil {
|
||||
rtn <- wshutil.RespErr[iochantypes.Packet](err)
|
||||
}
|
||||
log.Printf("RemoteTarStreamCommand: done\n")
|
||||
}()
|
||||
log.Printf("RemoteTarStreamCommand: returning channel\n")
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) (bool, error) {
|
||||
log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri)
|
||||
opts := data.Opts
|
||||
if opts == nil {
|
||||
opts = &wshrpc.FileCopyOpts{}
|
||||
}
|
||||
destUri := data.DestUri
|
||||
srcUri := data.SrcUri
|
||||
merge := opts.Merge
|
||||
overwrite := opts.Overwrite
|
||||
if overwrite && merge {
|
||||
return false, fmt.Errorf("cannot specify both overwrite and merge")
|
||||
}
|
||||
|
||||
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot parse destination URI %q: %w", destUri, err)
|
||||
}
|
||||
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
|
||||
destinfo, err := os.Stat(destPathCleaned)
|
||||
if err != nil {
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
return false, fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err)
|
||||
}
|
||||
}
|
||||
|
||||
destExists := destinfo != nil
|
||||
destIsDir := destExists && destinfo.IsDir()
|
||||
destHasSlash := strings.HasSuffix(destUri, "/")
|
||||
|
||||
if destExists && !destIsDir {
|
||||
if !overwrite {
|
||||
return false, fmt.Errorf(fstype.OverwriteRequiredError, destPathCleaned)
|
||||
} else {
|
||||
err := os.Remove(destPathCleaned)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
|
||||
}
|
||||
|
||||
copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) {
|
||||
nextinfo, err := os.Stat(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
|
||||
if nextinfo != nil {
|
||||
if nextinfo.IsDir() {
|
||||
if !finfo.IsDir() {
|
||||
// try to create file in directory
|
||||
path = filepath.Join(path, filepath.Base(finfo.Name()))
|
||||
newdestinfo, err := os.Stat(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
if newdestinfo != nil && !overwrite {
|
||||
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
|
||||
}
|
||||
} else if overwrite {
|
||||
err := os.RemoveAll(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
|
||||
}
|
||||
} else if !merge {
|
||||
return 0, fmt.Errorf(fstype.MergeRequiredError, path)
|
||||
}
|
||||
} else {
|
||||
if !overwrite {
|
||||
return 0, fmt.Errorf(fstype.OverwriteRequiredError, path)
|
||||
} else if finfo.IsDir() {
|
||||
err := os.RemoveAll(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if finfo.IsDir() {
|
||||
err := os.MkdirAll(path, finfo.Mode())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot create directory %q: %w", path, err)
|
||||
}
|
||||
return 0, nil
|
||||
} else {
|
||||
err := os.MkdirAll(filepath.Dir(path), 0755)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot create new file %q: %w", path, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path)
|
||||
_, err = io.Copy(file, srcFile)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot write file %q: %w", path, err)
|
||||
}
|
||||
|
||||
return finfo.Size(), nil
|
||||
}
|
||||
|
||||
srcIsDir := false
|
||||
if srcConn.Host == destConn.Host {
|
||||
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
|
||||
|
||||
srcFileStat, err := os.Stat(srcPathCleaned)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
|
||||
}
|
||||
|
||||
if srcFileStat.IsDir() {
|
||||
srcIsDir = true
|
||||
var srcPathPrefix string
|
||||
if destIsDir {
|
||||
srcPathPrefix = filepath.Dir(srcPathCleaned)
|
||||
} else {
|
||||
srcPathPrefix = srcPathCleaned
|
||||
}
|
||||
err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srcFilePath := path
|
||||
destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix))
|
||||
var file *os.File
|
||||
if !info.IsDir() {
|
||||
file, err = os.Open(srcFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file %q: %w", srcFilePath, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath)
|
||||
}
|
||||
_, err = copyFileFunc(destFilePath, info, file)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
|
||||
}
|
||||
} else {
|
||||
file, err := os.Open(srcPathCleaned)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned)
|
||||
var destFilePath string
|
||||
if destHasSlash {
|
||||
destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned))
|
||||
} else {
|
||||
destFilePath = destPathCleaned
|
||||
}
|
||||
_, err = copyFileFunc(destFilePath, srcFileStat, file)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
timeout := fstype.DefaultTimeout
|
||||
if opts.Timeout > 0 {
|
||||
timeout = time.Duration(opts.Timeout) * time.Millisecond
|
||||
}
|
||||
readCtx, cancel := context.WithCancelCause(ctx)
|
||||
readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri))
|
||||
defer timeoutCancel()
|
||||
copyStart := time.Now()
|
||||
ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout})
|
||||
numFiles := 0
|
||||
numSkipped := 0
|
||||
totalBytes := int64(0)
|
||||
|
||||
err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error {
|
||||
numFiles++
|
||||
nextpath := filepath.Join(destPathCleaned, next.Name)
|
||||
srcIsDir = !singleFile
|
||||
if singleFile && !destHasSlash {
|
||||
// custom flag to indicate that the source is a single file, not a directory the contents of a directory
|
||||
nextpath = destPathCleaned
|
||||
}
|
||||
finfo := next.FileInfo()
|
||||
n, err := copyFileFunc(nextpath, finfo, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot copy file %q: %w", next.Name, err)
|
||||
}
|
||||
totalBytes += n
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
|
||||
}
|
||||
totalTime := time.Since(copyStart).Seconds()
|
||||
totalMegaBytes := float64(totalBytes) / 1024 / 1024
|
||||
rate := float64(0)
|
||||
if totalTime > 0 {
|
||||
rate = totalMegaBytes / totalTime
|
||||
}
|
||||
log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped)
|
||||
}
|
||||
return srcIsDir, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrpc.CommandRemoteListEntriesData) chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] {
|
||||
ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
path, err := wavebase.ExpandHomeDir(data.Path)
|
||||
if err != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err)
|
||||
return
|
||||
}
|
||||
innerFilesEntries := []os.DirEntry{}
|
||||
seen := 0
|
||||
if data.Opts.Limit == 0 {
|
||||
data.Opts.Limit = wshrpc.MaxDirSize
|
||||
}
|
||||
if data.Opts.All {
|
||||
fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error {
|
||||
defer func() {
|
||||
seen++
|
||||
}()
|
||||
if seen < data.Opts.Offset {
|
||||
return nil
|
||||
}
|
||||
if seen >= data.Opts.Offset+data.Opts.Limit {
|
||||
return io.EOF
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
innerFilesEntries = append(innerFilesEntries, d)
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
innerFilesEntries, err = os.ReadDir(path)
|
||||
if err != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](fmt.Errorf("cannot open dir %q: %w", path, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
var fileInfoArr []*wshrpc.FileInfo
|
||||
for _, innerFileEntry := range innerFilesEntries {
|
||||
if ctx.Err() != nil {
|
||||
ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](ctx.Err())
|
||||
return
|
||||
}
|
||||
innerFileInfoInt, err := innerFileEntry.Info()
|
||||
if err != nil {
|
||||
log.Printf("cannot stat file %q: %v\n", innerFileEntry.Name(), err)
|
||||
continue
|
||||
}
|
||||
innerFileInfo := statToFileInfo(filepath.Join(path, innerFileInfoInt.Name()), innerFileInfoInt, false)
|
||||
fileInfoArr = append(fileInfoArr, innerFileInfo)
|
||||
if len(fileInfoArr) >= wshrpc.DirChunkSize {
|
||||
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
|
||||
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
|
||||
fileInfoArr = nil
|
||||
}
|
||||
}
|
||||
if len(fileInfoArr) > 0 {
|
||||
resp := wshrpc.CommandRemoteListEntriesRtnData{FileInfo: fileInfoArr}
|
||||
ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: resp}
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo {
|
||||
mimeType := fileutil.DetectMimeType(fullPath, finfo, extended)
|
||||
rtn := &wshrpc.FileInfo{
|
||||
Path: wavebase.ReplaceHomeDir(fullPath),
|
||||
Dir: computeDirPart(fullPath),
|
||||
Name: finfo.Name(),
|
||||
Size: finfo.Size(),
|
||||
Mode: finfo.Mode(),
|
||||
ModeStr: finfo.Mode().String(),
|
||||
ModTime: finfo.ModTime().UnixMilli(),
|
||||
IsDir: finfo.IsDir(),
|
||||
MimeType: mimeType,
|
||||
SupportsMkdir: true,
|
||||
}
|
||||
if finfo.IsDir() {
|
||||
rtn.Size = -1
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// fileInfo might be null
|
||||
func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool {
|
||||
if !exists || fileInfo.Mode().IsDir() {
|
||||
dirName := filepath.Dir(path)
|
||||
randHexStr, err := utilfn.RandomHexString(12)
|
||||
if err != nil {
|
||||
// we're not sure, just return false
|
||||
return false
|
||||
}
|
||||
tmpFileName := filepath.Join(dirName, "wsh-tmp-"+randHexStr)
|
||||
fd, err := os.Create(tmpFileName)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName)
|
||||
os.Remove(tmpFileName)
|
||||
return false
|
||||
}
|
||||
// try to open for writing, if this fails then it is read-only
|
||||
file, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
utilfn.GracefulClose(file, "checkIsReadOnly", path)
|
||||
return false
|
||||
}
|
||||
|
||||
func computeDirPart(path string) string {
|
||||
path = filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
path = filepath.ToSlash(path)
|
||||
if path == "/" {
|
||||
return "/"
|
||||
}
|
||||
return filepath.Dir(path)
|
||||
}
|
||||
|
||||
func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInfo, error) {
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
finfo, err := os.Stat(cleanedPath)
|
||||
if os.IsNotExist(err) {
|
||||
return &wshrpc.FileInfo{
|
||||
Path: wavebase.ReplaceHomeDir(path),
|
||||
Dir: computeDirPart(path),
|
||||
NotFound: true,
|
||||
ReadOnly: checkIsReadOnly(cleanedPath, finfo, false),
|
||||
SupportsMkdir: true,
|
||||
}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
rtn := statToFileInfo(cleanedPath, finfo, extended)
|
||||
if extended {
|
||||
rtn.ReadOnly = checkIsReadOnly(cleanedPath, finfo, true)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func resolvePaths(paths []string) string {
|
||||
if len(paths) == 0 {
|
||||
return wavebase.ExpandHomeDirSafe("~")
|
||||
}
|
||||
rtnPath := wavebase.ExpandHomeDirSafe(paths[0])
|
||||
for _, path := range paths[1:] {
|
||||
path = wavebase.ExpandHomeDirSafe(path)
|
||||
if filepath.IsAbs(path) {
|
||||
rtnPath = path
|
||||
continue
|
||||
}
|
||||
rtnPath = filepath.Join(rtnPath, path)
|
||||
}
|
||||
return rtnPath
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) {
|
||||
rtnPath := resolvePaths(paths)
|
||||
return impl.fileInfoInternal(rtnPath, true)
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileInfoCommand(ctx context.Context, path string) (*wshrpc.FileInfo, error) {
|
||||
return impl.fileInfoInternal(path, true)
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) error {
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
if _, err := os.Stat(cleanedPath); err == nil {
|
||||
return fmt.Errorf("file %q already exists", path)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(cleanedPath), 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory %q: %w", filepath.Dir(cleanedPath), err)
|
||||
}
|
||||
if err := os.WriteFile(cleanedPath, []byte{}, 0644); err != nil {
|
||||
return fmt.Errorf("cannot create file %q: %w", cleanedPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error {
|
||||
opts := data.Opts
|
||||
destUri := data.DestUri
|
||||
srcUri := data.SrcUri
|
||||
overwrite := opts != nil && opts.Overwrite
|
||||
recursive := opts != nil && opts.Recursive
|
||||
|
||||
destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse destination URI %q: %w", srcUri, err)
|
||||
}
|
||||
destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path))
|
||||
destinfo, err := os.Stat(destPathCleaned)
|
||||
if err == nil {
|
||||
if !destinfo.IsDir() {
|
||||
if !overwrite {
|
||||
return fmt.Errorf("destination %q already exists, use overwrite option", destUri)
|
||||
} else {
|
||||
err := os.Remove(destPathCleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot remove file %q: %w", destUri, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("cannot stat destination %q: %w", destUri, err)
|
||||
}
|
||||
srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
|
||||
}
|
||||
if srcConn.Host == destConn.Host {
|
||||
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
|
||||
finfo, err := os.Stat(srcPathCleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
|
||||
}
|
||||
if finfo.IsDir() && !recursive {
|
||||
return fmt.Errorf(fstype.RecursiveRequiredError)
|
||||
}
|
||||
err = os.Rename(srcPathCleaned, destPathCleaned)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteMkdirCommand(ctx context.Context, path string) error {
|
||||
cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path))
|
||||
if stat, err := os.Stat(cleanedPath); err == nil {
|
||||
if stat.IsDir() {
|
||||
return fmt.Errorf("directory %q already exists", path)
|
||||
} else {
|
||||
return fmt.Errorf("cannot create directory %q, file exists at path", path)
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(cleanedPath, 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory %q: %w", cleanedPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileData) error {
|
||||
var truncate, append bool
|
||||
var atOffset int64
|
||||
if data.Info != nil && data.Info.Opts != nil {
|
||||
truncate = data.Info.Opts.Truncate
|
||||
append = data.Info.Opts.Append
|
||||
}
|
||||
if data.At != nil {
|
||||
atOffset = data.At.Offset
|
||||
}
|
||||
if truncate && atOffset > 0 {
|
||||
return fmt.Errorf("cannot specify non-zero offset with truncate option")
|
||||
}
|
||||
if append && atOffset > 0 {
|
||||
return fmt.Errorf("cannot specify non-zero offset with append option")
|
||||
}
|
||||
path, err := wavebase.ExpandHomeDir(data.Info.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
createMode := os.FileMode(0644)
|
||||
if data.Info != nil && data.Info.Mode > 0 {
|
||||
createMode = data.Info.Mode
|
||||
}
|
||||
dataSize := base64.StdEncoding.DecodedLen(len(data.Data64))
|
||||
dataBytes := make([]byte, dataSize)
|
||||
n, err := base64.StdEncoding.Decode(dataBytes, []byte(data.Data64))
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot decode base64 data: %w", err)
|
||||
}
|
||||
finfo, err := os.Stat(path)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("cannot stat file %q: %w", path, err)
|
||||
}
|
||||
fileSize := int64(0)
|
||||
if finfo != nil {
|
||||
fileSize = finfo.Size()
|
||||
}
|
||||
if atOffset > fileSize {
|
||||
return fmt.Errorf("cannot write at offset %d, file size is %d", atOffset, fileSize)
|
||||
}
|
||||
openFlags := os.O_CREATE | os.O_WRONLY
|
||||
if truncate {
|
||||
openFlags |= os.O_TRUNC
|
||||
}
|
||||
if append {
|
||||
openFlags |= os.O_APPEND
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, openFlags, createMode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open file %q: %w", path, err)
|
||||
}
|
||||
defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path)
|
||||
if atOffset > 0 && !append {
|
||||
n, err = file.WriteAt(dataBytes[:n], atOffset)
|
||||
} else {
|
||||
n, err = file.Write(dataBytes[:n])
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write to file %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error {
|
||||
expandedPath, err := wavebase.ExpandHomeDir(data.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
|
||||
}
|
||||
cleanedPath := filepath.Clean(expandedPath)
|
||||
|
||||
err = os.Remove(cleanedPath)
|
||||
if err != nil {
|
||||
finfo, _ := os.Stat(cleanedPath)
|
||||
if finfo != nil && finfo.IsDir() {
|
||||
if !data.Recursive {
|
||||
return fmt.Errorf(fstype.RecursiveRequiredError)
|
||||
}
|
||||
err = os.RemoveAll(cleanedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete directory %q: %w", data.Path, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("cannot delete file %q: %w", data.Path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
352
pkg/wshrpc/wshremote/wshremote_job.go
Normal file
352
pkg/wshrpc/wshremote/wshremote_job.go
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshremote
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/process"
|
||||
"github.com/wavetermdev/waveterm/pkg/jobmanager"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
func isProcessRunning(pid int, pidStartTs int64) (*process.Process, error) {
|
||||
if pid <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
proc, err := process.NewProcess(int32(pid))
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
createTime, err := proc.CreateTime()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if createTime != pidStartTs {
|
||||
return nil, nil
|
||||
}
|
||||
return proc, nil
|
||||
}
|
||||
|
||||
// returns jobRouteId, cleanupFunc, error
|
||||
func (impl *ServerImpl) connectToJobManager(ctx context.Context, jobId string, mainServerJwtToken string) (string, func(), error) {
|
||||
socketPath := jobmanager.GetJobSocketPath(jobId)
|
||||
log.Printf("connectToJobManager: connecting to socket: %s\n", socketPath)
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
log.Printf("connectToJobManager: error connecting to socket: %v\n", err)
|
||||
return "", nil, fmt.Errorf("cannot connect to job manager socket: %w", err)
|
||||
}
|
||||
log.Printf("connectToJobManager: connected to socket\n")
|
||||
|
||||
proxy := wshutil.MakeRpcProxy("jobmanager")
|
||||
linkId := impl.Router.RegisterUntrustedLink(proxy)
|
||||
|
||||
var cleanupOnce sync.Once
|
||||
cleanup := func() {
|
||||
cleanupOnce.Do(func() {
|
||||
conn.Close()
|
||||
impl.Router.UnregisterLink(linkId)
|
||||
impl.removeJobManagerConnection(jobId)
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
|
||||
if writeErr != nil {
|
||||
log.Printf("connectToJobManager: error writing to job manager socket: %v\n", writeErr)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer func() {
|
||||
close(proxy.FromRemoteCh)
|
||||
cleanup()
|
||||
}()
|
||||
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
|
||||
}()
|
||||
|
||||
routeId := wshutil.MakeLinkRouteId(linkId)
|
||||
authData := wshrpc.CommandAuthenticateToJobData{
|
||||
JobAccessToken: mainServerJwtToken,
|
||||
}
|
||||
err = wshclient.AuthenticateToJobManagerCommand(impl.RpcClient, authData, &wshrpc.RpcOpts{Route: routeId})
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("authentication to job manager failed: %w", err)
|
||||
}
|
||||
|
||||
jobRouteId := wshutil.MakeJobRouteId(jobId)
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer cancel()
|
||||
err = impl.Router.WaitForRegister(waitCtx, jobRouteId)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("timeout waiting for job route to register: %w", err)
|
||||
}
|
||||
|
||||
jobConn := &JobManagerConnection{
|
||||
JobId: jobId,
|
||||
Conn: conn,
|
||||
CleanupFn: cleanup,
|
||||
}
|
||||
impl.addJobManagerConnection(jobConn)
|
||||
|
||||
log.Printf("connectToJobManager: successfully connected and authenticated\n")
|
||||
return jobRouteId, cleanup, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) addJobManagerConnection(conn *JobManagerConnection) {
|
||||
impl.Lock.Lock()
|
||||
defer impl.Lock.Unlock()
|
||||
impl.JobManagerMap[conn.JobId] = conn
|
||||
log.Printf("addJobManagerConnection: added job manager connection for jobid=%s\n", conn.JobId)
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) removeJobManagerConnection(jobId string) {
|
||||
impl.Lock.Lock()
|
||||
defer impl.Lock.Unlock()
|
||||
if _, exists := impl.JobManagerMap[jobId]; exists {
|
||||
delete(impl.JobManagerMap, jobId)
|
||||
log.Printf("removeJobManagerConnection: removed job manager connection for jobid=%s\n", jobId)
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) getJobManagerConnection(jobId string) *JobManagerConnection {
|
||||
impl.Lock.Lock()
|
||||
defer impl.Lock.Unlock()
|
||||
return impl.JobManagerMap[jobId]
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteStartJobCommand(ctx context.Context, data wshrpc.CommandRemoteStartJobData) (*wshrpc.CommandStartJobRtnData, error) {
|
||||
log.Printf("RemoteStartJobCommand: starting, jobid=%s, clientid=%s\n", data.JobId, data.ClientId)
|
||||
if impl.Router == nil {
|
||||
return nil, fmt.Errorf("cannot start remote job: no router available")
|
||||
}
|
||||
|
||||
wshPath, err := impl.getWshPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("RemoteStartJobCommand: wshPath=%s\n", wshPath)
|
||||
|
||||
readyPipeRead, readyPipeWrite, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create ready pipe: %w", err)
|
||||
}
|
||||
defer readyPipeRead.Close()
|
||||
defer readyPipeWrite.Close()
|
||||
|
||||
cmd := exec.Command(wshPath, "jobmanager", "--jobid", data.JobId, "--clientid", data.ClientId)
|
||||
if data.PublicKeyBase64 != "" {
|
||||
cmd.Env = append(os.Environ(), "WAVETERM_PUBLICKEY="+data.PublicKeyBase64)
|
||||
}
|
||||
cmd.ExtraFiles = []*os.File{readyPipeWrite}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create stderr pipe: %w", err)
|
||||
}
|
||||
log.Printf("RemoteStartJobCommand: created pipes\n")
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("cannot start job manager: %w", err)
|
||||
}
|
||||
readyPipeWrite.Close()
|
||||
log.Printf("RemoteStartJobCommand: job manager process started\n")
|
||||
|
||||
jobAuthTokenLine := fmt.Sprintf("Wave-JobAccessToken:%s\n", data.JobAuthToken)
|
||||
if _, err := stdin.Write([]byte(jobAuthTokenLine)); err != nil {
|
||||
cmd.Process.Kill()
|
||||
return nil, fmt.Errorf("cannot write job auth token: %w", err)
|
||||
}
|
||||
stdin.Close()
|
||||
log.Printf("RemoteStartJobCommand: wrote auth token to stdin\n")
|
||||
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
log.Printf("RemoteStartJobCommand: stderr: %s\n", line)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("RemoteStartJobCommand: error reading stderr: %v\n", err)
|
||||
} else {
|
||||
log.Printf("RemoteStartJobCommand: stderr EOF\n")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
log.Printf("RemoteStartJobCommand: stdout: %s\n", line)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("RemoteStartJobCommand: error reading stdout: %v\n", err)
|
||||
} else {
|
||||
log.Printf("RemoteStartJobCommand: stdout EOF\n")
|
||||
}
|
||||
}()
|
||||
|
||||
startCh := make(chan error, 1)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(readyPipeRead)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
log.Printf("RemoteStartJobCommand: ready pipe line: %s\n", line)
|
||||
if strings.Contains(line, "Wave-JobManagerStart") {
|
||||
startCh <- nil
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
startCh <- fmt.Errorf("error reading ready pipe: %w", err)
|
||||
} else {
|
||||
log.Printf("RemoteStartJobCommand: ready pipe EOF\n")
|
||||
startCh <- fmt.Errorf("job manager exited without start signal")
|
||||
}
|
||||
}()
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
log.Printf("RemoteStartJobCommand: waiting for start signal\n")
|
||||
select {
|
||||
case err := <-startCh:
|
||||
if err != nil {
|
||||
cmd.Process.Kill()
|
||||
log.Printf("RemoteStartJobCommand: error from start signal: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("RemoteStartJobCommand: received start signal\n")
|
||||
case <-timeoutCtx.Done():
|
||||
cmd.Process.Kill()
|
||||
log.Printf("RemoteStartJobCommand: timeout waiting for start signal\n")
|
||||
return nil, fmt.Errorf("timeout waiting for job manager to start")
|
||||
}
|
||||
|
||||
go func() {
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
jobRouteId, cleanup, err := impl.connectToJobManager(ctx, data.JobId, data.MainServerJwtToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startJobData := wshrpc.CommandStartJobData{
|
||||
Cmd: data.Cmd,
|
||||
Args: data.Args,
|
||||
Env: data.Env,
|
||||
TermSize: data.TermSize,
|
||||
StreamMeta: data.StreamMeta,
|
||||
}
|
||||
rtnData, err := wshclient.StartJobCommand(impl.RpcClient, startJobData, &wshrpc.RpcOpts{Route: jobRouteId})
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("failed to start job: %w", err)
|
||||
}
|
||||
|
||||
return rtnData, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteReconnectToJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteReconnectToJobManagerData) (*wshrpc.CommandRemoteReconnectToJobManagerRtnData, error) {
|
||||
log.Printf("RemoteReconnectToJobManagerCommand: reconnecting, jobid=%s\n", data.JobId)
|
||||
if impl.Router == nil {
|
||||
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
|
||||
Success: false,
|
||||
Error: "cannot reconnect to job manager: no router available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
proc, err := isProcessRunning(data.JobManagerPid, data.JobManagerStartTs)
|
||||
if err != nil {
|
||||
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("error checking job manager process: %v", err),
|
||||
}, nil
|
||||
}
|
||||
if proc == nil {
|
||||
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
|
||||
Success: false,
|
||||
JobManagerGone: true,
|
||||
Error: fmt.Sprintf("job manager process (pid=%d) is not running", data.JobManagerPid),
|
||||
}, nil
|
||||
}
|
||||
|
||||
existingConn := impl.getJobManagerConnection(data.JobId)
|
||||
if existingConn != nil {
|
||||
log.Printf("RemoteReconnectToJobManagerCommand: closing existing connection for jobid=%s\n", data.JobId)
|
||||
if existingConn.CleanupFn != nil {
|
||||
existingConn.CleanupFn()
|
||||
}
|
||||
}
|
||||
|
||||
_, _, err = impl.connectToJobManager(ctx, data.JobId, data.MainServerJwtToken)
|
||||
if err != nil {
|
||||
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Printf("RemoteReconnectToJobManagerCommand: successfully reconnected to job manager\n")
|
||||
return &wshrpc.CommandRemoteReconnectToJobManagerRtnData{
|
||||
Success: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteDisconnectFromJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteDisconnectFromJobManagerData) error {
|
||||
log.Printf("RemoteDisconnectFromJobManagerCommand: disconnecting, jobid=%s\n", data.JobId)
|
||||
conn := impl.getJobManagerConnection(data.JobId)
|
||||
if conn == nil {
|
||||
log.Printf("RemoteDisconnectFromJobManagerCommand: no connection found for jobid=%s\n", data.JobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if conn.CleanupFn != nil {
|
||||
conn.CleanupFn()
|
||||
log.Printf("RemoteDisconnectFromJobManagerCommand: cleanup completed for jobid=%s\n", data.JobId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *ServerImpl) RemoteTerminateJobManagerCommand(ctx context.Context, data wshrpc.CommandRemoteTerminateJobManagerData) error {
|
||||
log.Printf("RemoteTerminateJobManagerCommand: terminating job manager, jobid=%s, pid=%d\n", data.JobId, data.JobManagerPid)
|
||||
proc, err := isProcessRunning(data.JobManagerPid, data.JobManagerStartTs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking job manager process: %w", err)
|
||||
}
|
||||
if proc == nil {
|
||||
log.Printf("RemoteTerminateJobManagerCommand: job manager process not running, jobid=%s\n", data.JobId)
|
||||
return nil
|
||||
}
|
||||
err = proc.SendSignal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
log.Printf("failed to send SIGTERM to job manager: %v", err)
|
||||
} else {
|
||||
log.Printf("RemoteTerminateJobManagerCommand: sent SIGTERM to job manager process, jobid=%s, pid=%d\n", data.JobId, data.JobManagerPid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -22,10 +22,18 @@ type RespOrErrorUnion[T any] struct {
|
|||
Error error
|
||||
}
|
||||
|
||||
// Instructions for adding a new RPC call
|
||||
// * methods must end with Command
|
||||
// * methods must take context as their first parameter
|
||||
// * methods may take up to one parameter, and may return either just an error, or one return value plus an error
|
||||
// * after modifying WshRpcInterface, run `task generate` to regnerate bindings
|
||||
|
||||
type WshRpcInterface interface {
|
||||
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
|
||||
AuthenticateTokenCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error)
|
||||
AuthenticateTokenVerifyCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error) // (special) validates token without binding, root router only
|
||||
AuthenticateJobManagerCommand(ctx context.Context, data CommandAuthenticateJobManagerData) error
|
||||
AuthenticateJobManagerVerifyCommand(ctx context.Context, data CommandAuthenticateJobManagerData) error // (special) validates job auth token without binding, root router only
|
||||
DisposeCommand(ctx context.Context, data CommandDisposeData) error
|
||||
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
|
||||
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
|
||||
|
|
@ -100,6 +108,10 @@ type WshRpcInterface interface {
|
|||
RemoteStreamCpuDataCommand(ctx context.Context) chan RespOrErrorUnion[TimeSeriesData]
|
||||
RemoteGetInfoCommand(ctx context.Context) (RemoteInfo, error)
|
||||
RemoteInstallRcFilesCommand(ctx context.Context) error
|
||||
RemoteStartJobCommand(ctx context.Context, data CommandRemoteStartJobData) (*CommandStartJobRtnData, error)
|
||||
RemoteReconnectToJobManagerCommand(ctx context.Context, data CommandRemoteReconnectToJobManagerData) (*CommandRemoteReconnectToJobManagerRtnData, error)
|
||||
RemoteDisconnectFromJobManagerCommand(ctx context.Context, data CommandRemoteDisconnectFromJobManagerData) error
|
||||
RemoteTerminateJobManagerCommand(ctx context.Context, data CommandRemoteTerminateJobManagerData) error
|
||||
|
||||
// emain
|
||||
WebSelectorCommand(ctx context.Context, data CommandWebSelectorData) ([]string, error)
|
||||
|
|
@ -140,6 +152,7 @@ type WshRpcInterface interface {
|
|||
|
||||
// terminal
|
||||
TermGetScrollbackLinesCommand(ctx context.Context, data CommandTermGetScrollbackLinesData) (*CommandTermGetScrollbackLinesRtnData, error)
|
||||
TermUpdateAttachedJobCommand(ctx context.Context, data CommandTermUpdateAttachedJobData) error
|
||||
|
||||
// file
|
||||
WshRpcFileInterface
|
||||
|
|
@ -154,6 +167,26 @@ type WshRpcInterface interface {
|
|||
// streams
|
||||
StreamDataCommand(ctx context.Context, data CommandStreamData) error
|
||||
StreamDataAckCommand(ctx context.Context, data CommandStreamAckData) error
|
||||
|
||||
// jobs
|
||||
AuthenticateToJobManagerCommand(ctx context.Context, data CommandAuthenticateToJobData) error
|
||||
StartJobCommand(ctx context.Context, data CommandStartJobData) (*CommandStartJobRtnData, error)
|
||||
JobPrepareConnectCommand(ctx context.Context, data CommandJobPrepareConnectData) (*CommandJobConnectRtnData, error)
|
||||
JobStartStreamCommand(ctx context.Context, data CommandJobStartStreamData) error
|
||||
JobInputCommand(ctx context.Context, data CommandJobInputData) error
|
||||
JobCmdExitedCommand(ctx context.Context, data CommandJobCmdExitedData) error // this is sent FROM the job manager => main server
|
||||
|
||||
// job controller
|
||||
JobControllerDeleteJobCommand(ctx context.Context, jobId string) error
|
||||
JobControllerListCommand(ctx context.Context) ([]*waveobj.Job, error)
|
||||
JobControllerStartJobCommand(ctx context.Context, data CommandJobControllerStartJobData) (string, error)
|
||||
JobControllerExitJobCommand(ctx context.Context, jobId string) error
|
||||
JobControllerDisconnectJobCommand(ctx context.Context, jobId string) error
|
||||
JobControllerReconnectJobCommand(ctx context.Context, jobId string) error
|
||||
JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error
|
||||
JobControllerConnectedJobsCommand(ctx context.Context) ([]string, error)
|
||||
JobControllerAttachJobCommand(ctx context.Context, data CommandJobControllerAttachJobData) error
|
||||
JobControllerDetachJobCommand(ctx context.Context, jobId string) error
|
||||
}
|
||||
|
||||
// for frontend
|
||||
|
|
@ -250,6 +283,13 @@ type CommandBlockInputData struct {
|
|||
TermSize *waveobj.TermSize `json:"termsize,omitempty"`
|
||||
}
|
||||
|
||||
type CommandJobInputData struct {
|
||||
JobId string `json:"jobid"`
|
||||
InputData64 string `json:"inputdata64,omitempty"`
|
||||
SigName string `json:"signame,omitempty"`
|
||||
TermSize *waveobj.TermSize `json:"termsize,omitempty"`
|
||||
}
|
||||
|
||||
type CommandWaitForRouteData struct {
|
||||
RouteId string `json:"routeid"`
|
||||
WaitMs int `json:"waitms"`
|
||||
|
|
@ -614,6 +654,11 @@ type CommandTermGetScrollbackLinesRtnData struct {
|
|||
LastUpdated int64 `json:"lastupdated"`
|
||||
}
|
||||
|
||||
type CommandTermUpdateAttachedJobData struct {
|
||||
BlockId string `json:"blockid"`
|
||||
JobId string `json:"jobid,omitempty"`
|
||||
}
|
||||
|
||||
type CommandElectronEncryptData struct {
|
||||
PlainText string `json:"plaintext"`
|
||||
}
|
||||
|
|
@ -633,7 +678,7 @@ type CommandElectronDecryptRtnData struct {
|
|||
}
|
||||
|
||||
type CommandStreamData struct {
|
||||
Id int64 `json:"id"` // streamid
|
||||
Id string `json:"id"` // streamid
|
||||
Seq int64 `json:"seq"` // start offset (bytes)
|
||||
Data64 string `json:"data64,omitempty"`
|
||||
Eof bool `json:"eof,omitempty"` // can be set with data or without
|
||||
|
|
@ -641,7 +686,7 @@ type CommandStreamData struct {
|
|||
}
|
||||
|
||||
type CommandStreamAckData struct {
|
||||
Id int64 `json:"id"` // streamid
|
||||
Id string `json:"id"` // streamid
|
||||
Seq int64 `json:"seq"` // next expected byte
|
||||
RWnd int64 `json:"rwnd"` // receive window size
|
||||
Fin bool `json:"fin,omitempty"` // observed end-of-stream (eof or error)
|
||||
|
|
@ -651,8 +696,108 @@ type CommandStreamAckData struct {
|
|||
}
|
||||
|
||||
type StreamMeta struct {
|
||||
Id int64 `json:"id"` // streamid
|
||||
Id string `json:"id"` // streamid
|
||||
RWnd int64 `json:"rwnd"` // initial receive window size
|
||||
ReaderRouteId string `json:"readerrouteid"`
|
||||
WriterRouteId string `json:"writerrouteid"`
|
||||
}
|
||||
|
||||
type CommandAuthenticateToJobData struct {
|
||||
JobAccessToken string `json:"jobaccesstoken"`
|
||||
}
|
||||
|
||||
type CommandAuthenticateJobManagerData struct {
|
||||
JobId string `json:"jobid"`
|
||||
JobAuthToken string `json:"jobauthtoken"`
|
||||
}
|
||||
|
||||
type CommandStartJobData struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
TermSize waveobj.TermSize `json:"termsize"`
|
||||
StreamMeta *StreamMeta `json:"streammeta,omitempty"`
|
||||
}
|
||||
|
||||
type CommandRemoteStartJobData struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
TermSize waveobj.TermSize `json:"termsize"`
|
||||
StreamMeta *StreamMeta `json:"streammeta,omitempty"`
|
||||
JobAuthToken string `json:"jobauthtoken"`
|
||||
JobId string `json:"jobid"`
|
||||
MainServerJwtToken string `json:"mainserverjwttoken"`
|
||||
ClientId string `json:"clientid"`
|
||||
PublicKeyBase64 string `json:"publickeybase64"`
|
||||
}
|
||||
|
||||
type CommandRemoteReconnectToJobManagerData struct {
|
||||
JobId string `json:"jobid"`
|
||||
JobAuthToken string `json:"jobauthtoken"`
|
||||
MainServerJwtToken string `json:"mainserverjwttoken"`
|
||||
JobManagerPid int `json:"jobmanagerpid"`
|
||||
JobManagerStartTs int64 `json:"jobmanagerstartts"`
|
||||
}
|
||||
|
||||
type CommandRemoteReconnectToJobManagerRtnData struct {
|
||||
Success bool `json:"success"`
|
||||
JobManagerGone bool `json:"jobmanagergone"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type CommandRemoteDisconnectFromJobManagerData struct {
|
||||
JobId string `json:"jobid"`
|
||||
}
|
||||
|
||||
type CommandRemoteTerminateJobManagerData struct {
|
||||
JobId string `json:"jobid"`
|
||||
JobManagerPid int `json:"jobmanagerpid"`
|
||||
JobManagerStartTs int64 `json:"jobmanagerstartts"`
|
||||
}
|
||||
|
||||
type CommandStartJobRtnData struct {
|
||||
CmdPid int `json:"cmdpid"`
|
||||
CmdStartTs int64 `json:"cmdstartts"`
|
||||
JobManagerPid int `json:"jobmanagerpid"`
|
||||
JobManagerStartTs int64 `json:"jobmanagerstartts"`
|
||||
}
|
||||
|
||||
type CommandJobPrepareConnectData struct {
|
||||
StreamMeta StreamMeta `json:"streammeta"`
|
||||
Seq int64 `json:"seq"`
|
||||
}
|
||||
|
||||
type CommandJobStartStreamData struct {
|
||||
}
|
||||
|
||||
type CommandJobConnectRtnData struct {
|
||||
Seq int64 `json:"seq"`
|
||||
StreamDone bool `json:"streamdone,omitempty"`
|
||||
StreamError string `json:"streamerror,omitempty"`
|
||||
HasExited bool `json:"hasexited,omitempty"`
|
||||
ExitCode *int `json:"exitcode,omitempty"`
|
||||
ExitSignal string `json:"exitsignal,omitempty"`
|
||||
ExitErr string `json:"exiterr,omitempty"`
|
||||
}
|
||||
|
||||
type CommandJobCmdExitedData struct {
|
||||
JobId string `json:"jobid"`
|
||||
ExitCode *int `json:"exitcode,omitempty"`
|
||||
ExitSignal string `json:"exitsignal,omitempty"`
|
||||
ExitErr string `json:"exiterr,omitempty"`
|
||||
ExitTs int64 `json:"exitts,omitempty"`
|
||||
}
|
||||
|
||||
type CommandJobControllerStartJobData struct {
|
||||
ConnName string `json:"connname"`
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
TermSize *waveobj.TermSize `json:"termsize,omitempty"`
|
||||
}
|
||||
|
||||
type CommandJobControllerAttachJobData struct {
|
||||
JobId string `json:"jobid"`
|
||||
BlockId string `json:"blockid"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,13 +35,16 @@ const (
|
|||
// we only need consts for special commands handled in the router or
|
||||
// in the RPC code / WPS code directly. other commands go through the clients
|
||||
const (
|
||||
Command_Authenticate = "authenticate" // $control
|
||||
Command_AuthenticateToken = "authenticatetoken" // $control
|
||||
Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only)
|
||||
Command_RouteAnnounce = "routeannounce" // $control (for routing)
|
||||
Command_RouteUnannounce = "routeunannounce" // $control (for routing)
|
||||
Command_Ping = "ping" // $control
|
||||
Command_ControllerInput = "controllerinput"
|
||||
Command_EventRecv = "eventrecv"
|
||||
Command_Message = "message"
|
||||
Command_Authenticate = "authenticate" // $control
|
||||
Command_AuthenticateToken = "authenticatetoken" // $control
|
||||
Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only)
|
||||
Command_AuthenticateJobManagerVerify = "authenticatejobmanagerverify" // $control:root (internal, for job auth token validation only)
|
||||
Command_RouteAnnounce = "routeannounce" // $control (for routing)
|
||||
Command_RouteUnannounce = "routeunannounce" // $control (for routing)
|
||||
Command_Ping = "ping" // $control
|
||||
Command_ControllerInput = "controllerinput"
|
||||
Command_EventRecv = "eventrecv"
|
||||
Command_Message = "message"
|
||||
Command_StreamData = "streamdata"
|
||||
Command_StreamDataAck = "streamdataack"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/wavetermdev/waveterm/pkg/filebackup"
|
||||
"github.com/wavetermdev/waveterm/pkg/filestore"
|
||||
"github.com/wavetermdev/waveterm/pkg/genconn"
|
||||
"github.com/wavetermdev/waveterm/pkg/jobcontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/awsconn"
|
||||
|
|
@ -294,6 +295,21 @@ func (ws *WshServer) ControllerResyncCommand(ctx context.Context, data wshrpc.Co
|
|||
}
|
||||
|
||||
func (ws *WshServer) ControllerInputCommand(ctx context.Context, data wshrpc.CommandBlockInputData) error {
|
||||
block, err := wstore.DBMustGet[*waveobj.Block](ctx, data.BlockId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting block: %w", err)
|
||||
}
|
||||
|
||||
if block.JobId != "" {
|
||||
jobInputData := wshrpc.CommandJobInputData{
|
||||
JobId: block.JobId,
|
||||
InputData64: data.InputData64,
|
||||
SigName: data.SigName,
|
||||
TermSize: data.TermSize,
|
||||
}
|
||||
return jobcontroller.SendInput(ctx, jobInputData)
|
||||
}
|
||||
|
||||
inputUnion := &blockcontroller.BlockInputUnion{
|
||||
SigName: data.SigName,
|
||||
TermSize: data.TermSize,
|
||||
|
|
@ -1430,3 +1446,54 @@ func (ws *WshServer) GetSecretsLinuxStorageBackendCommand(ctx context.Context) (
|
|||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobCmdExitedCommand(ctx context.Context, data wshrpc.CommandJobCmdExitedData) error {
|
||||
return jobcontroller.HandleCmdJobExited(ctx, data.JobId, data)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerListCommand(ctx context.Context) ([]*waveobj.Job, error) {
|
||||
return wstore.DBGetAllObjsByType[*waveobj.Job](ctx, waveobj.OType_Job)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerDeleteJobCommand(ctx context.Context, jobId string) error {
|
||||
return jobcontroller.DeleteJob(ctx, jobId)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerStartJobCommand(ctx context.Context, data wshrpc.CommandJobControllerStartJobData) (string, error) {
|
||||
params := jobcontroller.StartJobParams{
|
||||
ConnName: data.ConnName,
|
||||
Cmd: data.Cmd,
|
||||
Args: data.Args,
|
||||
Env: data.Env,
|
||||
TermSize: data.TermSize,
|
||||
}
|
||||
return jobcontroller.StartJob(ctx, params)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerExitJobCommand(ctx context.Context, jobId string) error {
|
||||
return jobcontroller.TerminateJobManager(ctx, jobId)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerDisconnectJobCommand(ctx context.Context, jobId string) error {
|
||||
return jobcontroller.DisconnectJob(ctx, jobId)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerReconnectJobCommand(ctx context.Context, jobId string) error {
|
||||
return jobcontroller.ReconnectJob(ctx, jobId)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerReconnectJobsForConnCommand(ctx context.Context, connName string) error {
|
||||
return jobcontroller.ReconnectJobsForConn(ctx, connName)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerConnectedJobsCommand(ctx context.Context) ([]string, error) {
|
||||
return jobcontroller.GetConnectedJobIds(), nil
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerAttachJobCommand(ctx context.Context, data wshrpc.CommandJobControllerAttachJobData) error {
|
||||
return jobcontroller.AttachJobToBlock(ctx, data.JobId, data.BlockId)
|
||||
}
|
||||
|
||||
func (ws *WshServer) JobControllerDetachJobCommand(ctx context.Context, jobId string) error {
|
||||
return jobcontroller.DetachJobFromBlock(ctx, jobId, true)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -34,6 +35,9 @@ const (
|
|||
RoutePrefix_Tab = "tab:"
|
||||
RoutePrefix_FeBlock = "feblock:"
|
||||
RoutePrefix_Builder = "builder:"
|
||||
RoutePrefix_Link = "link:"
|
||||
RoutePrefix_Job = "job:"
|
||||
RoutePrefix_Bare = "bare:"
|
||||
)
|
||||
|
||||
// this works like a network switch
|
||||
|
|
@ -118,6 +122,14 @@ func MakeBuilderRouteId(builderId string) string {
|
|||
return "builder:" + builderId
|
||||
}
|
||||
|
||||
func MakeJobRouteId(jobId string) string {
|
||||
return "job:" + jobId
|
||||
}
|
||||
|
||||
func MakeLinkRouteId(linkId baseds.LinkId) string {
|
||||
return fmt.Sprintf("%s%d", RoutePrefix_Link, linkId)
|
||||
}
|
||||
|
||||
var DefaultRouter *WshRouter
|
||||
|
||||
func NewWshRouter() *WshRouter {
|
||||
|
|
@ -245,6 +257,13 @@ func (router *WshRouter) getRouteInfo(rpcId string) *rpcRoutingInfo {
|
|||
|
||||
// returns true if message was sent, false if failed
|
||||
func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string, commandName string, ingressLinkId baseds.LinkId) bool {
|
||||
if strings.HasPrefix(routeId, RoutePrefix_Link) {
|
||||
linkIdStr := strings.TrimPrefix(routeId, RoutePrefix_Link)
|
||||
linkIdInt, err := strconv.ParseInt(linkIdStr, 10, 32)
|
||||
if err == nil {
|
||||
return router.sendMessageToLink(msgBytes, baseds.LinkId(linkIdInt), ingressLinkId)
|
||||
}
|
||||
}
|
||||
lm := router.getLinkForRoute(routeId)
|
||||
if lm != nil {
|
||||
lm.client.SendRpcMessage(msgBytes, ingressLinkId, "route")
|
||||
|
|
@ -448,8 +467,10 @@ func (router *WshRouter) runLinkClientRecvLoop(linkId baseds.LinkId, client Abst
|
|||
} else {
|
||||
// non-request messages (responses)
|
||||
if !lm.trusted {
|
||||
// drop responses from untrusted links
|
||||
continue
|
||||
// allow responses to RPCs we initiated
|
||||
if rpcMsg.ResId == "" || router.getRouteInfo(rpcMsg.ResId) == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
router.inputCh <- baseds.RpcInputChType{MsgBytes: msgBytes, IngressLinkId: linkId}
|
||||
|
|
@ -596,7 +617,7 @@ func (router *WshRouter) UnregisterLink(linkId baseds.LinkId) {
|
|||
}
|
||||
|
||||
func isBindableRouteId(routeId string) bool {
|
||||
if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) {
|
||||
if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) || strings.HasPrefix(routeId, RoutePrefix_Link) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
|
@ -676,6 +697,9 @@ func (router *WshRouter) bindRoute(linkId baseds.LinkId, routeId string, isSourc
|
|||
if !strings.HasPrefix(routeId, ControlPrefix) {
|
||||
router.announceUpstream(routeId)
|
||||
}
|
||||
if router.IsRootRouter() {
|
||||
router.publishRouteToBroker(routeId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -692,12 +716,19 @@ func (router *WshRouter) getUpstreamClient() AbstractRpcClient {
|
|||
return lm.client
|
||||
}
|
||||
|
||||
func (router *WshRouter) publishRouteToBroker(routeId string) {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("WshRouter:publishRouteToBroker", recover())
|
||||
}()
|
||||
wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteUp, Scopes: []string{routeId}})
|
||||
}
|
||||
|
||||
func (router *WshRouter) unsubscribeFromBroker(routeId string) {
|
||||
defer func() {
|
||||
panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover())
|
||||
panichandler.PanicHandler("WshRouter:unregisterRoute:routedown", recover())
|
||||
}()
|
||||
wps.Broker.UnsubscribeAll(routeId)
|
||||
wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteGone, Scopes: []string{routeId}})
|
||||
wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteDown, Scopes: []string{routeId}})
|
||||
}
|
||||
|
||||
func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMeta) {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ import (
|
|||
"github.com/wavetermdev/waveterm/pkg/baseds"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
type WshRouterControlImpl struct {
|
||||
|
|
@ -102,6 +104,46 @@ func (impl *WshRouterControlImpl) AuthenticateCommand(ctx context.Context, data
|
|||
return rtnData, nil
|
||||
}
|
||||
|
||||
func extractTokenData(token string) (wshrpc.CommandAuthenticateRtnData, error) {
|
||||
entry := shellutil.GetAndRemoveTokenSwapEntry(token)
|
||||
if entry == nil {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found")
|
||||
}
|
||||
_, err := validateRpcContextFromAuth(entry.RpcContext)
|
||||
if err != nil {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, err
|
||||
}
|
||||
if entry.RpcContext.IsRouter {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token")
|
||||
}
|
||||
if entry.RpcContext.RouteId == "" {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid")
|
||||
}
|
||||
return wshrpc.CommandAuthenticateRtnData{
|
||||
Env: entry.Env,
|
||||
InitScriptText: entry.ScriptText,
|
||||
RpcContext: entry.RpcContext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) {
|
||||
if !impl.Router.IsRootRouter() {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router")
|
||||
}
|
||||
if data.Token == "" {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message")
|
||||
}
|
||||
|
||||
rtnData, err := extractTokenData(data.Token)
|
||||
if err != nil {
|
||||
log.Printf("wshrouter authenticate-token-verify error: %v", err)
|
||||
return wshrpc.CommandAuthenticateRtnData{}, err
|
||||
}
|
||||
|
||||
log.Printf("wshrouter authenticate-token-verify success routeid=%q", rtnData.RpcContext.RouteId)
|
||||
return rtnData, nil
|
||||
}
|
||||
|
||||
func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) {
|
||||
handler := GetRpcResponseHandlerFromContext(ctx)
|
||||
if handler == nil {
|
||||
|
|
@ -117,29 +159,14 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context,
|
|||
}
|
||||
|
||||
var rtnData wshrpc.CommandAuthenticateRtnData
|
||||
var rpcContext *wshrpc.RpcContext
|
||||
var err error
|
||||
|
||||
if impl.Router.IsRootRouter() {
|
||||
entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token)
|
||||
if entry == nil {
|
||||
log.Printf("wshrouter authenticate-token error linkid=%d: no token entry found", linkId)
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found")
|
||||
}
|
||||
_, err := validateRpcContextFromAuth(entry.RpcContext)
|
||||
rtnData, err = extractTokenData(data.Token)
|
||||
if err != nil {
|
||||
log.Printf("wshrouter authenticate-token error linkid=%d: %v", linkId, err)
|
||||
return wshrpc.CommandAuthenticateRtnData{}, err
|
||||
}
|
||||
if entry.RpcContext.IsRouter {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token")
|
||||
}
|
||||
if entry.RpcContext.RouteId == "" {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid")
|
||||
}
|
||||
rpcContext = entry.RpcContext
|
||||
rtnData = wshrpc.CommandAuthenticateRtnData{
|
||||
Env: entry.Env,
|
||||
InitScriptText: entry.ScriptText,
|
||||
RpcContext: rpcContext,
|
||||
}
|
||||
} else {
|
||||
wshRpc := GetWshRpcFromContext(ctx)
|
||||
if wshRpc == nil {
|
||||
|
|
@ -154,51 +181,91 @@ func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context,
|
|||
if err != nil {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
rpcContext = rtnData.RpcContext
|
||||
}
|
||||
|
||||
if rpcContext == nil {
|
||||
if rtnData.RpcContext == nil {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no rpccontext in token response")
|
||||
}
|
||||
log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rpcContext.RouteId)
|
||||
log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rtnData.RpcContext.RouteId)
|
||||
impl.Router.trustLink(linkId, LinkKind_Leaf)
|
||||
impl.Router.bindRoute(linkId, rpcContext.RouteId, true)
|
||||
impl.Router.bindRoute(linkId, rtnData.RpcContext.RouteId, true)
|
||||
|
||||
return rtnData, nil
|
||||
}
|
||||
|
||||
func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) {
|
||||
func (impl *WshRouterControlImpl) AuthenticateJobManagerVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateJobManagerData) error {
|
||||
if !impl.Router.IsRootRouter() {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router")
|
||||
return fmt.Errorf("authenticatejobmanagerverify can only be called on root router")
|
||||
}
|
||||
|
||||
if data.Token == "" {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message")
|
||||
if data.JobId == "" {
|
||||
return fmt.Errorf("no jobid in authenticatejobmanager message")
|
||||
}
|
||||
entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token)
|
||||
if entry == nil {
|
||||
log.Printf("wshrouter authenticate-token-verify error: no token entry found")
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found")
|
||||
if data.JobAuthToken == "" {
|
||||
return fmt.Errorf("no jobauthtoken in authenticatejobmanager message")
|
||||
}
|
||||
_, err := validateRpcContextFromAuth(entry.RpcContext)
|
||||
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, data.JobId)
|
||||
if err != nil {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, err
|
||||
}
|
||||
if entry.RpcContext.IsRouter {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token")
|
||||
}
|
||||
if entry.RpcContext.RouteId == "" {
|
||||
return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid")
|
||||
log.Printf("wshrouter authenticate-jobmanager-verify error jobid=%q: failed to get job: %v", data.JobId, err)
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
rtnData := wshrpc.CommandAuthenticateRtnData{
|
||||
Env: entry.Env,
|
||||
InitScriptText: entry.ScriptText,
|
||||
RpcContext: entry.RpcContext,
|
||||
if job.JobAuthToken != data.JobAuthToken {
|
||||
log.Printf("wshrouter authenticate-jobmanager-verify error jobid=%q: invalid jobauthtoken", data.JobId)
|
||||
return fmt.Errorf("invalid jobauthtoken")
|
||||
}
|
||||
|
||||
log.Printf("wshrouter authenticate-token-verify success routeid=%q", entry.RpcContext.RouteId)
|
||||
return rtnData, nil
|
||||
log.Printf("wshrouter authenticate-jobmanager-verify success jobid=%q", data.JobId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (impl *WshRouterControlImpl) AuthenticateJobManagerCommand(ctx context.Context, data wshrpc.CommandAuthenticateJobManagerData) error {
|
||||
handler := GetRpcResponseHandlerFromContext(ctx)
|
||||
if handler == nil {
|
||||
return fmt.Errorf("no response handler in context")
|
||||
}
|
||||
linkId := handler.GetIngressLinkId()
|
||||
if linkId == baseds.NoLinkId {
|
||||
return fmt.Errorf("no ingress link found")
|
||||
}
|
||||
|
||||
if data.JobId == "" {
|
||||
return fmt.Errorf("no jobid in authenticatejobmanager message")
|
||||
}
|
||||
if data.JobAuthToken == "" {
|
||||
return fmt.Errorf("no jobauthtoken in authenticatejobmanager message")
|
||||
}
|
||||
|
||||
if impl.Router.IsRootRouter() {
|
||||
job, err := wstore.DBMustGet[*waveobj.Job](ctx, data.JobId)
|
||||
if err != nil {
|
||||
log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: failed to get job: %v", linkId, data.JobId, err)
|
||||
return fmt.Errorf("failed to get job: %w", err)
|
||||
}
|
||||
|
||||
if job.JobAuthToken != data.JobAuthToken {
|
||||
log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: invalid jobauthtoken", linkId, data.JobId)
|
||||
return fmt.Errorf("invalid jobauthtoken")
|
||||
}
|
||||
} else {
|
||||
wshRpc := GetWshRpcFromContext(ctx)
|
||||
if wshRpc == nil {
|
||||
return fmt.Errorf("no wshrpc in context")
|
||||
}
|
||||
_, err := wshRpc.SendRpcRequest(wshrpc.Command_AuthenticateJobManagerVerify, data, &wshrpc.RpcOpts{Route: ControlRootRoute})
|
||||
if err != nil {
|
||||
log.Printf("wshrouter authenticate-jobmanager error linkid=%d jobid=%q: failed to verify job auth token: %v", linkId, data.JobId, err)
|
||||
return fmt.Errorf("failed to verify job auth token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
routeId := MakeJobRouteId(data.JobId)
|
||||
log.Printf("wshrouter authenticate-jobmanager success linkid=%d jobid=%q routeid=%q", linkId, data.JobId, routeId)
|
||||
impl.Router.trustLink(linkId, LinkKind_Leaf)
|
||||
impl.Router.bindRoute(linkId, routeId, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRpcContextFromAuth(newCtx *wshrpc.RpcContext) (string, error) {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/baseds"
|
||||
"github.com/wavetermdev/waveterm/pkg/panichandler"
|
||||
"github.com/wavetermdev/waveterm/pkg/streamclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/ds"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
|
|
@ -56,6 +57,7 @@ type WshRpc struct {
|
|||
ServerImpl ServerImpl
|
||||
EventListener *EventListener
|
||||
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
||||
StreamBroker *streamclient.Broker
|
||||
Debug bool
|
||||
DebugName string
|
||||
ServerDone bool
|
||||
|
|
@ -226,6 +228,7 @@ func MakeWshRpcWithChannels(inputCh chan baseds.RpcInputChType, outputCh chan []
|
|||
ResponseHandlerMap: make(map[string]*RpcResponseHandler),
|
||||
}
|
||||
rtn.RpcContext.Store(&rpcCtx)
|
||||
rtn.StreamBroker = streamclient.NewBroker(AdaptWshRpc(rtn))
|
||||
go rtn.runServer()
|
||||
return rtn
|
||||
}
|
||||
|
|
@ -286,6 +289,36 @@ func (w *WshRpc) handleEventRecv(req *RpcMessage) {
|
|||
w.EventListener.RecvEvent(&waveEvent)
|
||||
}
|
||||
|
||||
func (w *WshRpc) handleStreamData(req *RpcMessage) {
|
||||
if w.StreamBroker == nil {
|
||||
return
|
||||
}
|
||||
if req.Data == nil {
|
||||
return
|
||||
}
|
||||
var dataPk wshrpc.CommandStreamData
|
||||
err := utilfn.ReUnmarshal(&dataPk, req.Data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.StreamBroker.RecvData(dataPk)
|
||||
}
|
||||
|
||||
func (w *WshRpc) handleStreamAck(req *RpcMessage) {
|
||||
if w.StreamBroker == nil {
|
||||
return
|
||||
}
|
||||
if req.Data == nil {
|
||||
return
|
||||
}
|
||||
var ackPk wshrpc.CommandStreamAckData
|
||||
err := utilfn.ReUnmarshal(&ackPk, req.Data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.StreamBroker.RecvAck(ackPk)
|
||||
}
|
||||
|
||||
func (w *WshRpc) handleRequestInternal(req *RpcMessage, ingressLinkId baseds.LinkId, pprofCtx context.Context) {
|
||||
if req.Command == wshrpc.Command_EventRecv {
|
||||
w.handleEventRecv(req)
|
||||
|
|
@ -381,6 +414,17 @@ outer:
|
|||
continue
|
||||
}
|
||||
if msg.IsRpcRequest() {
|
||||
// Handle stream commands synchronously since the broker is designed to be non-blocking.
|
||||
// RecvData/RecvAck just enqueue to work queues, so there's no risk of blocking the main loop.
|
||||
if msg.Command == wshrpc.Command_StreamData {
|
||||
w.handleStreamData(&msg)
|
||||
continue
|
||||
}
|
||||
if msg.Command == wshrpc.Command_StreamDataAck {
|
||||
w.handleStreamAck(&msg)
|
||||
continue
|
||||
}
|
||||
|
||||
ingressLinkId := inputVal.IngressLinkId
|
||||
go func() {
|
||||
defer func() {
|
||||
|
|
|
|||
24
pkg/wshutil/wshstreamadapter.go
Normal file
24
pkg/wshutil/wshstreamadapter.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright 2025, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshutil
|
||||
|
||||
import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type WshRpcStreamClientAdapter struct {
|
||||
rpc *WshRpc
|
||||
}
|
||||
|
||||
func (a *WshRpcStreamClientAdapter) StreamDataAckCommand(data wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error {
|
||||
return a.rpc.SendCommand("streamdataack", data, opts)
|
||||
}
|
||||
|
||||
func (a *WshRpcStreamClientAdapter) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error {
|
||||
return a.rpc.SendCommand("streamdata", data, opts)
|
||||
}
|
||||
|
||||
func AdaptWshRpc(rpc *WshRpc) *WshRpcStreamClientAdapter {
|
||||
return &WshRpcStreamClientAdapter{rpc: rpc}
|
||||
}
|
||||
|
|
@ -317,6 +317,31 @@ func DBUpdate(ctx context.Context, val waveobj.WaveObj) error {
|
|||
})
|
||||
}
|
||||
|
||||
func DBUpdateFn[T waveobj.WaveObj](ctx context.Context, id string, updateFn func(T)) error {
|
||||
return WithTx(ctx, func(tx *TxWrap) error {
|
||||
val, err := DBMustGet[T](tx.Context(), id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updateFn(val)
|
||||
return DBUpdate(tx.Context(), val)
|
||||
})
|
||||
}
|
||||
|
||||
func DBUpdateFnErr[T waveobj.WaveObj](ctx context.Context, id string, updateFn func(T) error) error {
|
||||
return WithTx(ctx, func(tx *TxWrap) error {
|
||||
val, err := DBMustGet[T](tx.Context(), id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = updateFn(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return DBUpdate(tx.Context(), val)
|
||||
})
|
||||
}
|
||||
|
||||
func DBInsert(ctx context.Context, val waveobj.WaveObj) error {
|
||||
oid := waveobj.GetOID(val)
|
||||
if oid == "" {
|
||||
|
|
|
|||
Loading…
Reference in a new issue