-
+
+
`, map[string]any{"clName": clName, "ref": ref, "onClick": props["onClick"], "children": props["children"]})
}
@@ -85,10 +85,10 @@ func Test1(t *testing.T) {
t.Fatalf("root.Root is nil")
}
printVDom(root)
- root.runWork()
+ root.RunWork()
printVDom(root)
- root.Event(testContext.ButtonId, "onClick")
- root.runWork()
+ root.Event(testContext.ButtonId, "onClick", nil)
+ root.RunWork()
printVDom(root)
}
@@ -111,8 +111,8 @@ func TestBind(t *testing.T) {
elem = Bind(`
hello world
-
-
+
+
`, nil)
jsonBytes, _ = json.MarshalIndent(elem, "", " ")
diff --git a/pkg/vdom/vdom_types.go b/pkg/vdom/vdom_types.go
new file mode 100644
index 000000000..1c09d2817
--- /dev/null
+++ b/pkg/vdom/vdom_types.go
@@ -0,0 +1,201 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package vdom
+
+import (
+ "time"
+
+ "github.com/wavetermdev/waveterm/pkg/waveobj"
+)
+
+const TextTag = "#text"
+const WaveTextTag = "wave:text"
+const WaveNullTag = "wave:null"
+const FragmentTag = "#fragment"
+const BindTag = "#bind"
+
+const ChildrenPropKey = "children"
+const KeyPropKey = "key"
+
+const ObjectType_Ref = "ref"
+const ObjectType_Binding = "binding"
+const ObjectType_Func = "func"
+
+// vdom element
+type VDomElem struct {
+ WaveId string `json:"waveid,omitempty"` // required, except for #text nodes
+ Tag string `json:"tag"`
+ Props map[string]any `json:"props,omitempty"`
+ Children []VDomElem `json:"children,omitempty"`
+ Text string `json:"text,omitempty"`
+}
+
+//// protocol messages
+
+type VDomCreateContext struct {
+ Type string `json:"type" tstype:"\"createcontext\""`
+ Ts int64 `json:"ts"`
+ Meta waveobj.MetaMapType `json:"meta,omitempty"`
+ Target *VDomTarget `json:"target,omitempty"`
+ Persist bool `json:"persist,omitempty"`
+}
+
+type VDomAsyncInitiationRequest struct {
+ Type string `json:"type" tstype:"\"asyncinitiationrequest\""`
+ Ts int64 `json:"ts"`
+ BlockId string `json:"blockid,omitempty"`
+}
+
+func MakeAsyncInitiationRequest(blockId string) VDomAsyncInitiationRequest {
+ return VDomAsyncInitiationRequest{
+ Type: "asyncinitiationrequest",
+ Ts: time.Now().UnixMilli(),
+ BlockId: blockId,
+ }
+}
+
+type VDomFrontendUpdate struct {
+ Type string `json:"type" tstype:"\"frontendupdate\""`
+ Ts int64 `json:"ts"`
+ BlockId string `json:"blockid"`
+ CorrelationId string `json:"correlationid,omitempty"`
+ Dispose bool `json:"dispose,omitempty"` // the vdom context was closed
+ Resync bool `json:"resync,omitempty"` // resync (send all backend data). useful when the FE reloads
+ RenderContext VDomRenderContext `json:"rendercontext,omitempty"`
+ Events []VDomEvent `json:"events,omitempty"`
+ StateSync []VDomStateSync `json:"statesync,omitempty"`
+ RefUpdates []VDomRefUpdate `json:"refupdates,omitempty"`
+ Messages []VDomMessage `json:"messages,omitempty"`
+}
+
+type VDomBackendUpdate struct {
+ Type string `json:"type" tstype:"\"backendupdate\""`
+ Ts int64 `json:"ts"`
+ BlockId string `json:"blockid"`
+ Opts *VDomBackendOpts `json:"opts,omitempty"`
+ RenderUpdates []VDomRenderUpdate `json:"renderupdates,omitempty"`
+ StateSync []VDomStateSync `json:"statesync,omitempty"`
+ RefOperations []VDomRefOperation `json:"refoperations,omitempty"`
+ Messages []VDomMessage `json:"messages,omitempty"`
+}
+
+///// prop types
+
+// used in props
+type VDomBinding struct {
+ Type string `json:"type" tstype:"\"binding\""`
+ Bind string `json:"bind"`
+}
+
+// used in props
+type VDomFunc struct {
+ Fn any `json:"-"` // server side function (called with reflection)
+ Type string `json:"type" tstype:"\"func\""`
+ StopPropagation bool `json:"stoppropagation,omitempty"`
+ PreventDefault bool `json:"preventdefault,omitempty"`
+ GlobalEvent string `json:"globalevent,omitempty"`
+ Keys []string `json:"keys,omitempty"` // special for keyDown events a list of keys to "capture"
+}
+
+// used in props
+type VDomRef struct {
+ Type string `json:"type" tstype:"\"ref\""`
+ RefId string `json:"refid"`
+ TrackPosition bool `json:"trackposition,omitempty"`
+ Position *VDomRefPosition `json:"position,omitempty"`
+ HasCurrent bool `json:"hascurrent,omitempty"`
+}
+
+type DomRect struct {
+ Top float64 `json:"top"`
+ Left float64 `json:"left"`
+ Right float64 `json:"right"`
+ Bottom float64 `json:"bottom"`
+ Width float64 `json:"width"`
+ Height float64 `json:"height"`
+}
+
+type VDomRefPosition struct {
+ OffsetHeight int `json:"offsetheight"`
+ OffsetWidth int `json:"offsetwidth"`
+ ScrollHeight int `json:"scrollheight"`
+ ScrollWidth int `json:"scrollwidth"`
+ ScrollTop int `json:"scrolltop"`
+ BoundingClientRect DomRect `json:"boundingclientrect"`
+}
+
+///// subbordinate protocol types
+
+type VDomEvent struct {
+ WaveId string `json:"waveid"` // empty for global events
+ EventType string `json:"eventtype"`
+ EventData any `json:"eventdata"`
+}
+
+type VDomRenderContext struct {
+ BlockId string `json:"blockid"`
+ Focused bool `json:"focused"`
+ Width int `json:"width"`
+ Height int `json:"height"`
+ RootRefId string `json:"rootrefid"`
+ Background bool `json:"background,omitempty"`
+}
+
+type VDomStateSync struct {
+ Atom string `json:"atom"`
+ Value any `json:"value"`
+}
+
+type VDomRefUpdate struct {
+ RefId string `json:"refid"`
+ HasCurrent bool `json:"hascurrent"`
+ Position *VDomRefPosition `json:"position,omitempty"`
+}
+
+type VDomBackendOpts struct {
+ CloseOnCtrlC bool `json:"closeonctrlc,omitempty"`
+ GlobalKeyboardEvents bool `json:"globalkeyboardevents,omitempty"`
+}
+
+type VDomRenderUpdate struct {
+ UpdateType string `json:"updatetype" tstype:"\"root\"|\"append\"|\"replace\"|\"remove\"|\"insert\""`
+ WaveId string `json:"waveid,omitempty"`
+ VDom VDomElem `json:"vdom"`
+ Index *int `json:"index,omitempty"`
+}
+
+type VDomRefOperation struct {
+ RefId string `json:"refid"`
+ Op string `json:"op" tsype:"\"focus\""`
+ Params []any `json:"params,omitempty"`
+}
+
+type VDomMessage struct {
+ MessageType string `json:"messagetype"`
+ Message string `json:"message"`
+ StackTrace string `json:"stacktrace,omitempty"`
+ Params []any `json:"params,omitempty"`
+}
+
+// target -- to support new targets in the future, like toolbars, partial blocks, splits, etc.
+// default is vdom context inside of a terminal block
+type VDomTarget struct {
+ NewBlock bool `json:"newblock,omitempty"`
+ Magnified bool `json:"magnified,omitempty"`
+}
+
+// matches WaveKeyboardEvent
+type VDomKeyboardEvent struct {
+ Type string `json:"type"`
+ Key string `json:"key"`
+ Code string `json:"code"`
+ Shift bool `json:"shift,omitempty"`
+ Control bool `json:"ctrl,omitempty"`
+ Alt bool `json:"alt,omitempty"`
+ Meta bool `json:"meta,omitempty"`
+ Cmd bool `json:"cmd,omitempty"`
+ Option bool `json:"option,omitempty"`
+ Repeat bool `json:"repeat,omitempty"`
+ Location int `json:"location,omitempty"`
+}
diff --git a/pkg/vdom/vdomclient/vdomclient.go b/pkg/vdom/vdomclient/vdomclient.go
new file mode 100644
index 000000000..79ee5d743
--- /dev/null
+++ b/pkg/vdom/vdomclient/vdomclient.go
@@ -0,0 +1,238 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package vdomclient
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "os"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/wavetermdev/waveterm/pkg/vdom"
+ "github.com/wavetermdev/waveterm/pkg/wps"
+ "github.com/wavetermdev/waveterm/pkg/wshrpc"
+ "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
+ "github.com/wavetermdev/waveterm/pkg/wshutil"
+)
+
+type Client struct {
+ Lock *sync.Mutex
+ Root *vdom.RootElem
+ RootElem *vdom.VDomElem
+ RpcClient *wshutil.WshRpc
+ RpcContext *wshrpc.RpcContext
+ ServerImpl *VDomServerImpl
+ IsDone bool
+ RouteId string
+ VDomContextBlockId string
+ DoneReason string
+ DoneCh chan struct{}
+ Opts vdom.VDomBackendOpts
+ GlobalEventHandler func(client *Client, event vdom.VDomEvent)
+}
+
+type VDomServerImpl struct {
+ Client *Client
+ BlockId string
+}
+
+func (*VDomServerImpl) WshServerImpl() {}
+
+func (impl *VDomServerImpl) VDomRenderCommand(ctx context.Context, feUpdate vdom.VDomFrontendUpdate) (*vdom.VDomBackendUpdate, error) {
+ if feUpdate.Dispose {
+ log.Printf("got dispose from frontend\n")
+ impl.Client.doShutdown("got dispose from frontend")
+ return nil, nil
+ }
+ if impl.Client.GetIsDone() {
+ return nil, nil
+ }
+ // set atoms
+ for _, ss := range feUpdate.StateSync {
+ impl.Client.Root.SetAtomVal(ss.Atom, ss.Value, false)
+ }
+ // run events
+ for _, event := range feUpdate.Events {
+ if event.WaveId == "" {
+ if impl.Client.GlobalEventHandler != nil {
+ impl.Client.GlobalEventHandler(impl.Client, event)
+ }
+ } else {
+ impl.Client.Root.Event(event.WaveId, event.EventType, event.EventData)
+ }
+ }
+ if feUpdate.Resync {
+ return impl.Client.fullRender()
+ }
+ return impl.Client.incrementalRender()
+}
+
+func (c *Client) GetIsDone() bool {
+ c.Lock.Lock()
+ defer c.Lock.Unlock()
+ return c.IsDone
+}
+
+func (c *Client) doShutdown(reason string) {
+ c.Lock.Lock()
+ defer c.Lock.Unlock()
+ if c.IsDone {
+ return
+ }
+ c.DoneReason = reason
+ c.IsDone = true
+ close(c.DoneCh)
+}
+
+func (c *Client) SetGlobalEventHandler(handler func(client *Client, event vdom.VDomEvent)) {
+ c.GlobalEventHandler = handler
+}
+
+func MakeClient(opts *vdom.VDomBackendOpts) (*Client, error) {
+ client := &Client{
+ Lock: &sync.Mutex{},
+ Root: vdom.MakeRoot(),
+ DoneCh: make(chan struct{}),
+ }
+ if opts != nil {
+ client.Opts = *opts
+ }
+ jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
+ if jwtToken == "" {
+ return nil, fmt.Errorf("no %s env var set", wshutil.WaveJwtTokenVarName)
+ }
+ rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken)
+ if err != nil {
+ return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
+ }
+ client.RpcContext = rpcCtx
+ if client.RpcContext == nil || client.RpcContext.BlockId == "" {
+ return nil, fmt.Errorf("no block id in rpc context")
+ }
+ client.ServerImpl = &VDomServerImpl{BlockId: client.RpcContext.BlockId, Client: client}
+ sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
+ if err != nil {
+ return nil, fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err)
+ }
+ rpcClient, err := wshutil.SetupDomainSocketRpcClient(sockName, client.ServerImpl)
+ if err != nil {
+ return nil, fmt.Errorf("error setting up domain socket rpc client: %v", err)
+ }
+ client.RpcClient = rpcClient
+ authRtn, err := wshclient.AuthenticateCommand(client.RpcClient, jwtToken, &wshrpc.RpcOpts{NoResponse: true})
+ if err != nil {
+ return nil, fmt.Errorf("error authenticating rpc connection: %v", err)
+ }
+ client.RouteId = authRtn.RouteId
+ return client, nil
+}
+
+func (c *Client) SetRootElem(elem *vdom.VDomElem) {
+ c.RootElem = elem
+}
+
+func (c *Client) CreateVDomContext(target *vdom.VDomTarget) error {
+ blockORef, err := wshclient.VDomCreateContextCommand(
+ c.RpcClient,
+ vdom.VDomCreateContext{Target: target},
+ &wshrpc.RpcOpts{Route: wshutil.MakeFeBlockRouteId(c.RpcContext.BlockId)},
+ )
+ if err != nil {
+ return err
+ }
+ c.VDomContextBlockId = blockORef.OID
+ log.Printf("created vdom context: %v\n", blockORef)
+ gotRoute, err := wshclient.WaitForRouteCommand(c.RpcClient, wshrpc.CommandWaitForRouteData{
+ RouteId: wshutil.MakeFeBlockRouteId(blockORef.OID),
+ WaitMs: 4000,
+ }, &wshrpc.RpcOpts{Timeout: 5000})
+ if err != nil {
+ return fmt.Errorf("error waiting for vdom context route: %v", err)
+ }
+ if !gotRoute {
+ return fmt.Errorf("vdom context route could not be established")
+ }
+ wshclient.EventSubCommand(c.RpcClient, wps.SubscriptionRequest{Event: wps.Event_BlockClose, Scopes: []string{
+ blockORef.String(),
+ }}, nil)
+ c.RpcClient.EventListener.On("blockclose", func(event *wps.WaveEvent) {
+ c.doShutdown("got blockclose event")
+ })
+ return nil
+}
+
+func (c *Client) SendAsyncInitiation() error {
+ if c.VDomContextBlockId == "" {
+ return fmt.Errorf("no vdom context block id")
+ }
+ if c.GetIsDone() {
+ return fmt.Errorf("client is done")
+ }
+ return wshclient.VDomAsyncInitiationCommand(
+ c.RpcClient,
+ vdom.MakeAsyncInitiationRequest(c.RpcContext.BlockId),
+ &wshrpc.RpcOpts{Route: wshutil.MakeFeBlockRouteId(c.VDomContextBlockId)},
+ )
+}
+
+func (c *Client) SetAtomVals(m map[string]any) {
+ for k, v := range m {
+ c.Root.SetAtomVal(k, v, true)
+ }
+}
+
+func (c *Client) SetAtomVal(name string, val any) {
+ c.Root.SetAtomVal(name, val, true)
+}
+
+func (c *Client) GetAtomVal(name string) any {
+ return c.Root.GetAtomVal(name)
+}
+
+func makeNullVDom() *vdom.VDomElem {
+ return &vdom.VDomElem{WaveId: uuid.New().String(), Tag: vdom.WaveNullTag}
+}
+
+func (c *Client) RegisterComponent(name string, cfunc vdom.CFunc) {
+ c.Root.RegisterComponent(name, cfunc)
+}
+
+func (c *Client) fullRender() (*vdom.VDomBackendUpdate, error) {
+ c.Root.RunWork()
+ c.Root.Render(c.RootElem)
+ renderedVDom := c.Root.MakeVDom()
+ if renderedVDom == nil {
+ renderedVDom = makeNullVDom()
+ }
+ return &vdom.VDomBackendUpdate{
+ Type: "backendupdate",
+ Ts: time.Now().UnixMilli(),
+ BlockId: c.RpcContext.BlockId,
+ Opts: &c.Opts,
+ RenderUpdates: []vdom.VDomRenderUpdate{
+ {UpdateType: "root", VDom: *renderedVDom},
+ },
+ StateSync: c.Root.GetStateSync(true),
+ }, nil
+}
+
+func (c *Client) incrementalRender() (*vdom.VDomBackendUpdate, error) {
+ c.Root.RunWork()
+ renderedVDom := c.Root.MakeVDom()
+ if renderedVDom == nil {
+ renderedVDom = makeNullVDom()
+ }
+ return &vdom.VDomBackendUpdate{
+ Type: "backendupdate",
+ Ts: time.Now().UnixMilli(),
+ BlockId: c.RpcContext.BlockId,
+ RenderUpdates: []vdom.VDomRenderUpdate{
+ {UpdateType: "root", VDom: *renderedVDom},
+ },
+ StateSync: c.Root.GetStateSync(false),
+ }, nil
+}
diff --git a/pkg/wavebase/wavebase-posix.go b/pkg/wavebase/wavebase-posix.go
index 352302221..12d2e7f75 100644
--- a/pkg/wavebase/wavebase-posix.go
+++ b/pkg/wavebase/wavebase-posix.go
@@ -14,8 +14,8 @@ import (
)
func AcquireWaveLock() (FDLock, error) {
- homeDir := GetWaveHomeDir()
- lockFileName := filepath.Join(homeDir, WaveLockFile)
+ dataHomeDir := GetWaveDataDir()
+ lockFileName := filepath.Join(dataHomeDir, WaveLockFile)
log.Printf("[base] acquiring lock on %s\n", lockFileName)
fd, err := os.OpenFile(lockFileName, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
diff --git a/pkg/wavebase/wavebase-win.go b/pkg/wavebase/wavebase-win.go
index a22ac2f85..31bdab821 100644
--- a/pkg/wavebase/wavebase-win.go
+++ b/pkg/wavebase/wavebase-win.go
@@ -14,8 +14,8 @@ import (
)
func AcquireWaveLock() (FDLock, error) {
- homeDir := GetWaveHomeDir()
- lockFileName := filepath.Join(homeDir, WaveLockFile)
+ dataHomeDir := GetWaveDataDir()
+ lockFileName := filepath.Join(dataHomeDir, WaveLockFile)
log.Printf("[base] acquiring lock on %s\n", lockFileName)
m, err := filemutex.New(lockFileName)
if err != nil {
diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go
index fd0cb5738..805386d52 100644
--- a/pkg/wavebase/wavebase.go
+++ b/pkg/wavebase/wavebase.go
@@ -17,22 +17,26 @@ import (
"strings"
"sync"
"time"
+
+ "github.com/wavetermdev/waveterm/pkg/util/panic"
)
// set by main-server.go
var WaveVersion = "0.0.0"
var BuildTime = "0"
-const DefaultWaveHome = "~/.waveterm"
-const DevWaveHome = "~/.waveterm-dev"
-const WaveHomeVarName = "WAVETERM_HOME"
+const WaveConfigHomeEnvVar = "WAVETERM_CONFIG_HOME"
+const WaveDataHomeEnvVar = "WAVETERM_DATA_HOME"
const WaveDevVarName = "WAVETERM_DEV"
const WaveLockFile = "wave.lock"
const DomainSocketBaseName = "wave.sock"
+const RemoteDomainSocketBaseName = "wave-remote.sock"
const WaveDBDir = "db"
const JwtSecret = "waveterm" // TODO generate and store this
const ConfigDir = "config"
+var RemoteWaveHome = ExpandHomeDirSafe("~/.waveterm")
+
const WaveAppPathVarName = "WAVETERM_APP_PATH"
const AppPathBinDir = "bin"
@@ -97,30 +101,43 @@ func ReplaceHomeDir(pathStr string) string {
}
func GetDomainSocketName() string {
- return filepath.Join(GetWaveHomeDir(), DomainSocketBaseName)
+ return filepath.Join(GetWaveDataDir(), DomainSocketBaseName)
}
-func GetWaveHomeDir() string {
- homeVar := os.Getenv(WaveHomeVarName)
- if homeVar != "" {
- return ExpandHomeDirSafe(homeVar)
- }
- if IsDevMode() {
- return ExpandHomeDirSafe(DevWaveHome)
- }
- return ExpandHomeDirSafe(DefaultWaveHome)
+func GetRemoteDomainSocketName() string {
+ return filepath.Join(RemoteWaveHome, RemoteDomainSocketBaseName)
}
-func EnsureWaveHomeDir() error {
- return CacheEnsureDir(GetWaveHomeDir(), "wavehome", 0700, "wave home directory")
+func GetWaveDataDir() string {
+ retVal, found := os.LookupEnv(WaveDataHomeEnvVar)
+ if !found {
+ panic.Panic(WaveDataHomeEnvVar + " not set")
+ }
+ return retVal
+}
+
+func GetWaveConfigDir() string {
+ retVal, found := os.LookupEnv(WaveConfigHomeEnvVar)
+ if !found {
+ panic.Panic(WaveConfigHomeEnvVar + " not set")
+ }
+ return retVal
+}
+
+func EnsureWaveDataDir() error {
+ return CacheEnsureDir(GetWaveDataDir(), "wavehome", 0700, "wave home directory")
}
func EnsureWaveDBDir() error {
- return CacheEnsureDir(filepath.Join(GetWaveHomeDir(), WaveDBDir), "wavedb", 0700, "wave db directory")
+ return CacheEnsureDir(filepath.Join(GetWaveDataDir(), WaveDBDir), "wavedb", 0700, "wave db directory")
}
func EnsureWaveConfigDir() error {
- return CacheEnsureDir(filepath.Join(GetWaveHomeDir(), ConfigDir), "waveconfig", 0700, "wave config directory")
+ return CacheEnsureDir(GetWaveConfigDir(), "waveconfig", 0700, "wave config directory")
+}
+
+func EnsureWavePresetsDir() error {
+ return CacheEnsureDir(filepath.Join(GetWaveConfigDir(), "presets"), "wavepresets", 0700, "wave presets directory")
}
func CacheEnsureDir(dirName string, cacheKey string, perm os.FileMode, dirDesc string) error {
diff --git a/pkg/waveobj/metaconsts.go b/pkg/waveobj/metaconsts.go
index e904c3def..fc3a73584 100644
--- a/pkg/waveobj/metaconsts.go
+++ b/pkg/waveobj/metaconsts.go
@@ -79,6 +79,13 @@ const (
MetaKey_TermLocalShellPath = "term:localshellpath"
MetaKey_TermLocalShellOpts = "term:localshellopts"
MetaKey_TermScrollback = "term:scrollback"
+ MetaKey_TermVDomSubBlockId = "term:vdomblockid"
+
+ MetaKey_VDomClear = "vdom:*"
+ MetaKey_VDomInitialized = "vdom:initialized"
+ MetaKey_VDomCorrelationId = "vdom:correlationid"
+ MetaKey_VDomRoute = "vdom:route"
+ MetaKey_VDomPersist = "vdom:persist"
MetaKey_Count = "count"
)
diff --git a/pkg/waveobj/waveobj.go b/pkg/waveobj/waveobj.go
index ba5931316..13111b54a 100644
--- a/pkg/waveobj/waveobj.go
+++ b/pkg/waveobj/waveobj.go
@@ -94,6 +94,14 @@ func ParseORef(orefStr string) (ORef, error) {
return ORef{OType: otype, OID: oid}, nil
}
+func ParseORefNoErr(orefStr string) *ORef {
+ oref, err := ParseORef(orefStr)
+ if err != nil {
+ return nil
+ }
+ return &oref
+}
+
type WaveObj interface {
GetOType() string // should not depend on object state (should work with nil value)
}
diff --git a/pkg/waveobj/wtype.go b/pkg/waveobj/wtype.go
index 2f3e87717..5953a22b3 100644
--- a/pkg/waveobj/wtype.go
+++ b/pkg/waveobj/wtype.go
@@ -118,6 +118,7 @@ type Client struct {
Meta MetaMapType `json:"meta"`
TosAgreed int64 `json:"tosagreed,omitempty"`
HasOldHistory bool `json:"hasoldhistory,omitempty"`
+ NextTabId int `json:"nexttabid,omitempty"`
}
func (*Client) GetOType() string {
@@ -252,11 +253,13 @@ type WinSize struct {
type Block struct {
OID string `json:"oid"`
+ ParentORef string `json:"parentoref,omitempty"`
Version int `json:"version"`
BlockDef *BlockDef `json:"blockdef"`
RuntimeOpts *RuntimeOpts `json:"runtimeopts,omitempty"`
Stickers []*StickerType `json:"stickers,omitempty"`
Meta MetaMapType `json:"meta"`
+ SubBlockIds []string `json:"subblockids,omitempty"`
}
func (*Block) GetOType() string {
diff --git a/pkg/waveobj/wtypemeta.go b/pkg/waveobj/wtypemeta.go
index 93c1919f0..367f44bc4 100644
--- a/pkg/waveobj/wtypemeta.go
+++ b/pkg/waveobj/wtypemeta.go
@@ -80,6 +80,13 @@ type MetaTSType struct {
TermLocalShellPath string `json:"term:localshellpath,omitempty"` // matches settings
TermLocalShellOpts []string `json:"term:localshellopts,omitempty"` // matches settings
TermScrollback *int `json:"term:scrollback,omitempty"`
+ TermVDomSubBlockId string `json:"term:vdomblockid,omitempty"`
+
+ VDomClear bool `json:"vdom:*,omitempty"`
+ VDomInitialized bool `json:"vdom:initialized,omitempty"`
+ VDomCorrelationId string `json:"vdom:correlationid,omitempty"`
+ VDomRoute string `json:"vdom:route,omitempty"`
+ VDomPersist bool `json:"vdom:persist,omitempty"`
Count int `json:"count,omitempty"` // temp for cpu plot. will remove later
}
diff --git a/pkg/wconfig/defaultconfig/defaultconfig.go b/pkg/wconfig/defaultconfig/defaultconfig.go
index bc28a9557..9527a069c 100644
--- a/pkg/wconfig/defaultconfig/defaultconfig.go
+++ b/pkg/wconfig/defaultconfig/defaultconfig.go
@@ -5,5 +5,5 @@ package defaultconfig
import "embed"
-//go:embed *.json
+//go:embed *.json all:*/*.json
var ConfigFS embed.FS
diff --git a/pkg/wconfig/defaultconfig/presets.json b/pkg/wconfig/defaultconfig/presets.json
index 30dc05942..3f1a38135 100644
--- a/pkg/wconfig/defaultconfig/presets.json
+++ b/pkg/wconfig/defaultconfig/presets.json
@@ -94,23 +94,5 @@
"bg": "linear-gradient(120deg,hsla(350, 65%, 57%, 1),hsla(30,60%,60%, .75), hsla(208,69%,50%,.15), hsl(230,60%,40%)),radial-gradient(at top right,hsla(300,60%,70%,0.3),transparent),radial-gradient(at top left,hsla(330,100%,70%,.20),transparent),radial-gradient(at top right,hsla(190,100%,40%,.20),transparent),radial-gradient(at bottom left,hsla(323,54%,50%,.5),transparent),radial-gradient(at bottom left,hsla(144,54%,50%,.25),transparent)",
"bg:blendmode": "overlay",
"bg:text": "rgb(200, 200, 200)"
- },
- "ai@global": {
- "display:name": "Global default",
- "display:order": -1,
- "ai:*": true
- },
- "ai@wave": {
- "display:name": "Wave Proxy - gpt-4o-mini",
- "display:order": 0,
- "ai:*": true,
- "ai:apitype": "",
- "ai:baseurl": "",
- "ai:apitoken": "",
- "ai:name": "",
- "ai:orgid": "",
- "ai:model": "gpt-4o-mini",
- "ai:maxtokens": 2048,
- "ai:timeoutms": 60000
}
}
diff --git a/pkg/wconfig/defaultconfig/presets/ai.json b/pkg/wconfig/defaultconfig/presets/ai.json
new file mode 100644
index 000000000..11c0b848e
--- /dev/null
+++ b/pkg/wconfig/defaultconfig/presets/ai.json
@@ -0,0 +1,20 @@
+{
+ "ai@global": {
+ "display:name": "Global default",
+ "display:order": -1,
+ "ai:*": true
+ },
+ "ai@wave": {
+ "display:name": "Wave Proxy - gpt-4o-mini",
+ "display:order": 0,
+ "ai:*": true,
+ "ai:apitype": "",
+ "ai:baseurl": "",
+ "ai:apitoken": "",
+ "ai:name": "",
+ "ai:orgid": "",
+ "ai:model": "gpt-4o-mini",
+ "ai:maxtokens": 2048,
+ "ai:timeoutms": 60000
+ }
+}
diff --git a/pkg/wconfig/defaultconfig/settings.json b/pkg/wconfig/defaultconfig/settings.json
index b5c821ecf..a8f591d0a 100644
--- a/pkg/wconfig/defaultconfig/settings.json
+++ b/pkg/wconfig/defaultconfig/settings.json
@@ -11,6 +11,7 @@
"web:defaulturl": "https://github.com/wavetermdev/waveterm",
"web:defaultsearch": "https://www.google.com/search?q={query}",
"window:tilegapsize": 3,
+ "window:maxtabcachesize": 10,
"telemetry:enabled": true,
"term:copyonselect": true
}
diff --git a/pkg/wconfig/filewatcher.go b/pkg/wconfig/filewatcher.go
index f8ce9fd93..3478b4a4a 100644
--- a/pkg/wconfig/filewatcher.go
+++ b/pkg/wconfig/filewatcher.go
@@ -14,8 +14,6 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps"
)
-var configDirAbsPath = filepath.Join(wavebase.GetWaveHomeDir(), wavebase.ConfigDir)
-
var instance *Watcher
var once sync.Once
@@ -38,10 +36,20 @@ func GetWatcher() *Watcher {
log.Printf("failed to create file watcher: %v", err)
return
}
+ configDirAbsPath := wavebase.GetWaveConfigDir()
instance = &Watcher{watcher: watcher}
err = instance.watcher.Add(configDirAbsPath)
+ const failedStr = "failed to add path %s to watcher: %v"
if err != nil {
- log.Printf("failed to add path %s to watcher: %v", configDirAbsPath, err)
+ log.Printf(failedStr, configDirAbsPath, err)
+ }
+
+ subdirs := GetConfigSubdirs()
+ for _, dir := range subdirs {
+ err = instance.watcher.Add(dir)
+ if err != nil {
+ log.Printf(failedStr, dir, err)
+ }
}
})
return instance
diff --git a/pkg/wconfig/metaconsts.go b/pkg/wconfig/metaconsts.go
index 57518e734..16442a692 100644
--- a/pkg/wconfig/metaconsts.go
+++ b/pkg/wconfig/metaconsts.go
@@ -60,6 +60,7 @@ const (
ConfigKey_WindowShowMenuBar = "window:showmenubar"
ConfigKey_WindowNativeTitleBar = "window:nativetitlebar"
ConfigKey_WindowDisableHardwareAcceleration = "window:disablehardwareacceleration"
+ ConfigKey_WindowMaxTabCacheSize = "window:maxtabcachesize"
ConfigKey_TelemetryClear = "telemetry:*"
ConfigKey_TelemetryEnabled = "telemetry:enabled"
diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go
index c54de90e9..8593af40b 100644
--- a/pkg/wconfig/settingsconfig.go
+++ b/pkg/wconfig/settingsconfig.go
@@ -7,6 +7,8 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "io/fs"
+ "log"
"os"
"path/filepath"
"reflect"
@@ -14,6 +16,7 @@ import (
"strings"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
+ "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig/defaultconfig"
)
@@ -101,6 +104,7 @@ type SettingsType struct {
WindowShowMenuBar bool `json:"window:showmenubar,omitempty"`
WindowNativeTitleBar bool `json:"window:nativetitlebar,omitempty"`
WindowDisableHardwareAcceleration bool `json:"window:disablehardwareacceleration,omitempty"`
+ WindowMaxTabCacheSize int `json:"window:maxtabcachesize,omitempty"`
TelemetryClear bool `json:"telemetry:*,omitempty"`
TelemetryEnabled bool `json:"telemetry:enabled,omitempty"`
@@ -180,18 +184,23 @@ func readConfigHelper(fileName string, barr []byte, readErr error) (waveobj.Meta
return rtn, cerrs
}
+func readConfigFileFS(fsys fs.FS, logPrefix string, fileName string) (waveobj.MetaMapType, []ConfigError) {
+ barr, readErr := fs.ReadFile(fsys, fileName)
+ return readConfigHelper(logPrefix+fileName, barr, readErr)
+}
+
func ReadDefaultsConfigFile(fileName string) (waveobj.MetaMapType, []ConfigError) {
- barr, readErr := defaultconfig.ConfigFS.ReadFile(fileName)
- return readConfigHelper("defaults:"+fileName, barr, readErr)
+ return readConfigFileFS(defaultconfig.ConfigFS, "defaults:", fileName)
}
func ReadWaveHomeConfigFile(fileName string) (waveobj.MetaMapType, []ConfigError) {
- fullFileName := filepath.Join(configDirAbsPath, fileName)
- barr, err := os.ReadFile(fullFileName)
- return readConfigHelper(fullFileName, barr, err)
+ configDirAbsPath := wavebase.GetWaveConfigDir()
+ configDirFsys := os.DirFS(configDirAbsPath)
+ return readConfigFileFS(configDirFsys, "", fileName)
}
func WriteWaveHomeConfigFile(fileName string, m waveobj.MetaMapType) error {
+ configDirAbsPath := wavebase.GetWaveConfigDir()
fullFileName := filepath.Join(configDirAbsPath, fileName)
barr, err := jsonMarshalConfigInOrder(m)
if err != nil {
@@ -221,17 +230,71 @@ func mergeMetaMapSimple(m waveobj.MetaMapType, toMerge waveobj.MetaMapType) wave
return m
}
-func ReadConfigPart(partName string, simpleMerge bool) (waveobj.MetaMapType, []ConfigError) {
- defConfig, cerrs1 := ReadDefaultsConfigFile(partName)
- userConfig, cerrs2 := ReadWaveHomeConfigFile(partName)
- allErrs := append(cerrs1, cerrs2...)
+func mergeMetaMap(m waveobj.MetaMapType, toMerge waveobj.MetaMapType, simpleMerge bool) waveobj.MetaMapType {
if simpleMerge {
- return mergeMetaMapSimple(defConfig, userConfig), allErrs
+ return mergeMetaMapSimple(m, toMerge)
} else {
- return waveobj.MergeMeta(defConfig, userConfig, true), allErrs
+ return waveobj.MergeMeta(m, toMerge, true)
}
}
+func selectDirEntsBySuffix(dirEnts []fs.DirEntry, fileNameSuffix string) []fs.DirEntry {
+ var rtn []fs.DirEntry
+ for _, ent := range dirEnts {
+ if ent.IsDir() {
+ continue
+ }
+ if !strings.HasSuffix(ent.Name(), fileNameSuffix) {
+ continue
+ }
+ rtn = append(rtn, ent)
+ }
+ return rtn
+}
+
+func SortFileNameDescend(files []fs.DirEntry) {
+ sort.Slice(files, func(i, j int) bool {
+ return files[i].Name() > files[j].Name()
+ })
+}
+
+// Read and merge all files in the specified directory matching the supplied suffix
+func readConfigFilesForDir(fsys fs.FS, logPrefix string, dirName string, fileName string, simpleMerge bool) (waveobj.MetaMapType, []ConfigError) {
+ dirEnts, _ := fs.ReadDir(fsys, dirName)
+ suffixEnts := selectDirEntsBySuffix(dirEnts, fileName+".json")
+ SortFileNameDescend(suffixEnts)
+ var rtn waveobj.MetaMapType
+ var errs []ConfigError
+ for _, ent := range suffixEnts {
+ fileVal, cerrs := readConfigFileFS(fsys, logPrefix, filepath.Join(dirName, ent.Name()))
+ rtn = mergeMetaMap(rtn, fileVal, simpleMerge)
+ errs = append(errs, cerrs...)
+ }
+ return rtn, errs
+}
+
+// Read and merge all files in the specified config filesystem matching the patterns `
.json` and `/*.json`
+func readConfigPartForFS(fsys fs.FS, logPrefix string, partName string, simpleMerge bool) (waveobj.MetaMapType, []ConfigError) {
+ config, errs := readConfigFilesForDir(fsys, logPrefix, partName, "", simpleMerge)
+ allErrs := errs
+ rtn := config
+ config, errs = readConfigFileFS(fsys, logPrefix, partName+".json")
+ allErrs = append(allErrs, errs...)
+ return mergeMetaMap(rtn, config, simpleMerge), allErrs
+}
+
+// Combine files from the defaults and home directory for the specified config part name
+func readConfigPart(partName string, simpleMerge bool) (waveobj.MetaMapType, []ConfigError) {
+ configDirAbsPath := wavebase.GetWaveConfigDir()
+ configDirFsys := os.DirFS(configDirAbsPath)
+ defaultConfigs, cerrs := readConfigPartForFS(defaultconfig.ConfigFS, "defaults:", partName, simpleMerge)
+ homeConfigs, cerrs1 := readConfigPartForFS(configDirFsys, "", partName, simpleMerge)
+
+ rtn := defaultConfigs
+ allErrs := append(cerrs, cerrs1...)
+ return mergeMetaMap(rtn, homeConfigs, simpleMerge), allErrs
+}
+
func ReadFullConfig() FullConfigType {
var fullConfig FullConfigType
configRType := reflect.TypeOf(fullConfig)
@@ -246,13 +309,15 @@ func ReadFullConfig() FullConfigType {
continue
}
jsonTag := utilfn.GetJsonTag(field)
+ simpleMerge := field.Tag.Get("merge") == ""
+ var configPart waveobj.MetaMapType
+ var errs []ConfigError
if jsonTag == "-" || jsonTag == "" {
continue
+ } else {
+ configPart, errs = readConfigPart(jsonTag, simpleMerge)
}
- simpleMerge := field.Tag.Get("merge") == ""
- fileName := jsonTag + ".json"
- configPart, cerrs := ReadConfigPart(fileName, simpleMerge)
- fullConfig.ConfigErrors = append(fullConfig.ConfigErrors, cerrs...)
+ fullConfig.ConfigErrors = append(fullConfig.ConfigErrors, errs...)
if configPart != nil {
fieldPtr := configRVal.Field(fieldIdx).Addr().Interface()
utilfn.ReUnmarshal(fieldPtr, configPart)
@@ -261,6 +326,29 @@ func ReadFullConfig() FullConfigType {
return fullConfig
}
+func GetConfigSubdirs() []string {
+ var fullConfig FullConfigType
+ configRType := reflect.TypeOf(fullConfig)
+ var retVal []string
+ configDirAbsPath := wavebase.GetWaveConfigDir()
+ for fieldIdx := 0; fieldIdx < configRType.NumField(); fieldIdx++ {
+ field := configRType.Field(fieldIdx)
+ if field.PkgPath != "" {
+ continue
+ }
+ configFile := field.Tag.Get("configfile")
+ if configFile == "-" {
+ continue
+ }
+ jsonTag := utilfn.GetJsonTag(field)
+ if jsonTag != "-" && jsonTag != "" && jsonTag != "settings" {
+ retVal = append(retVal, filepath.Join(configDirAbsPath, jsonTag))
+ }
+ }
+ log.Printf("subdirs: %v\n", retVal)
+ return retVal
+}
+
func getConfigKeyType(configKey string) reflect.Type {
ctype := reflect.TypeOf(SettingsType{})
for i := 0; i < ctype.NumField(); i++ {
diff --git a/pkg/wcore/wcore.go b/pkg/wcore/wcore.go
index e2967b7f0..78761bc63 100644
--- a/pkg/wcore/wcore.go
+++ b/pkg/wcore/wcore.go
@@ -26,21 +26,35 @@ import (
const DefaultTimeout = 2 * time.Second
const DefaultActivateBlockTimeout = 60 * time.Second
-func DeleteBlock(ctx context.Context, tabId string, blockId string) error {
- err := wstore.DeleteBlock(ctx, tabId, blockId)
+func DeleteBlock(ctx context.Context, blockId string) error {
+ block, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
+ if err != nil {
+ return fmt.Errorf("error getting block: %w", err)
+ }
+ if block == nil {
+ return nil
+ }
+ if len(block.SubBlockIds) > 0 {
+ for _, subBlockId := range block.SubBlockIds {
+ err := DeleteBlock(ctx, subBlockId)
+ if err != nil {
+ return fmt.Errorf("error deleting subblock %s: %w", subBlockId, err)
+ }
+ }
+ }
+ err = wstore.DeleteBlock(ctx, blockId)
if err != nil {
return fmt.Errorf("error deleting block: %w", err)
}
go blockcontroller.StopBlockController(blockId)
- sendBlockCloseEvent(tabId, blockId)
+ sendBlockCloseEvent(blockId)
return nil
}
-func sendBlockCloseEvent(tabId string, blockId string) {
+func sendBlockCloseEvent(blockId string) {
waveEvent := wps.WaveEvent{
Event: wps.Event_BlockClose,
Scopes: []string{
- waveobj.MakeORef(waveobj.OType_Tab, tabId).String(),
waveobj.MakeORef(waveobj.OType_Block, blockId).String(),
},
Data: blockId,
@@ -58,7 +72,7 @@ func DeleteTab(ctx context.Context, workspaceId string, tabId string) error {
}
// close blocks (sends events + stops block controllers)
for _, blockId := range tabData.BlockIds {
- err := DeleteBlock(ctx, tabId, blockId)
+ err := DeleteBlock(ctx, blockId)
if err != nil {
return fmt.Errorf("error deleting block %s: %w", blockId, err)
}
@@ -78,6 +92,18 @@ func CreateTab(ctx context.Context, windowId string, tabName string, activateTab
if err != nil {
return "", fmt.Errorf("error getting window: %w", err)
}
+ if tabName == "" {
+ client, err := wstore.DBGetSingleton[*waveobj.Client](ctx)
+ if err != nil {
+ return "", fmt.Errorf("error getting client: %w", err)
+ }
+ tabName = "T" + fmt.Sprint(client.NextTabId)
+ client.NextTabId++
+ err = wstore.DBUpdate(ctx, client)
+ if err != nil {
+ return "", fmt.Errorf("error updating client: %w", err)
+ }
+ }
tab, err := wstore.CreateTab(ctx, windowData.WorkspaceId, tabName)
if err != nil {
return "", fmt.Errorf("error creating tab: %w", err)
@@ -122,7 +148,7 @@ func CreateWindow(ctx context.Context, winSize *waveobj.WinSize) (*waveobj.Windo
if err != nil {
return nil, fmt.Errorf("error inserting workspace: %w", err)
}
- _, err = CreateTab(ctx, windowId, "T1", true)
+ _, err = CreateTab(ctx, windowId, "", true)
if err != nil {
return nil, fmt.Errorf("error inserting tab: %w", err)
}
@@ -151,7 +177,7 @@ func checkAndFixWindow(ctx context.Context, windowId string) {
}
if len(workspace.TabIds) == 0 {
log.Printf("fixing workspace with no tabs %q (in checkAndFixWindow)\n", workspace.OID)
- _, err = CreateTab(ctx, windowId, "T1", true)
+ _, err = CreateTab(ctx, windowId, "", true)
if err != nil {
log.Printf("error creating tab (in checkAndFixWindow): %v\n", err)
}
@@ -172,6 +198,17 @@ func EnsureInitialData() (*waveobj.Window, bool, error) {
}
firstRun = true
}
+ if client.NextTabId == 0 {
+ tabCount, err := wstore.DBGetCount[*waveobj.Tab](ctx)
+ if err != nil {
+ return nil, false, fmt.Errorf("error getting tab count: %w", err)
+ }
+ client.NextTabId = tabCount + 1
+ err = wstore.DBUpdate(ctx, client)
+ if err != nil {
+ return nil, false, fmt.Errorf("error updating client: %w", err)
+ }
+ }
log.Printf("clientid: %s\n", client.OID)
if len(client.WindowIds) == 1 {
checkAndFixWindow(ctx, client.WindowIds[0])
@@ -190,6 +227,7 @@ func CreateClient(ctx context.Context) (*waveobj.Client, error) {
client := &waveobj.Client{
OID: uuid.NewString(),
WindowIds: []string{},
+ NextTabId: 1,
}
err := wstore.DBInsert(ctx, client)
if err != nil {
@@ -198,6 +236,20 @@ func CreateClient(ctx context.Context) (*waveobj.Client, error) {
return client, nil
}
+func CreateSubBlock(ctx context.Context, blockId string, blockDef *waveobj.BlockDef) (*waveobj.Block, error) {
+ if blockDef == nil {
+ return nil, fmt.Errorf("blockDef is nil")
+ }
+ if blockDef.Meta == nil || blockDef.Meta.GetString(waveobj.MetaKey_View, "") == "" {
+ return nil, fmt.Errorf("no view provided for new block")
+ }
+ blockData, err := wstore.CreateSubBlock(ctx, blockId, blockDef)
+ if err != nil {
+ return nil, fmt.Errorf("error creating sub block: %w", err)
+ }
+ return blockData, nil
+}
+
func CreateBlock(ctx context.Context, tabId string, blockDef *waveobj.BlockDef, rtOpts *waveobj.RuntimeOpts) (*waveobj.Block, error) {
if blockDef == nil {
return nil, fmt.Errorf("blockDef is nil")
diff --git a/pkg/web/web.go b/pkg/web/web.go
index 0e470490b..695c73bbf 100644
--- a/pkg/web/web.go
+++ b/pkg/web/web.go
@@ -431,7 +431,7 @@ func MakeTCPListener(serviceName string) (net.Listener, error) {
}
func MakeUnixListener() (net.Listener, error) {
- serverAddr := wavebase.GetWaveHomeDir() + "/wave.sock"
+ serverAddr := wavebase.GetDomainSocketName()
os.Remove(serverAddr) // ignore error
rtn, err := net.Listen("unix", serverAddr)
if err != nil {
diff --git a/pkg/web/ws.go b/pkg/web/ws.go
index 5ef0c0344..c0374efff 100644
--- a/pkg/web/ws.go
+++ b/pkg/web/ws.go
@@ -252,7 +252,7 @@ func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy)
wshutil.DefaultRouter.UnregisterRoute(routeId)
}
RouteToConnMap[routeId] = wsConnId
- wshutil.DefaultRouter.RegisterRoute(routeId, wproxy)
+ wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true)
}
func unregisterConn(wsConnId string, routeId string) {
@@ -269,11 +269,10 @@ func unregisterConn(wsConnId string, routeId string) {
}
func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
- windowId := r.URL.Query().Get("windowid")
- if windowId == "" {
- return fmt.Errorf("windowid is required")
+ tabId := r.URL.Query().Get("tabid")
+ if tabId == "" {
+ return fmt.Errorf("tabid is required")
}
-
err := authkey.ValidateIncomingRequest(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
@@ -290,13 +289,13 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
outputCh := make(chan any, 100)
closeCh := make(chan any)
var routeId string
- if windowId == wshutil.ElectronRoute {
+ if tabId == wshutil.ElectronRoute {
routeId = wshutil.ElectronRoute
} else {
- routeId = wshutil.MakeWindowRouteId(windowId)
+ routeId = wshutil.MakeTabRouteId(tabId)
}
- log.Printf("[websocket] new connection: windowid:%s connid:%s routeid:%s\n", windowId, wsConnId, routeId)
- eventbus.RegisterWSChannel(wsConnId, windowId, outputCh)
+ log.Printf("[websocket] new connection: tabid:%s connid:%s routeid:%s\n", tabId, wsConnId, routeId)
+ eventbus.RegisterWSChannel(wsConnId, tabId, outputCh)
defer eventbus.UnregisterWSChannel(wsConnId)
wproxy := wshutil.MakeRpcProxy() // we create a wshproxy to handle rpc messages to/from the window
defer close(wproxy.ToRemoteCh)
diff --git a/pkg/wps/wps.go b/pkg/wps/wps.go
index 87e00faa2..c742dc3fd 100644
--- a/pkg/wps/wps.go
+++ b/pkg/wps/wps.go
@@ -5,7 +5,6 @@
package wps
import (
- "log"
"strings"
"sync"
@@ -76,7 +75,7 @@ func (b *BrokerType) GetClient() Client {
// if already subscribed, this will *resubscribe* with the new subscription (remove the old one, and replace with this one)
func (b *BrokerType) Subscribe(subRouteId string, sub SubscriptionRequest) {
- log.Printf("[wps] sub %s %s\n", subRouteId, sub.Event)
+ // log.Printf("[wps] sub %s %s\n", subRouteId, sub.Event)
if sub.Event == "" {
return
}
@@ -138,7 +137,7 @@ func addStrToScopeMap(scopeMap map[string][]string, scope string, routeId string
}
func (b *BrokerType) Unsubscribe(subRouteId string, eventName string) {
- log.Printf("[wps] unsub %s %s\n", subRouteId, eventName)
+ // log.Printf("[wps] unsub %s %s\n", subRouteId, eventName)
b.Lock.Lock()
defer b.Lock.Unlock()
b.unsubscribe_nolock(subRouteId, eventName)
diff --git a/pkg/wps/wpstypes.go b/pkg/wps/wpstypes.go
index d925f3eeb..d0e6d4202 100644
--- a/pkg/wps/wpstypes.go
+++ b/pkg/wps/wpstypes.go
@@ -11,6 +11,7 @@ const (
Event_BlockFile = "blockfile"
Event_Config = "config"
Event_UserInput = "userinput"
+ Event_RouteGone = "route:gone"
)
type WaveEvent struct {
diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go
index 382f1041a..2539742e7 100644
--- a/pkg/wshrpc/wshclient/wshclient.go
+++ b/pkg/wshrpc/wshclient/wshclient.go
@@ -11,6 +11,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
+ "github.com/wavetermdev/waveterm/pkg/vdom"
)
// command "authenticate", wshserver.AuthenticateCommand
@@ -85,12 +86,30 @@ func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, o
return resp, err
}
+// command "createsubblock", wshserver.CreateSubBlockCommand
+func CreateSubBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateSubBlockData, opts *wshrpc.RpcOpts) (waveobj.ORef, error) {
+ resp, err := sendRpcRequestCallHelper[waveobj.ORef](w, "createsubblock", data, opts)
+ return resp, err
+}
+
// command "deleteblock", wshserver.DeleteBlockCommand
func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "deleteblock", data, opts)
return err
}
+// command "deletesubblock", wshserver.DeleteSubBlockCommand
+func DeleteSubBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, opts *wshrpc.RpcOpts) error {
+ _, err := sendRpcRequestCallHelper[any](w, "deletesubblock", data, opts)
+ return err
+}
+
+// command "dispose", wshserver.DisposeCommand
+func DisposeCommand(w *wshutil.WshRpc, data wshrpc.CommandDisposeData, opts *wshrpc.RpcOpts) error {
+ _, err := sendRpcRequestCallHelper[any](w, "dispose", data, opts)
+ return err
+}
+
// command "eventpublish", wshserver.EventPublishCommand
func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts)
@@ -260,10 +279,52 @@ func TestCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error {
return err
}
+// command "vdomasyncinitiation", wshserver.VDomAsyncInitiationCommand
+func VDomAsyncInitiationCommand(w *wshutil.WshRpc, data vdom.VDomAsyncInitiationRequest, opts *wshrpc.RpcOpts) error {
+ _, err := sendRpcRequestCallHelper[any](w, "vdomasyncinitiation", data, opts)
+ return err
+}
+
+// command "vdomcreatecontext", wshserver.VDomCreateContextCommand
+func VDomCreateContextCommand(w *wshutil.WshRpc, data vdom.VDomCreateContext, opts *wshrpc.RpcOpts) (*waveobj.ORef, error) {
+ resp, err := sendRpcRequestCallHelper[*waveobj.ORef](w, "vdomcreatecontext", data, opts)
+ return resp, err
+}
+
+// command "vdomrender", wshserver.VDomRenderCommand
+func VDomRenderCommand(w *wshutil.WshRpc, data vdom.VDomFrontendUpdate, opts *wshrpc.RpcOpts) (*vdom.VDomBackendUpdate, error) {
+ resp, err := sendRpcRequestCallHelper[*vdom.VDomBackendUpdate](w, "vdomrender", data, opts)
+ return resp, err
+}
+
+// command "waitforroute", wshserver.WaitForRouteCommand
+func WaitForRouteCommand(w *wshutil.WshRpc, data wshrpc.CommandWaitForRouteData, opts *wshrpc.RpcOpts) (bool, error) {
+ resp, err := sendRpcRequestCallHelper[bool](w, "waitforroute", data, opts)
+ return resp, err
+}
+
// command "webselector", wshserver.WebSelectorCommand
func WebSelectorCommand(w *wshutil.WshRpc, data wshrpc.CommandWebSelectorData, opts *wshrpc.RpcOpts) ([]string, error) {
resp, err := sendRpcRequestCallHelper[[]string](w, "webselector", data, opts)
return resp, err
}
+// command "wsldefaultdistro", wshserver.WslDefaultDistroCommand
+func WslDefaultDistroCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) {
+ resp, err := sendRpcRequestCallHelper[string](w, "wsldefaultdistro", nil, opts)
+ return resp, err
+}
+
+// command "wsllist", wshserver.WslListCommand
+func WslListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
+ resp, err := sendRpcRequestCallHelper[[]string](w, "wsllist", nil, opts)
+ return resp, err
+}
+
+// command "wslstatus", wshserver.WslStatusCommand
+func WslStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) {
+ resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "wslstatus", nil, opts)
+ return resp, err
+}
+
diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go
index fbd06a27d..19754bec7 100644
--- a/pkg/wshrpc/wshrpctypes.go
+++ b/pkg/wshrpc/wshrpctypes.go
@@ -11,6 +11,7 @@ import (
"reflect"
"github.com/wavetermdev/waveterm/pkg/ijson"
+ "github.com/wavetermdev/waveterm/pkg/vdom"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
@@ -27,6 +28,7 @@ const (
const (
Command_Authenticate = "authenticate" // special
+ Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only)
Command_RouteAnnounce = "routeannounce" // special (for routing)
Command_RouteUnannounce = "routeunannounce" // special (for routing)
Command_Message = "message"
@@ -61,14 +63,22 @@ const (
Command_RemoteFileDelete = "remotefiledelete"
Command_RemoteFileJoiin = "remotefilejoin"
+ Command_ConnStatus = "connstatus"
+ Command_WslStatus = "wslstatus"
Command_ConnEnsure = "connensure"
Command_ConnReinstallWsh = "connreinstallwsh"
Command_ConnConnect = "connconnect"
Command_ConnDisconnect = "conndisconnect"
Command_ConnList = "connlist"
+ Command_WslList = "wsllist"
+ Command_WslDefaultDistro = "wsldefaultdistro"
Command_WebSelector = "webselector"
Command_Notify = "notify"
+
+ Command_VDomCreateContext = "vdomcreatecontext"
+ Command_VDomAsyncInitiation = "vdomasyncinitiation"
+ Command_VDomRender = "vdomrender"
)
type RespOrErrorUnion[T any] struct {
@@ -78,6 +88,7 @@ type RespOrErrorUnion[T any] struct {
type WshRpcInterface interface {
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
+ DisposeCommand(ctx context.Context, data CommandDisposeData) error
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
@@ -92,7 +103,10 @@ type WshRpcInterface interface {
FileAppendIJsonCommand(ctx context.Context, data CommandAppendIJsonData) error
ResolveIdsCommand(ctx context.Context, data CommandResolveIdsData) (CommandResolveIdsRtnData, error)
CreateBlockCommand(ctx context.Context, data CommandCreateBlockData) (waveobj.ORef, error)
+ CreateSubBlockCommand(ctx context.Context, data CommandCreateSubBlockData) (waveobj.ORef, error)
DeleteBlockCommand(ctx context.Context, data CommandDeleteBlockData) error
+ DeleteSubBlockCommand(ctx context.Context, data CommandDeleteBlockData) error
+ WaitForRouteCommand(ctx context.Context, data CommandWaitForRouteData) (bool, error)
FileWriteCommand(ctx context.Context, data CommandFileData) error
FileReadCommand(ctx context.Context, data CommandFileData) (string, error)
EventPublishCommand(ctx context.Context, data wps.WaveEvent) error
@@ -109,11 +123,14 @@ type WshRpcInterface interface {
// connection functions
ConnStatusCommand(ctx context.Context) ([]ConnStatus, error)
+ WslStatusCommand(ctx context.Context) ([]ConnStatus, error)
ConnEnsureCommand(ctx context.Context, connName string) error
ConnReinstallWshCommand(ctx context.Context, connName string) error
ConnConnectCommand(ctx context.Context, connName string) error
ConnDisconnectCommand(ctx context.Context, connName string) error
ConnListCommand(ctx context.Context) ([]string, error)
+ WslListCommand(ctx context.Context) ([]string, error)
+ WslDefaultDistroCommand(ctx context.Context) (string, error)
// eventrecv is special, it's handled internally by WshRpc with EventListener
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
@@ -126,8 +143,16 @@ type WshRpcInterface interface {
RemoteFileJoinCommand(ctx context.Context, paths []string) (*FileInfo, error)
RemoteStreamCpuDataCommand(ctx context.Context) chan RespOrErrorUnion[TimeSeriesData]
+ // emain
WebSelectorCommand(ctx context.Context, data CommandWebSelectorData) ([]string, error)
NotifyCommand(ctx context.Context, notificationOptions WaveNotificationOptions) error
+
+ // terminal
+ VDomCreateContextCommand(ctx context.Context, data vdom.VDomCreateContext) (*waveobj.ORef, error)
+ VDomAsyncInitiationCommand(ctx context.Context, data vdom.VDomAsyncInitiationRequest) error
+
+ // proc
+ VDomRenderCommand(ctx context.Context, data vdom.VDomFrontendUpdate) (*vdom.VDomBackendUpdate, error)
}
// for frontend
@@ -187,7 +212,13 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
}
type CommandAuthenticateRtnData struct {
+ RouteId string `json:"routeid"`
+ AuthToken string `json:"authtoken,omitempty"`
+}
+
+type CommandDisposeData struct {
RouteId string `json:"routeid"`
+ // auth token travels in the packet directly
}
type CommandMessageData struct {
@@ -220,6 +251,11 @@ type CommandCreateBlockData struct {
Magnified bool `json:"magnified,omitempty"`
}
+type CommandCreateSubBlockData struct {
+ ParentBlockId string `json:"parentblockid"`
+ BlockDef *waveobj.BlockDef `json:"blockdef"`
+}
+
type CommandBlockSetViewData struct {
BlockId string `json:"blockid" wshcontext:"BlockId"`
View string `json:"view"`
@@ -251,6 +287,11 @@ type CommandAppendIJsonData struct {
Data ijson.Command `json:"data"`
}
+type CommandWaitForRouteData struct {
+ RouteId string `json:"routeid"`
+ WaitMs int `json:"waitms"`
+}
+
type CommandDeleteBlockData struct {
BlockId string `json:"blockid" wshcontext:"BlockId"`
}
@@ -374,10 +415,10 @@ type CommandWebSelectorData struct {
}
type BlockInfoData struct {
- BlockId string `json:"blockid"`
- TabId string `json:"tabid"`
- WindowId string `json:"windowid"`
- Meta waveobj.MetaMapType `json:"meta"`
+ BlockId string `json:"blockid"`
+ TabId string `json:"tabid"`
+ WindowId string `json:"windowid"`
+ Block *waveobj.Block `json:"block"`
}
type WaveNotificationOptions struct {
diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go
index 91facb368..b08de1f45 100644
--- a/pkg/wshrpc/wshserver/wshserver.go
+++ b/pkg/wshrpc/wshserver/wshserver.go
@@ -21,6 +21,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/remote"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
+ "github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveai"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
@@ -29,6 +30,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
+ "github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@@ -36,6 +38,7 @@ const SimpleId_This = "this"
const SimpleId_Tab = "tab"
var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`)
+var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"}
type WshServer struct{}
@@ -120,7 +123,7 @@ func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetM
}
func (ws *WshServer) SetMetaCommand(ctx context.Context, data wshrpc.CommandSetMetaData) error {
- log.Printf("SETMETA: %s | %v\n", data.ORef, data.Meta)
+ log.Printf("SetMetaCommand: %s | %v\n", data.ORef, data.Meta)
oref := data.ORef
err := wstore.UpdateObjectMeta(ctx, oref, data.Meta)
if err != nil {
@@ -247,6 +250,16 @@ func (ws *WshServer) CreateBlockCommand(ctx context.Context, data wshrpc.Command
return &waveobj.ORef{OType: waveobj.OType_Block, OID: blockRef.OID}, nil
}
+func (ws *WshServer) CreateSubBlockCommand(ctx context.Context, data wshrpc.CommandCreateSubBlockData) (*waveobj.ORef, error) {
+ parentBlockId := data.ParentBlockId
+ blockData, err := wcore.CreateSubBlock(ctx, parentBlockId, data.BlockDef)
+ if err != nil {
+ return nil, fmt.Errorf("error creating block: %w", err)
+ }
+ blockRef := &waveobj.ORef{OType: waveobj.OType_Block, OID: blockData.OID}
+ return blockRef, nil
+}
+
func (ws *WshServer) SetViewCommand(ctx context.Context, data wshrpc.CommandBlockSetViewData) error {
log.Printf("SETVIEW: %s | %q\n", data.BlockId, data.View)
ctx = waveobj.ContextWithUpdates(ctx)
@@ -353,10 +366,10 @@ func (ws *WshServer) FileAppendCommand(ctx context.Context, data wshrpc.CommandF
func (ws *WshServer) FileAppendIJsonCommand(ctx context.Context, data wshrpc.CommandAppendIJsonData) error {
tryCreate := true
- if data.FileName == blockcontroller.BlockFile_Html && tryCreate {
+ if data.FileName == blockcontroller.BlockFile_VDom && tryCreate {
err := filestore.WFS.MakeFile(ctx, data.ZoneId, data.FileName, nil, filestore.FileOptsType{MaxSize: blockcontroller.DefaultHtmlMaxFileSize, IJson: true})
if err != nil && err != fs.ErrExist {
- return fmt.Errorf("error creating blockfile[html]: %w", err)
+ return fmt.Errorf("error creating blockfile[vdom]: %w", err)
}
}
err := filestore.WFS.AppendIJson(ctx, data.ZoneId, data.FileName, data.Data)
@@ -376,6 +389,14 @@ func (ws *WshServer) FileAppendIJsonCommand(ctx context.Context, data wshrpc.Com
return nil
}
+func (ws *WshServer) DeleteSubBlockCommand(ctx context.Context, data wshrpc.CommandDeleteBlockData) error {
+ err := wcore.DeleteBlock(ctx, data.BlockId)
+ if err != nil {
+ return fmt.Errorf("error deleting block: %w", err)
+ }
+ return nil
+}
+
func (ws *WshServer) DeleteBlockCommand(ctx context.Context, data wshrpc.CommandDeleteBlockData) error {
ctx = waveobj.ContextWithUpdates(ctx)
tabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId)
@@ -392,7 +413,7 @@ func (ws *WshServer) DeleteBlockCommand(ctx context.Context, data wshrpc.Command
if windowId == "" {
return fmt.Errorf("no window found for tab")
}
- err = wcore.DeleteBlock(ctx, tabId, data.BlockId)
+ err = wcore.DeleteBlock(ctx, data.BlockId)
if err != nil {
return fmt.Errorf("error deleting block: %w", err)
}
@@ -405,6 +426,13 @@ func (ws *WshServer) DeleteBlockCommand(ctx context.Context, data wshrpc.Command
return nil
}
+func (ws *WshServer) WaitForRouteCommand(ctx context.Context, data wshrpc.CommandWaitForRouteData) (bool, error) {
+ waitCtx, cancelFn := context.WithTimeout(ctx, time.Duration(data.WaitMs)*time.Millisecond)
+ defer cancelFn()
+ err := wshutil.DefaultRouter.WaitForRegister(waitCtx, data.RouteId)
+ return err == nil, nil
+}
+
func (ws *WshServer) EventRecvCommand(ctx context.Context, data wps.WaveEvent) error {
return nil
}
@@ -463,11 +491,28 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
return rtn, nil
}
+func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
+ rtn := wsl.GetAllConnStatus()
+ return rtn, nil
+}
+
func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error {
+ if strings.HasPrefix(connName, "wsl://") {
+ distroName := strings.TrimPrefix(connName, "wsl://")
+ return wsl.EnsureConnection(ctx, distroName)
+ }
return conncontroller.EnsureConnection(ctx, connName)
}
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
+ if strings.HasPrefix(connName, "wsl://") {
+ distroName := strings.TrimPrefix(connName, "wsl://")
+ conn := wsl.GetWslConn(ctx, distroName, false)
+ if conn == nil {
+ return fmt.Errorf("distro not found: %s", connName)
+ }
+ return conn.Close()
+ }
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)
@@ -480,6 +525,14 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string)
}
func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error {
+ if strings.HasPrefix(connName, "wsl://") {
+ distroName := strings.TrimPrefix(connName, "wsl://")
+ conn := wsl.GetWslConn(ctx, distroName, false)
+ if conn == nil {
+ return fmt.Errorf("connection not found: %s", connName)
+ }
+ return conn.Connect(ctx)
+ }
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)
@@ -492,6 +545,14 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) er
}
func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error {
+ if strings.HasPrefix(connName, "wsl://") {
+ distroName := strings.TrimPrefix(connName, "wsl://")
+ conn := wsl.GetWslConn(ctx, distroName, false)
+ if conn == nil {
+ return fmt.Errorf("connection not found: %s", connName)
+ }
+ return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
+ }
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)
@@ -507,6 +568,33 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) {
return conncontroller.GetConnectionsList()
}
+func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) {
+ distros, err := wsl.RegisteredDistros(ctx)
+ if err != nil {
+ return nil, err
+ }
+ var distroNames []string
+ for _, distro := range distros {
+ distroName := distro.Name()
+ if utilfn.ContainsStr(InvalidWslDistroNames, distroName) {
+ continue
+ }
+ distroNames = append(distroNames, distroName)
+ }
+ return distroNames, nil
+}
+
+func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error) {
+ distro, ok, err := wsl.DefaultDistro(ctx)
+ if err != nil {
+ return "", fmt.Errorf("unable to determine default distro: %w", err)
+ }
+ if !ok {
+ return "", fmt.Errorf("unable to determine default distro")
+ }
+ return distro.Name(), nil
+}
+
func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) {
blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
if err != nil {
@@ -524,6 +612,6 @@ func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wsh
BlockId: blockId,
TabId: tabId,
WindowId: windowId,
- Meta: blockData.Meta,
+ Block: blockData,
}, nil
}
diff --git a/pkg/wshutil/wshmultiproxy.go b/pkg/wshutil/wshmultiproxy.go
new file mode 100644
index 000000000..be2888bf1
--- /dev/null
+++ b/pkg/wshutil/wshmultiproxy.go
@@ -0,0 +1,151 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package wshutil
+
+import (
+ "encoding/json"
+ "fmt"
+ "sync"
+
+ "github.com/google/uuid"
+ "github.com/wavetermdev/waveterm/pkg/wshrpc"
+)
+
+type multiProxyRouteInfo struct {
+ RouteId string
+ AuthToken string
+ Proxy *WshRpcProxy
+ RpcContext *wshrpc.RpcContext
+}
+
+// handles messages from multiple unauthenitcated clients
+type WshRpcMultiProxy struct {
+ Lock *sync.Mutex
+ RouteInfo map[string]*multiProxyRouteInfo // authtoken to info
+ ToRemoteCh chan []byte
+ FromRemoteRawCh chan []byte // raw message from the remote
+}
+
+func MakeRpcMultiProxy() *WshRpcMultiProxy {
+ return &WshRpcMultiProxy{
+ Lock: &sync.Mutex{},
+ RouteInfo: make(map[string]*multiProxyRouteInfo),
+ ToRemoteCh: make(chan []byte, DefaultInputChSize),
+ FromRemoteRawCh: make(chan []byte, DefaultOutputChSize),
+ }
+}
+
+func (p *WshRpcMultiProxy) DisposeRoutes() {
+ p.Lock.Lock()
+ defer p.Lock.Unlock()
+ for authToken, routeInfo := range p.RouteInfo {
+ DefaultRouter.UnregisterRoute(routeInfo.RouteId)
+ delete(p.RouteInfo, authToken)
+ }
+}
+
+func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
+ p.Lock.Lock()
+ defer p.Lock.Unlock()
+ return p.RouteInfo[authToken]
+}
+
+func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) {
+ p.Lock.Lock()
+ defer p.Lock.Unlock()
+ p.RouteInfo[authToken] = routeInfo
+}
+
+func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) {
+ p.Lock.Lock()
+ defer p.Lock.Unlock()
+ delete(p.RouteInfo, authToken)
+}
+
+func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) {
+ if msg.ReqId == "" {
+ // no response needed
+ return
+ }
+ resp := RpcMessage{
+ ResId: msg.ReqId,
+ Error: sendErr.Error(),
+ }
+ respBytes, _ := json.Marshal(resp)
+ p.ToRemoteCh <- respBytes
+}
+
+func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) {
+ if msg.ReqId == "" {
+ // no response needed
+ return
+ }
+ resp := RpcMessage{
+ ResId: msg.ReqId,
+ Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken},
+ }
+ respBytes, _ := json.Marshal(resp)
+ p.ToRemoteCh <- respBytes
+}
+
+func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
+ var msg RpcMessage
+ err := json.Unmarshal(msgBytes, &msg)
+ if err != nil {
+ // nothing to do here, malformed message
+ return
+ }
+ if msg.Command == wshrpc.Command_Authenticate {
+ rpcContext, routeId, err := handleAuthenticationCommand(msg)
+ if err != nil {
+ p.sendResponseError(msg, err)
+ return
+ }
+ routeInfo := &multiProxyRouteInfo{
+ RouteId: routeId,
+ AuthToken: uuid.New().String(),
+ RpcContext: rpcContext,
+ }
+ routeInfo.Proxy = MakeRpcProxy()
+ routeInfo.Proxy.SetRpcContext(rpcContext)
+ p.setRouteInfo(routeInfo.AuthToken, routeInfo)
+ p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
+ go func() {
+ for msgBytes := range routeInfo.Proxy.ToRemoteCh {
+ p.ToRemoteCh <- msgBytes
+ }
+ }()
+ DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
+ return
+ }
+ if msg.AuthToken == "" {
+ p.sendResponseError(msg, fmt.Errorf("no auth token"))
+ return
+ }
+ routeInfo := p.getRouteInfo(msg.AuthToken)
+ if routeInfo == nil {
+ p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
+ return
+ }
+ if msg.Command != "" && msg.Source != routeInfo.RouteId {
+ p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
+ return
+ }
+ if msg.Command == wshrpc.Command_Dispose {
+ DefaultRouter.UnregisterRoute(routeInfo.RouteId)
+ p.removeRouteInfo(msg.AuthToken)
+ close(routeInfo.Proxy.ToRemoteCh)
+ close(routeInfo.Proxy.FromRemoteCh)
+ return
+ }
+ routeInfo.Proxy.FromRemoteCh <- msgBytes
+}
+
+func (p *WshRpcMultiProxy) RunUnauthLoop() {
+ // loop over unauthenticated message
+ // handle Authenicate commands, and pass authenticated messages to the AuthCh
+ for msgBytes := range p.FromRemoteRawCh {
+ p.handleUnauthMessage(msgBytes)
+ }
+}
diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go
index c6a1ecf9f..c919b5d07 100644
--- a/pkg/wshutil/wshproxy.go
+++ b/pkg/wshutil/wshproxy.go
@@ -6,7 +6,6 @@ package wshutil
import (
"encoding/json"
"fmt"
- "log"
"sync"
"github.com/google/uuid"
@@ -18,6 +17,7 @@ type WshRpcProxy struct {
RpcContext *wshrpc.RpcContext
ToRemoteCh chan []byte
FromRemoteCh chan []byte
+ AuthToken string
}
func MakeRpcProxy() *WshRpcProxy {
@@ -40,6 +40,18 @@ func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext {
return p.RpcContext
}
+func (p *WshRpcProxy) SetAuthToken(authToken string) {
+ p.Lock.Lock()
+ defer p.Lock.Unlock()
+ p.AuthToken = authToken
+}
+
+func (p *WshRpcProxy) GetAuthToken() string {
+ p.Lock.Lock()
+ defer p.Lock.Unlock()
+ return p.AuthToken
+}
+
func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
if msg.ReqId == "" {
// no response needed
@@ -54,7 +66,7 @@ func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
p.SendRpcMessage(respBytes)
}
-func (p *WshRpcProxy) sendResponse(msg RpcMessage, routeId string) {
+func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) {
if msg.ReqId == "" {
// no response needed
return
@@ -98,6 +110,49 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
return newCtx, routeId, nil
}
+// runs on the client (stdio client)
+func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
+ for {
+ msgBytes, ok := <-p.FromRemoteCh
+ if !ok {
+ return "", fmt.Errorf("remote closed, not authenticated")
+ }
+ var origMsg RpcMessage
+ err := json.Unmarshal(msgBytes, &origMsg)
+ if err != nil {
+ // nothing to do, can't even send a response since we don't have Source or ReqId
+ continue
+ }
+ if origMsg.Command == "" {
+ // this message is not allowed (protocol error at this point), ignore
+ continue
+ }
+ // we only allow one command "authenticate", everything else returns an error
+ if origMsg.Command != wshrpc.Command_Authenticate {
+ respErr := fmt.Errorf("connection not authenticated")
+ p.sendResponseError(origMsg, respErr)
+ continue
+ }
+ authRtn, err := router.HandleProxyAuth(origMsg.Data)
+ if err != nil {
+ respErr := fmt.Errorf("error handling proxy auth: %w", err)
+ p.sendResponseError(origMsg, respErr)
+ return "", respErr
+ }
+ p.SetAuthToken(authRtn.AuthToken)
+ announceMsg := RpcMessage{
+ Command: wshrpc.Command_RouteAnnounce,
+ Source: authRtn.RouteId,
+ AuthToken: authRtn.AuthToken,
+ }
+ announceBytes, _ := json.Marshal(announceMsg)
+ router.InjectMessage(announceBytes, authRtn.RouteId)
+ p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
+ return authRtn.RouteId, nil
+ }
+}
+
+// runs on the server
func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
for {
msgBytes, ok := <-p.FromRemoteCh
@@ -122,11 +177,10 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
}
newCtx, routeId, err := handleAuthenticationCommand(msg)
if err != nil {
- log.Printf("error handling authentication: %v\n", err)
p.sendResponseError(msg, err)
continue
}
- p.sendResponse(msg, routeId)
+ p.sendAuthenticateResponse(msg, routeId)
return newCtx, nil
}
}
@@ -136,9 +190,10 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte) {
}
func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
- msgBytes, ok := <-p.FromRemoteCh
- if !ok || p.RpcContext == nil {
- return msgBytes, ok
+ msgBytes, more := <-p.FromRemoteCh
+ authToken := p.GetAuthToken()
+ if !more || (p.RpcContext == nil && authToken == "") {
+ return msgBytes, more
}
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
@@ -146,10 +201,15 @@ func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
// nothing to do here -- will error out at another level
return msgBytes, true
}
- msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
- if err != nil {
- // nothing to do here -- will error out at another level
- return msgBytes, true
+ if p.RpcContext != nil {
+ msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
+ if err != nil {
+ // nothing to do here -- will error out at another level
+ return msgBytes, true
+ }
+ }
+ if msg.AuthToken == "" {
+ msg.AuthToken = authToken
}
newBytes, err := json.Marshal(msg)
if err != nil {
diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go
index 6389e3d99..10b5df517 100644
--- a/pkg/wshutil/wshrouter.go
+++ b/pkg/wshutil/wshrouter.go
@@ -12,11 +12,14 @@ import (
"sync"
"time"
+ "github.com/google/uuid"
+ "github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
const DefaultRoute = "wavesrv"
+const UpstreamRoute = "upstream"
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
const ElectronRoute = "electron"
@@ -36,12 +39,13 @@ type msgAndRoute struct {
}
type WshRouter struct {
- Lock *sync.Mutex
- RouteMap map[string]AbstractRpcClient // routeid => client
- UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
- AnnouncedRoutes map[string]string // routeid => local routeid
- RpcMap map[string]*routeInfo // rpcid => routeinfo
- InputCh chan msgAndRoute
+ Lock *sync.Mutex
+ RouteMap map[string]AbstractRpcClient // routeid => client
+ UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
+ AnnouncedRoutes map[string]string // routeid => local routeid
+ RpcMap map[string]*routeInfo // rpcid => routeinfo
+ SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
+ InputCh chan msgAndRoute
}
func MakeConnectionRouteId(connId string) string {
@@ -52,23 +56,28 @@ func MakeControllerRouteId(blockId string) string {
return "controller:" + blockId
}
-func MakeWindowRouteId(windowId string) string {
- return "window:" + windowId
-}
-
func MakeProcRouteId(procId string) string {
return "proc:" + procId
}
+func MakeTabRouteId(tabId string) string {
+ return "tab:" + tabId
+}
+
+func MakeFeBlockRouteId(blockId string) string {
+ return "feblock:" + blockId
+}
+
var DefaultRouter = NewWshRouter()
func NewWshRouter() *WshRouter {
rtn := &WshRouter{
- Lock: &sync.Mutex{},
- RouteMap: make(map[string]AbstractRpcClient),
- AnnouncedRoutes: make(map[string]string),
- RpcMap: make(map[string]*routeInfo),
- InputCh: make(chan msgAndRoute, DefaultInputChSize),
+ Lock: &sync.Mutex{},
+ RouteMap: make(map[string]AbstractRpcClient),
+ AnnouncedRoutes: make(map[string]string),
+ RpcMap: make(map[string]*routeInfo),
+ SimpleRequestMap: make(map[string]chan *RpcMessage),
+ InputCh: make(chan msgAndRoute, DefaultInputChSize),
}
go rtn.runServer()
return rtn
@@ -233,6 +242,10 @@ func (router *WshRouter) runServer() {
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
continue
} else if msg.ResId != "" {
+ ok := router.trySimpleResponse(&msg)
+ if ok {
+ continue
+ }
routeInfo := router.getRouteInfo(msg.ResId)
if routeInfo == nil {
// no route info, nothing to do
@@ -255,6 +268,9 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
if router.GetRpc(routeId) != nil {
return nil
}
+ if router.getAnnouncedRoute(routeId) != "" {
+ return nil
+ }
select {
case <-ctx.Done():
return ctx.Err()
@@ -265,10 +281,10 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
}
// this will also consume the output channel of the abstract client
-func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
- if routeId == SysRoute {
+func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
+ if routeId == SysRoute || routeId == UpstreamRoute {
// cannot register sys route
- log.Printf("error: WshRouter cannot register sys route\n")
+ log.Printf("error: WshRouter cannot register %s route\n", routeId)
return
}
log.Printf("[router] registering wsh route %q\n", routeId)
@@ -281,7 +297,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
router.RouteMap[routeId] = rpc
go func() {
// announce
- if !alreadyExists && router.GetUpstreamClient() != nil {
+ if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil {
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
announceBytes, _ := json.Marshal(announceMsg)
router.GetUpstreamClient().SendRpcMessage(announceBytes)
@@ -326,6 +342,7 @@ func (router *WshRouter) UnregisterRoute(routeId string) {
}
go func() {
wps.Broker.UnsubscribeAll(routeId)
+ wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteGone, Scopes: []string{routeId}})
}()
}
@@ -347,3 +364,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
defer router.Lock.Unlock()
return router.UpstreamClient
}
+
+func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
+ router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
+}
+
+func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
+ router.Lock.Lock()
+ defer router.Lock.Unlock()
+ rtn := make(chan *RpcMessage, 1)
+ router.SimpleRequestMap[reqId] = rtn
+ return rtn
+}
+
+func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
+ router.Lock.Lock()
+ defer router.Lock.Unlock()
+ respCh := router.SimpleRequestMap[msg.ResId]
+ if respCh == nil {
+ return false
+ }
+ respCh <- msg
+ delete(router.SimpleRequestMap, msg.ResId)
+ return true
+}
+
+func (router *WshRouter) clearSimpleRequest(reqId string) {
+ router.Lock.Lock()
+ defer router.Lock.Unlock()
+ delete(router.SimpleRequestMap, reqId)
+}
+
+func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
+ if msg.Command == "" {
+ return nil, errors.New("no command")
+ }
+ msgBytes, err := json.Marshal(msg)
+ if err != nil {
+ return nil, err
+ }
+ var respCh chan *RpcMessage
+ if msg.ReqId != "" {
+ respCh = router.registerSimpleRequest(msg.ReqId)
+ }
+ router.InjectMessage(msgBytes, fromRouteId)
+ if respCh == nil {
+ return nil, nil
+ }
+ select {
+ case <-ctx.Done():
+ router.clearSimpleRequest(msg.ReqId)
+ return nil, ctx.Err()
+ case resp := <-respCh:
+ if resp.Error != "" {
+ return nil, errors.New(resp.Error)
+ }
+ return resp, nil
+ }
+}
+
+func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
+ if jwtTokenAny == nil {
+ return nil, errors.New("no jwt token")
+ }
+ jwtToken, ok := jwtTokenAny.(string)
+ if !ok {
+ return nil, errors.New("jwt token not a string")
+ }
+ if jwtToken == "" {
+ return nil, errors.New("empty jwt token")
+ }
+ msg := RpcMessage{
+ Command: wshrpc.Command_Authenticate,
+ ReqId: uuid.New().String(),
+ Data: jwtToken,
+ }
+ ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
+ defer cancelFn()
+ resp, err := router.RunSimpleRawCommand(ctx, msg, "")
+ if err != nil {
+ return nil, err
+ }
+ if resp == nil || resp.Data == nil {
+ return nil, errors.New("no data in authenticate response")
+ }
+ var respData wshrpc.CommandAuthenticateRtnData
+ err = utilfn.ReUnmarshal(&respData, resp.Data)
+ if err != nil {
+ return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
+ }
+ if respData.AuthToken == "" {
+ return nil, errors.New("no auth token in authenticate response")
+ }
+ return &respData, nil
+}
diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go
index cccc353e7..7d31246ac 100644
--- a/pkg/wshutil/wshrpc.go
+++ b/pkg/wshutil/wshrpc.go
@@ -45,10 +45,13 @@ type WshRpc struct {
InputCh chan []byte
OutputCh chan []byte
RpcContext *atomic.Pointer[wshrpc.RpcContext]
+ AuthToken string
RpcMap map[string]*rpcData
ServerImpl ServerImpl
EventListener *EventListener
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
+ Debug bool
+ DebugName string
}
type wshRpcContextKey struct{}
@@ -104,17 +107,18 @@ func (w *WshRpc) RecvRpcMessage() ([]byte, bool) {
}
type RpcMessage struct {
- Command string `json:"command,omitempty"`
- ReqId string `json:"reqid,omitempty"`
- ResId string `json:"resid,omitempty"`
- Timeout int `json:"timeout,omitempty"`
- Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
- Source string `json:"source,omitempty"` // source route id
- Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
- Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
- Error string `json:"error,omitempty"`
- DataType string `json:"datatype,omitempty"`
- Data any `json:"data,omitempty"`
+ Command string `json:"command,omitempty"`
+ ReqId string `json:"reqid,omitempty"`
+ ResId string `json:"resid,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
+ Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
+ AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy)
+ Source string `json:"source,omitempty"` // source route id
+ Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
+ Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
+ Error string `json:"error,omitempty"`
+ DataType string `json:"datatype,omitempty"`
+ Data any `json:"data,omitempty"`
}
func (r *RpcMessage) IsRpcRequest() bool {
@@ -226,6 +230,14 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) {
w.RpcContext.Store(&ctx)
}
+func (w *WshRpc) SetAuthToken(token string) {
+ w.AuthToken = token
+}
+
+func (w *WshRpc) GetAuthToken() string {
+ return w.AuthToken
+}
+
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
w.Lock.Lock()
defer w.Lock.Unlock()
@@ -323,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
func (w *WshRpc) runServer() {
defer close(w.OutputCh)
for msgBytes := range w.InputCh {
+ if w.Debug {
+ log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
+ }
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
if err != nil {
@@ -455,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() {
}
}()
msg := &RpcMessage{
- Cancel: true,
- ReqId: handler.reqId,
+ Cancel: true,
+ ReqId: handler.reqId,
+ AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
@@ -550,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
Data: wshrpc.CommandMessageData{
Message: msg,
},
+ AuthToken: handler.w.GetAuthToken(),
}
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
handler.w.OutputCh <- msgBytes
@@ -573,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
defer handler.close()
}
msg := &RpcMessage{
- ResId: handler.reqId,
- Data: data,
- Cont: !done,
+ ResId: handler.reqId,
+ Data: data,
+ Cont: !done,
+ AuthToken: handler.w.GetAuthToken(),
}
barr, err := json.Marshal(msg)
if err != nil {
@@ -598,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
}
defer handler.close()
msg := &RpcMessage{
- ResId: handler.reqId,
- Error: err.Error(),
+ ResId: handler.reqId,
+ Error: err.Error(),
+ AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
@@ -660,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
handler.reqId = uuid.New().String()
}
req := &RpcMessage{
- Command: command,
- ReqId: handler.reqId,
- Data: data,
- Timeout: timeoutMs,
- Route: opts.Route,
+ Command: command,
+ ReqId: handler.reqId,
+ Data: data,
+ Timeout: timeoutMs,
+ Route: opts.Route,
+ AuthToken: w.GetAuthToken(),
}
barr, err := json.Marshal(req)
if err != nil {
diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go
index 79cdc6080..8be9c908a 100644
--- a/pkg/wshutil/wshutil.go
+++ b/pkg/wshutil/wshutil.go
@@ -19,6 +19,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
+ "github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"golang.org/x/term"
@@ -204,11 +205,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
continue
}
os.Stdout.Write(barr)
+ os.Stdout.Write([]byte{'\n'})
}
}()
return rpcClient, ptyBuf
}
+func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
+ messageCh := make(chan []byte, DefaultInputChSize)
+ outputCh := make(chan []byte, DefaultOutputChSize)
+ rawCh := make(chan []byte, DefaultOutputChSize)
+ rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
+ go packetparser.Parse(input, messageCh, rawCh)
+ go func() {
+ for msg := range outputCh {
+ packetparser.WritePacket(output, msg)
+ }
+ }()
+ return rpcClient, rawCh
+}
+
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
@@ -229,10 +245,22 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err
return rtn, writeErrCh, nil
}
-func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
- conn, err := net.Dial("unix", sockName)
+func tryTcpSocket(sockName string) (net.Conn, error) {
+ addr, err := net.ResolveTCPAddr("tcp", sockName)
if err != nil {
- return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err)
+ return nil, err
+ }
+ return net.DialTCP("tcp", nil, addr)
+}
+
+func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
+ conn, tcpErr := tryTcpSocket(sockName)
+ var unixErr error
+ if tcpErr != nil {
+ conn, unixErr = net.Dial("unix", sockName)
+ }
+ if tcpErr != nil && unixErr != nil {
+ return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr)
}
rtn, errCh, err := SetupConnRpcClient(conn, serverImpl)
go func() {
@@ -363,6 +391,46 @@ func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) {
return MakeProcRouteId(procId), nil
}
+type WriteFlusher interface {
+ Write([]byte) (int, error)
+ Flush() error
+}
+
+// blocking, returns if there is an error, or on EOF of input
+func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
+ proxy := MakeRpcMultiProxy()
+ rawCh := make(chan []byte, DefaultInputChSize)
+ go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
+ doneCh := make(chan struct{})
+ var doneOnce sync.Once
+ closeDoneCh := func() {
+ doneOnce.Do(func() {
+ close(doneCh)
+ })
+ proxy.DisposeRoutes()
+ }
+ go func() {
+ proxy.RunUnauthLoop()
+ }()
+ go func() {
+ defer closeDoneCh()
+ for msg := range proxy.ToRemoteCh {
+ err := packetparser.WritePacket(output, msg)
+ if err != nil {
+ log.Printf("[%s] error writing to output: %v\n", logName, err)
+ break
+ }
+ }
+ }()
+ go func() {
+ defer closeDoneCh()
+ for msg := range rawCh {
+ log.Printf("[%s:stdout] %s", logName, msg)
+ }
+ }()
+ <-doneCh
+}
+
func handleDomainSocketClient(conn net.Conn) {
var routeIdContainer atomic.Pointer[string]
proxy := MakeRpcProxy()
@@ -399,7 +467,7 @@ func handleDomainSocketClient(conn net.Conn) {
return
}
routeIdContainer.Store(&routeId)
- DefaultRouter.RegisterRoute(routeId, proxy)
+ DefaultRouter.RegisterRoute(routeId, proxy, true)
}
// only for use on client
@@ -433,5 +501,6 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) {
if !ok {
return "", fmt.Errorf("sock claim is missing or invalid")
}
+ sockName = wavebase.ExpandHomeDirSafe(sockName)
return sockName, nil
}
diff --git a/pkg/wsl/wsl-unix.go b/pkg/wsl/wsl-unix.go
new file mode 100644
index 000000000..055e46669
--- /dev/null
+++ b/pkg/wsl/wsl-unix.go
@@ -0,0 +1,67 @@
+//go:build !windows
+
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package wsl
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+)
+
+func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
+ return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
+}
+
+func DefaultDistro(ctx context.Context) (d Distro, ok bool, err error) {
+ return d, false, fmt.Errorf("DefaultDistro not implemented on this system")
+}
+
+type Distro struct{}
+
+func (d *Distro) Name() string {
+ return ""
+}
+
+func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
+ return nil
+}
+
+// just use the regular cmd since it's
+// similar enough to not cause issues
+// type WslCmd = exec.Cmd
+type WslCmd struct {
+ exec.Cmd
+}
+
+func (wc *WslCmd) GetProcess() *os.Process {
+ return nil
+}
+
+func (wc *WslCmd) GetProcessState() *os.ProcessState {
+ return nil
+}
+
+func (c *WslCmd) SetStdin(stdin io.Reader) {
+ c.Stdin = stdin
+}
+
+func (c *WslCmd) SetStdout(stdout io.Writer) {
+ c.Stdout = stdout
+}
+
+func (c *WslCmd) SetStderr(stderr io.Writer) {
+ c.Stdout = stderr
+}
+
+func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
+ return nil, fmt.Errorf("GetDistroCmd not implemented on this system")
+}
+
+func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
+ return nil, fmt.Errorf("GetDistro not implemented on this system")
+}
diff --git a/pkg/wsl/wsl-util.go b/pkg/wsl/wsl-util.go
new file mode 100644
index 000000000..5d1f70d35
--- /dev/null
+++ b/pkg/wsl/wsl-util.go
@@ -0,0 +1,296 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package wsl
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "html/template"
+ "io"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+func DetectShell(ctx context.Context, client *Distro) (string, error) {
+ wshPath := GetWshPath(ctx, client)
+
+ cmd := client.WslCommand(ctx, wshPath+" shell")
+ log.Printf("shell detecting using command: %s shell", wshPath)
+ out, err := cmd.Output()
+ if err != nil {
+ log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
+ return "/bin/bash", nil
+ }
+ log.Printf("detecting shell: %s", out)
+
+ // quoting breaks this particular case
+ return strings.TrimSpace(string(out)), nil
+}
+
+func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
+ wshPath := GetWshPath(ctx, client)
+
+ cmd := client.WslCommand(ctx, wshPath+" version")
+ out, err := cmd.Output()
+ if err != nil {
+ return "", err
+ }
+
+ return strings.TrimSpace(string(out)), nil
+}
+
+func GetWshPath(ctx context.Context, client *Distro) string {
+ defaultPath := "~/.waveterm/bin/wsh"
+
+ cmd := client.WslCommand(ctx, "which wsh")
+ out, whichErr := cmd.Output()
+ if whichErr == nil {
+ return strings.TrimSpace(string(out))
+ }
+
+ cmd = client.WslCommand(ctx, "where.exe wsh")
+ out, whereErr := cmd.Output()
+ if whereErr == nil {
+ return strings.TrimSpace(string(out))
+ }
+
+ // check cmd on windows since it requires an absolute path with backslashes
+ cmd = client.WslCommand(ctx, "(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none")
+ out, cmdErr := cmd.Output()
+ if cmdErr == nil && strings.TrimSpace(string(out)) != "none" {
+ return strings.TrimSpace(string(out))
+ }
+
+ // no custom install, use default path
+ return defaultPath
+}
+
+func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
+ cmd := client.WslCommand(ctx, "which bash")
+ out, whichErr := cmd.Output()
+ if whichErr == nil && len(out) != 0 {
+ return true, nil
+ }
+
+ cmd = client.WslCommand(ctx, "where.exe bash")
+ out, whereErr := cmd.Output()
+ if whereErr == nil && len(out) != 0 {
+ return true, nil
+ }
+
+ // note: we could also check in /bin/bash explicitly
+ // just in case that wasn't added to the path. but if
+ // that's true, we will most likely have worse
+ // problems going forward
+
+ return false, nil
+}
+
+func GetClientOs(ctx context.Context, client *Distro) (string, error) {
+ cmd := client.WslCommand(ctx, "uname -s")
+ out, unixErr := cmd.Output()
+ if unixErr == nil {
+ formatted := strings.ToLower(string(out))
+ formatted = strings.TrimSpace(formatted)
+ return formatted, nil
+ }
+
+ cmd = client.WslCommand(ctx, "echo %OS%")
+ out, cmdErr := cmd.Output()
+ if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
+ formatted := strings.ToLower(string(out))
+ formatted = strings.TrimSpace(formatted)
+ return strings.Split(formatted, "_")[0], nil
+ }
+
+ cmd = client.WslCommand(ctx, "echo $env:OS")
+ out, psErr := cmd.Output()
+ if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
+ formatted := strings.ToLower(string(out))
+ formatted = strings.TrimSpace(formatted)
+ return strings.Split(formatted, "_")[0], nil
+ }
+ return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
+}
+
+func GetClientArch(ctx context.Context, client *Distro) (string, error) {
+ cmd := client.WslCommand(ctx, "uname -m")
+ out, unixErr := cmd.Output()
+ if unixErr == nil {
+ formatted := strings.ToLower(string(out))
+ formatted = strings.TrimSpace(formatted)
+ if formatted == "x86_64" {
+ return "x64", nil
+ }
+ return formatted, nil
+ }
+
+ cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
+ out, cmdErr := cmd.Output()
+ if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
+ formatted := strings.ToLower(string(out))
+ return strings.TrimSpace(formatted), nil
+ }
+
+ cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
+ out, psErr := cmd.Output()
+ if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
+ formatted := strings.ToLower(string(out))
+ return strings.TrimSpace(formatted), nil
+ }
+ return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
+}
+
+type CancellableCmd struct {
+ Cmd *WslCmd
+ Cancel func()
+}
+
+var installTemplatesRawBash = map[string]string{
+ "mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
+ "cat": `bash -c 'cat > {{.tempPath}}'`,
+ "mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
+ "chmod": `bash -c 'chmod a+x {{.installPath}}'`,
+}
+
+var installTemplatesRawDefault = map[string]string{
+ "mkdir": `mkdir -p {{.installDir}}`,
+ "cat": `cat > {{.tempPath}}`,
+ "mv": `mv {{.tempPath}} {{.installPath}}`,
+ "chmod": `chmod a+x {{.installPath}}`,
+}
+
+func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
+ cmdContext, cmdCancel := context.WithCancel(ctx)
+
+ cmdStr := &bytes.Buffer{}
+ cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
+ if err != nil {
+ cmdCancel()
+ return nil, err
+ }
+ cmdTemplate.Execute(cmdStr, words)
+
+ cmd := client.WslCommand(cmdContext, cmdStr.String())
+ return &CancellableCmd{cmd, cmdCancel}, nil
+}
+
+func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
+ // warning: does not work on windows remote yet
+ bashInstalled, err := hasBashInstalled(ctx, client)
+ if err != nil {
+ return err
+ }
+
+ var selectedTemplatesRaw map[string]string
+ if bashInstalled {
+ selectedTemplatesRaw = installTemplatesRawBash
+ } else {
+ log.Printf("bash is not installed on remote. attempting with default shell")
+ selectedTemplatesRaw = installTemplatesRawDefault
+ }
+
+ // I need to use toSlash here to force unix keybindings
+ // this means we can't guarantee it will work on a remote windows machine
+ var installWords = map[string]string{
+ "installDir": filepath.ToSlash(filepath.Dir(destPath)),
+ "tempPath": destPath + ".temp",
+ "installPath": destPath,
+ }
+
+ installStepCmds := make(map[string]*CancellableCmd)
+ for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
+ cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
+ if err != nil {
+ return err
+ }
+ installStepCmds[cmdName] = cancellableCmd
+ }
+
+ _, err = installStepCmds["mkdir"].Cmd.Output()
+ if err != nil {
+ return err
+ }
+
+ // the cat part of this is complicated since it requires stdin
+ catCmd := installStepCmds["cat"].Cmd
+ catStdin, err := catCmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ err = catCmd.Start()
+ if err != nil {
+ return err
+ }
+ input, err := os.Open(sourcePath)
+ if err != nil {
+ return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
+ }
+ go func() {
+ io.Copy(catStdin, input)
+ installStepCmds["cat"].Cancel()
+
+ // backup just in case something weird happens
+ // could cause potential race condition, but very
+ // unlikely
+ time.Sleep(time.Second * 1)
+ process := catCmd.GetProcess()
+ if process != nil {
+ process.Kill()
+ }
+ }()
+ catErr := catCmd.Wait()
+ if catErr != nil && !errors.Is(catErr, context.Canceled) {
+ return catErr
+ }
+
+ _, err = installStepCmds["mv"].Cmd.Output()
+ if err != nil {
+ return err
+ }
+
+ _, err = installStepCmds["chmod"].Cmd.Output()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func InstallClientRcFiles(ctx context.Context, client *Distro) error {
+ path := GetWshPath(ctx, client)
+ log.Printf("path to wsh searched is: %s", path)
+
+ cmd := client.WslCommand(ctx, path+" rcfiles")
+ _, err := cmd.Output()
+ return err
+}
+
+func GetHomeDir(ctx context.Context, client *Distro) string {
+ // note: also works for powershell
+ cmd := client.WslCommand(ctx, `echo "$HOME"`)
+ out, err := cmd.Output()
+ if err == nil {
+ return strings.TrimSpace(string(out))
+ }
+
+ cmd = client.WslCommand(ctx, `echo %userprofile%`)
+ out, err = cmd.Output()
+ if err == nil {
+ return strings.TrimSpace(string(out))
+ }
+
+ return "~"
+}
+
+func IsPowershell(shellPath string) bool {
+ // get the base path, and then check contains
+ shellBase := filepath.Base(shellPath)
+ return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
+}
diff --git a/pkg/wsl/wsl-win.go b/pkg/wsl/wsl-win.go
new file mode 100644
index 000000000..782e15719
--- /dev/null
+++ b/pkg/wsl/wsl-win.go
@@ -0,0 +1,125 @@
+//go:build windows
+
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package wsl
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "sync"
+
+ "github.com/ubuntu/gowsl"
+)
+
+var RegisteredDistros = gowsl.RegisteredDistros
+var DefaultDistro = gowsl.DefaultDistro
+
+type Distro struct {
+ gowsl.Distro
+}
+
+type WslCmd struct {
+ c *gowsl.Cmd
+ wg *sync.WaitGroup
+ once *sync.Once
+ lock *sync.Mutex
+ waitErr error
+}
+
+func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
+ if ctx == nil {
+ panic("nil Context")
+ }
+ innerCmd := d.Command(ctx, cmd)
+ var wg sync.WaitGroup
+ var lock *sync.Mutex
+ return &WslCmd{innerCmd, &wg, new(sync.Once), lock, nil}
+}
+
+func (c *WslCmd) CombinedOutput() (out []byte, err error) {
+ return c.c.CombinedOutput()
+}
+func (c *WslCmd) Output() (out []byte, err error) {
+ return c.c.Output()
+}
+func (c *WslCmd) Run() error {
+ return c.c.Run()
+}
+func (c *WslCmd) Start() (err error) {
+ return c.c.Start()
+}
+func (c *WslCmd) StderrPipe() (r io.ReadCloser, err error) {
+ return c.c.StderrPipe()
+}
+func (c *WslCmd) StdinPipe() (w io.WriteCloser, err error) {
+ return c.c.StdinPipe()
+}
+func (c *WslCmd) StdoutPipe() (r io.ReadCloser, err error) {
+ return c.c.StdoutPipe()
+}
+func (c *WslCmd) Wait() (err error) {
+ c.wg.Add(1)
+ c.once.Do(func() {
+ c.waitErr = c.c.Wait()
+ })
+ c.wg.Done()
+ c.wg.Wait()
+ if c.waitErr != nil && c.waitErr.Error() == "not started" {
+ c.once = new(sync.Once)
+ return c.waitErr
+ }
+ return c.waitErr
+}
+func (c *WslCmd) GetProcess() *os.Process {
+ return c.c.Process
+}
+
+func (c *WslCmd) GetProcessState() *os.ProcessState {
+ return c.c.ProcessState
+}
+
+func (c *WslCmd) SetStdin(stdin io.Reader) {
+ c.c.Stdin = stdin
+}
+
+func (c *WslCmd) SetStdout(stdout io.Writer) {
+ c.c.Stdout = stdout
+}
+
+func (c *WslCmd) SetStderr(stderr io.Writer) {
+ c.c.Stdout = stderr
+}
+
+func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
+ distros, err := RegisteredDistros(ctx)
+ if err != nil {
+ return nil, err
+ }
+ for _, distro := range distros {
+ if distro.Name() != wslDistroName {
+ continue
+ }
+ wrappedDistro := Distro{distro}
+ return wrappedDistro.WslCommand(ctx, cmd), nil
+ }
+ return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
+}
+
+func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
+ distros, err := RegisteredDistros(ctx)
+ if err != nil {
+ return nil, err
+ }
+ for _, distro := range distros {
+ if distro.Name() != wslDistroName.Distro {
+ continue
+ }
+ wrappedDistro := Distro{distro}
+ return &wrappedDistro, nil
+ }
+ return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
+}
diff --git a/pkg/wsl/wsl.go b/pkg/wsl/wsl.go
new file mode 100644
index 000000000..0f5927ebb
--- /dev/null
+++ b/pkg/wsl/wsl.go
@@ -0,0 +1,494 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package wsl
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/wavetermdev/waveterm/pkg/userinput"
+ "github.com/wavetermdev/waveterm/pkg/util/shellutil"
+ "github.com/wavetermdev/waveterm/pkg/wavebase"
+ "github.com/wavetermdev/waveterm/pkg/waveobj"
+ "github.com/wavetermdev/waveterm/pkg/wconfig"
+ "github.com/wavetermdev/waveterm/pkg/wps"
+ "github.com/wavetermdev/waveterm/pkg/wshrpc"
+ "github.com/wavetermdev/waveterm/pkg/wshutil"
+)
+
+const (
+ Status_Init = "init"
+ Status_Connecting = "connecting"
+ Status_Connected = "connected"
+ Status_Disconnected = "disconnected"
+ Status_Error = "error"
+)
+
+const DefaultConnectionTimeout = 60 * time.Second
+
+var globalLock = &sync.Mutex{}
+var clientControllerMap = make(map[string]*WslConn)
+var activeConnCounter = &atomic.Int32{}
+
+type WslConn struct {
+ Lock *sync.Mutex
+ Status string
+ Name WslName
+ Client *Distro
+ SockName string
+ DomainSockListener net.Listener
+ ConnController *WslCmd
+ Error string
+ HasWaiter *atomic.Bool
+ LastConnectTime int64
+ ActiveConnNum int
+ Context context.Context
+ cancelFn func()
+}
+
+type WslName struct {
+ Distro string `json:"distro"`
+}
+
+func GetAllConnStatus() []wshrpc.ConnStatus {
+ globalLock.Lock()
+ defer globalLock.Unlock()
+
+ var connStatuses []wshrpc.ConnStatus
+ for _, conn := range clientControllerMap {
+ connStatuses = append(connStatuses, conn.DeriveConnStatus())
+ }
+ return connStatuses
+}
+
+func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return wshrpc.ConnStatus{
+ Status: conn.Status,
+ Connected: conn.Status == Status_Connected,
+ Connection: conn.GetName(),
+ HasConnected: (conn.LastConnectTime > 0),
+ ActiveConnNum: conn.ActiveConnNum,
+ Error: conn.Error,
+ }
+}
+
+func (conn *WslConn) FireConnChangeEvent() {
+ status := conn.DeriveConnStatus()
+ event := wps.WaveEvent{
+ Event: wps.Event_ConnChange,
+ Scopes: []string{
+ fmt.Sprintf("connection:%s", conn.GetName()),
+ },
+ Data: status,
+ }
+ log.Printf("sending event: %+#v", event)
+ wps.Broker.Publish(event)
+}
+
+func (conn *WslConn) Close() error {
+ defer conn.FireConnChangeEvent()
+ conn.WithLock(func() {
+ if conn.Status == Status_Connected || conn.Status == Status_Connecting {
+ // if status is init, disconnected, or error don't change it
+ conn.Status = Status_Disconnected
+ }
+ conn.close_nolock()
+ })
+ // we must wait for the waiter to complete
+ startTime := time.Now()
+ for conn.HasWaiter.Load() {
+ time.Sleep(10 * time.Millisecond)
+ if time.Since(startTime) > 2*time.Second {
+ return fmt.Errorf("timeout waiting for waiter to complete")
+ }
+ }
+ return nil
+}
+
+func (conn *WslConn) close_nolock() {
+ // does not set status (that should happen at another level)
+ if conn.DomainSockListener != nil {
+ conn.DomainSockListener.Close()
+ conn.DomainSockListener = nil
+ }
+ if conn.ConnController != nil {
+ conn.cancelFn() // this suspends the conn controller
+ conn.ConnController = nil
+ }
+ if conn.Client != nil {
+ // conn.Client.Close() is not relevant here
+ // we do not want to completely close the wsl in case
+ // other applications are using it
+ conn.Client = nil
+ }
+}
+
+func (conn *WslConn) GetDomainSocketName() string {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return conn.SockName
+}
+
+func (conn *WslConn) GetStatus() string {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return conn.Status
+}
+
+func (conn *WslConn) GetName() string {
+ // no lock required because opts is immutable
+ return "wsl://" + conn.Name.Distro
+}
+
+/**
+ * This function is does not set a listener for WslConn
+ * It is still required in order to set SockName
+**/
+func (conn *WslConn) OpenDomainSocketListener() error {
+ var allowed bool
+ conn.WithLock(func() {
+ if conn.Status != Status_Connecting {
+ allowed = false
+ } else {
+ allowed = true
+ }
+ })
+ if !allowed {
+ return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
+ }
+ conn.WithLock(func() {
+ conn.SockName = "~/.waveterm/wave-remote.sock"
+ })
+ return nil
+}
+
+func (conn *WslConn) StartConnServer() error {
+ var allowed bool
+ conn.WithLock(func() {
+ if conn.Status != Status_Connecting {
+ allowed = false
+ } else {
+ allowed = true
+ }
+ })
+ if !allowed {
+ return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
+ }
+ client := conn.GetClient()
+ wshPath := GetWshPath(conn.Context, client)
+ rpcCtx := wshrpc.RpcContext{
+ ClientType: wshrpc.ClientType_ConnServer,
+ Conn: conn.GetName(),
+ }
+ sockName := conn.GetDomainSocketName()
+ jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
+ if err != nil {
+ return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
+ }
+ shellPath, err := DetectShell(conn.Context, client)
+ if err != nil {
+ return err
+ }
+ var cmdStr string
+ if IsPowershell(shellPath) {
+ cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
+ } else {
+ cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
+ }
+ log.Printf("starting conn controller: %s\n", cmdStr)
+ cmd := client.WslCommand(conn.Context, cmdStr)
+ pipeRead, pipeWrite := io.Pipe()
+ inputPipeRead, inputPipeWrite := io.Pipe()
+ cmd.SetStdout(pipeWrite)
+ cmd.SetStderr(pipeWrite)
+ cmd.SetStdin(inputPipeRead)
+ err = cmd.Start()
+ if err != nil {
+ return fmt.Errorf("unable to start conn controller: %w", err)
+ }
+ conn.WithLock(func() {
+ conn.ConnController = cmd
+ })
+ // service the I/O
+ go func() {
+ // wait for termination, clear the controller
+ defer conn.WithLock(func() {
+ conn.ConnController = nil
+ })
+ waitErr := cmd.Wait()
+ log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
+ }()
+ go func() {
+ logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
+ wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
+ }()
+ regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelFn()
+ err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
+ if err != nil {
+ return fmt.Errorf("timeout waiting for connserver to register")
+ }
+ time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
+ return nil
+}
+
+type WshInstallOpts struct {
+ Force bool
+ NoUserPrompt bool
+}
+
+func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
+ if opts == nil {
+ opts = &WshInstallOpts{}
+ }
+ client := conn.GetClient()
+ if client == nil {
+ return fmt.Errorf("client is nil")
+ }
+ // check that correct wsh extensions are installed
+ expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
+ clientVersion, err := GetWshVersion(ctx, client)
+ if err == nil && clientVersion == expectedVersion && !opts.Force {
+ return nil
+ }
+ var queryText string
+ var title string
+ if opts.Force {
+ queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
+ title = "Install Wave Shell Extensions"
+ } else if err != nil {
+ queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
+ "installed on `%s` \n"+
+ "to ensure a seamless experience. \n\n"+
+ "Would you like to install them?", clientDisplayName)
+ title = "Install Wave Shell Extensions"
+ } else {
+ // don't ask for upgrading the version
+ opts.NoUserPrompt = true
+ }
+ if !opts.NoUserPrompt {
+ request := &userinput.UserInputRequest{
+ ResponseType: "confirm",
+ QueryText: queryText,
+ Title: title,
+ Markdown: true,
+ CheckBoxMsg: "Don't show me this again",
+ }
+ response, err := userinput.GetUserInput(ctx, request)
+ if err != nil || !response.Confirm {
+ return err
+ }
+ if response.CheckboxStat {
+ meta := waveobj.MetaMapType{
+ wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
+ }
+ err := wconfig.SetBaseConfigValue(meta)
+ if err != nil {
+ return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
+ }
+ }
+ }
+ log.Printf("attempting to install wsh to `%s`", clientDisplayName)
+ clientOs, err := GetClientOs(ctx, client)
+ if err != nil {
+ return err
+ }
+ clientArch, err := GetClientArch(ctx, client)
+ if err != nil {
+ return err
+ }
+ // attempt to install extension
+ wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
+ err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh")
+ if err != nil {
+ return err
+ }
+ log.Printf("successfully installed wsh on %s\n", conn.GetName())
+ return nil
+}
+
+func (conn *WslConn) GetClient() *Distro {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ return conn.Client
+}
+
+func (conn *WslConn) Reconnect(ctx context.Context) error {
+ err := conn.Close()
+ if err != nil {
+ return err
+ }
+ return conn.Connect(ctx)
+}
+
+func (conn *WslConn) WaitForConnect(ctx context.Context) error {
+ for {
+ status := conn.DeriveConnStatus()
+ if status.Status == Status_Connected {
+ return nil
+ }
+ if status.Status == Status_Connecting {
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("context timeout")
+ case <-time.After(100 * time.Millisecond):
+ continue
+ }
+ }
+ if status.Status == Status_Init || status.Status == Status_Disconnected {
+ return fmt.Errorf("disconnected")
+ }
+ if status.Status == Status_Error {
+ return fmt.Errorf("error: %v", status.Error)
+ }
+ return fmt.Errorf("unknown status: %q", status.Status)
+ }
+}
+
+// does not return an error since that error is stored inside of WslConn
+func (conn *WslConn) Connect(ctx context.Context) error {
+ var connectAllowed bool
+ conn.WithLock(func() {
+ if conn.Status == Status_Connecting || conn.Status == Status_Connected {
+ connectAllowed = false
+ } else {
+ conn.Status = Status_Connecting
+ conn.Error = ""
+ connectAllowed = true
+ }
+ })
+ log.Printf("Connect %s\n", conn.GetName())
+ if !connectAllowed {
+ return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
+ }
+ conn.FireConnChangeEvent()
+ err := conn.connectInternal(ctx)
+ conn.WithLock(func() {
+ if err != nil {
+ conn.Status = Status_Error
+ conn.Error = err.Error()
+ conn.close_nolock()
+ } else {
+ conn.Status = Status_Connected
+ conn.LastConnectTime = time.Now().UnixMilli()
+ if conn.ActiveConnNum == 0 {
+ conn.ActiveConnNum = int(activeConnCounter.Add(1))
+ }
+ }
+ })
+ conn.FireConnChangeEvent()
+ return err
+}
+
+func (conn *WslConn) WithLock(fn func()) {
+ conn.Lock.Lock()
+ defer conn.Lock.Unlock()
+ fn()
+}
+
+func (conn *WslConn) connectInternal(ctx context.Context) error {
+ client, err := GetDistro(ctx, conn.Name)
+ if err != nil {
+ return err
+ }
+ conn.WithLock(func() {
+ conn.Client = client
+ })
+ err = conn.OpenDomainSocketListener()
+ if err != nil {
+ return err
+ }
+ config := wconfig.ReadFullConfig()
+ installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall})
+ if installErr != nil {
+ return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
+ }
+ csErr := conn.StartConnServer()
+ if csErr != nil {
+ return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
+ }
+ conn.HasWaiter.Store(true)
+ go conn.waitForDisconnect()
+ return nil
+}
+
+func (conn *WslConn) waitForDisconnect() {
+ defer conn.FireConnChangeEvent()
+ defer conn.HasWaiter.Store(false)
+ err := conn.ConnController.Wait()
+ conn.WithLock(func() {
+ // disconnects happen for a variety of reasons (like network, etc. and are typically transient)
+ // so we just set the status to "disconnected" here (not error)
+ // don't overwrite any existing error (or error status)
+ if err != nil && conn.Error == "" {
+ conn.Error = err.Error()
+ }
+ if conn.Status != Status_Error {
+ conn.Status = Status_Disconnected
+ }
+ conn.close_nolock()
+ })
+}
+
+func getConnInternal(name string) *WslConn {
+ globalLock.Lock()
+ defer globalLock.Unlock()
+ connName := WslName{Distro: name}
+ rtn := clientControllerMap[name]
+ if rtn == nil {
+ ctx, cancelFn := context.WithCancel(context.Background())
+ rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn}
+ clientControllerMap[name] = rtn
+ }
+ return rtn
+}
+
+func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
+ conn := getConnInternal(name)
+ if conn.Client == nil && shouldConnect {
+ conn.Connect(ctx)
+ }
+ return conn
+}
+
+// Convenience function for ensuring a connection is established
+func EnsureConnection(ctx context.Context, connName string) error {
+ if connName == "" {
+ return nil
+ }
+ conn := GetWslConn(ctx, connName, false)
+ if conn == nil {
+ return fmt.Errorf("connection not found: %s", connName)
+ }
+ connStatus := conn.DeriveConnStatus()
+ switch connStatus.Status {
+ case Status_Connected:
+ return nil
+ case Status_Connecting:
+ return conn.WaitForConnect(ctx)
+ case Status_Init, Status_Disconnected:
+ return conn.Connect(ctx)
+ case Status_Error:
+ return fmt.Errorf("connection error: %s", connStatus.Error)
+ default:
+ return fmt.Errorf("unknown connection status %q", connStatus.Status)
+ }
+}
+
+func DisconnectClient(connName string) error {
+ conn := getConnInternal(connName)
+ if conn == nil {
+ return fmt.Errorf("client %q not found", connName)
+ }
+ err := conn.Close()
+ return err
+}
diff --git a/pkg/wstore/wstore.go b/pkg/wstore/wstore.go
index 1c9824350..0872ec45b 100644
--- a/pkg/wstore/wstore.go
+++ b/pkg/wstore/wstore.go
@@ -95,6 +95,27 @@ func UpdateTabName(ctx context.Context, tabId, name string) error {
})
}
+func CreateSubBlock(ctx context.Context, parentBlockId string, blockDef *waveobj.BlockDef) (*waveobj.Block, error) {
+ return WithTxRtn(ctx, func(tx *TxWrap) (*waveobj.Block, error) {
+ parentBlock, _ := DBGet[*waveobj.Block](tx.Context(), parentBlockId)
+ if parentBlock == nil {
+ return nil, fmt.Errorf("parent block not found: %q", parentBlockId)
+ }
+ blockId := uuid.NewString()
+ blockData := &waveobj.Block{
+ OID: blockId,
+ ParentORef: waveobj.MakeORef(waveobj.OType_Block, parentBlockId).String(),
+ BlockDef: blockDef,
+ RuntimeOpts: nil,
+ Meta: blockDef.Meta,
+ }
+ DBInsert(tx.Context(), blockData)
+ parentBlock.SubBlockIds = append(parentBlock.SubBlockIds, blockId)
+ DBUpdate(tx.Context(), parentBlock)
+ return blockData, nil
+ })
+}
+
func CreateBlock(ctx context.Context, tabId string, blockDef *waveobj.BlockDef, rtOpts *waveobj.RuntimeOpts) (*waveobj.Block, error) {
return WithTxRtn(ctx, func(tx *TxWrap) (*waveobj.Block, error) {
tab, _ := DBGet[*waveobj.Tab](tx.Context(), tabId)
@@ -104,6 +125,7 @@ func CreateBlock(ctx context.Context, tabId string, blockDef *waveobj.BlockDef,
blockId := uuid.NewString()
blockData := &waveobj.Block{
OID: blockId,
+ ParentORef: waveobj.MakeORef(waveobj.OType_Tab, tabId).String(),
BlockDef: blockDef,
RuntimeOpts: rtOpts,
Meta: blockDef.Meta,
@@ -124,18 +146,34 @@ func findStringInSlice(slice []string, val string) int {
return -1
}
-func DeleteBlock(ctx context.Context, tabId string, blockId string) error {
+func DeleteBlock(ctx context.Context, blockId string) error {
return WithTx(ctx, func(tx *TxWrap) error {
- tab, _ := DBGet[*waveobj.Tab](tx.Context(), tabId)
- if tab == nil {
- return fmt.Errorf("tab not found: %q", tabId)
+ block, err := DBGet[*waveobj.Block](tx.Context(), blockId)
+ if err != nil {
+ return fmt.Errorf("error getting block: %w", err)
}
- blockIdx := findStringInSlice(tab.BlockIds, blockId)
- if blockIdx == -1 {
+ if block == nil {
return nil
}
- tab.BlockIds = append(tab.BlockIds[:blockIdx], tab.BlockIds[blockIdx+1:]...)
- DBUpdate(tx.Context(), tab)
+ if len(block.SubBlockIds) > 0 {
+ return fmt.Errorf("block has subblocks, must delete subblocks first")
+ }
+ parentORef := waveobj.ParseORefNoErr(block.ParentORef)
+ if parentORef != nil {
+ if parentORef.OType == waveobj.OType_Tab {
+ tab, _ := DBGet[*waveobj.Tab](tx.Context(), parentORef.OID)
+ if tab != nil {
+ tab.BlockIds = utilfn.RemoveElemFromSlice(tab.BlockIds, blockId)
+ DBUpdate(tx.Context(), tab)
+ }
+ } else if parentORef.OType == waveobj.OType_Block {
+ parentBlock, _ := DBGet[*waveobj.Block](tx.Context(), parentORef.OID)
+ if parentBlock != nil {
+ parentBlock.SubBlockIds = utilfn.RemoveElemFromSlice(parentBlock.SubBlockIds, blockId)
+ DBUpdate(tx.Context(), parentBlock)
+ }
+ }
+ }
DBDelete(tx.Context(), waveobj.OType_Block, blockId)
return nil
})
@@ -145,23 +183,18 @@ func DeleteBlock(ctx context.Context, tabId string, blockId string) error {
// also deletes LayoutState
func DeleteTab(ctx context.Context, workspaceId string, tabId string) error {
return WithTx(ctx, func(tx *TxWrap) error {
- ws, _ := DBGet[*waveobj.Workspace](tx.Context(), workspaceId)
- if ws == nil {
- return fmt.Errorf("workspace not found: %q", workspaceId)
- }
tab, _ := DBGet[*waveobj.Tab](tx.Context(), tabId)
if tab == nil {
- return fmt.Errorf("tab not found: %q", tabId)
+ return nil
}
if len(tab.BlockIds) != 0 {
return fmt.Errorf("tab has blocks, must delete blocks first")
}
- tabIdx := findStringInSlice(ws.TabIds, tabId)
- if tabIdx == -1 {
- return nil
+ ws, _ := DBGet[*waveobj.Workspace](tx.Context(), workspaceId)
+ if ws != nil {
+ ws.TabIds = utilfn.RemoveElemFromSlice(ws.TabIds, tabId)
+ DBUpdate(tx.Context(), ws)
}
- ws.TabIds = append(ws.TabIds[:tabIdx], ws.TabIds[tabIdx+1:]...)
- DBUpdate(tx.Context(), ws)
DBDelete(tx.Context(), waveobj.OType_Tab, tabId)
DBDelete(tx.Context(), waveobj.OType_LayoutState, tab.LayoutState)
return nil
@@ -190,6 +223,10 @@ func UpdateObjectMeta(ctx context.Context, oref waveobj.ORef, meta waveobj.MetaM
func MoveBlockToTab(ctx context.Context, currentTabId string, newTabId string, blockId string) error {
return WithTx(ctx, func(tx *TxWrap) error {
+ block, _ := DBGet[*waveobj.Block](tx.Context(), blockId)
+ if block == nil {
+ return fmt.Errorf("block not found: %q", blockId)
+ }
currentTab, _ := DBGet[*waveobj.Tab](tx.Context(), currentTabId)
if currentTab == nil {
return fmt.Errorf("current tab not found: %q", currentTabId)
@@ -204,6 +241,8 @@ func MoveBlockToTab(ctx context.Context, currentTabId string, newTabId string, b
}
currentTab.BlockIds = utilfn.RemoveElemFromSlice(currentTab.BlockIds, blockId)
newTab.BlockIds = append(newTab.BlockIds, blockId)
+ block.ParentORef = waveobj.MakeORef(waveobj.OType_Tab, newTabId).String()
+ DBUpdate(tx.Context(), block)
DBUpdate(tx.Context(), currentTab)
DBUpdate(tx.Context(), newTab)
return nil
diff --git a/pkg/wstore/wstore_dbops.go b/pkg/wstore/wstore_dbops.go
index c4cd17ba2..de7aa5c59 100644
--- a/pkg/wstore/wstore_dbops.go
+++ b/pkg/wstore/wstore_dbops.go
@@ -270,10 +270,39 @@ func DBFindWindowForTabId(ctx context.Context, tabId string) (string, error) {
func DBFindTabForBlockId(ctx context.Context, blockId string) (string, error) {
return WithTxRtn(ctx, func(tx *TxWrap) (string, error) {
- query := `
- SELECT t.oid
- FROM db_tab t, json_each(data->'blockids') je
- WHERE je.value = ?;`
- return tx.GetString(query, blockId), nil
+ iterNum := 1
+ for {
+ if iterNum > 5 {
+ return "", fmt.Errorf("too many iterations looking for tab in block parents")
+ }
+ query := `
+ SELECT json_extract(b.data, '$.parentoref') AS parentoref
+ FROM db_block b
+ WHERE b.oid = ?;`
+ parentORef := tx.GetString(query, blockId)
+ oref, err := waveobj.ParseORef(parentORef)
+ if err != nil {
+ return "", fmt.Errorf("bad block parent oref: %v", err)
+ }
+ if oref.OType == "tab" {
+ return oref.OID, nil
+ }
+ if oref.OType == "block" {
+ blockId = oref.OID
+ iterNum++
+ continue
+ }
+ return "", fmt.Errorf("bad parent oref type: %v", oref.OType)
+ }
+ })
+}
+
+func DBFindWorkspaceForTabId(ctx context.Context, tabId string) (string, error) {
+ return WithTxRtn(ctx, func(tx *TxWrap) (string, error) {
+ query := `
+ SELECT w.oid
+ FROM db_workspace w, json_each(data->'tabids') je
+ WHERE je.value = ?`
+ return tx.GetString(query, tabId), nil
})
}
diff --git a/pkg/wstore/wstore_dbsetup.go b/pkg/wstore/wstore_dbsetup.go
index 7df15a021..3a4f83585 100644
--- a/pkg/wstore/wstore_dbsetup.go
+++ b/pkg/wstore/wstore_dbsetup.go
@@ -42,7 +42,7 @@ func InitWStore() error {
}
func GetDBName() string {
- waveHome := wavebase.GetWaveHomeDir()
+ waveHome := wavebase.GetWaveDataDir()
return filepath.Join(waveHome, wavebase.WaveDBDir, WStoreDBName)
}
diff --git a/yarn.lock b/yarn.lock
index 0343ad213..79a3eae44 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -5329,6 +5329,13 @@ __metadata:
languageName: node
linkType: hard
+"env-paths@npm:^3.0.0":
+ version: 3.0.0
+ resolution: "env-paths@npm:3.0.0"
+ checksum: 10c0/76dec878cee47f841103bacd7fae03283af16f0702dad65102ef0a556f310b98a377885e0f32943831eb08b5ab37842a323d02529f3dfd5d0a40ca71b01b435f
+ languageName: node
+ linkType: hard
+
"err-code@npm:^2.0.2":
version: 2.0.3
resolution: "err-code@npm:2.0.3"
@@ -11717,6 +11724,7 @@ __metadata:
electron-builder: "npm:^25.1.8"
electron-updater: "npm:6.3.9"
electron-vite: "npm:^2.3.0"
+ env-paths: "npm:^3.0.0"
eslint: "npm:^9.13.0"
eslint-config-prettier: "npm:^9.1.0"
fast-average-color: "npm:^9.4.0"