diff --git a/frontend/app/element/button.less b/frontend/app/element/button.less new file mode 100644 index 000000000..05df3e9a2 --- /dev/null +++ b/frontend/app/element/button.less @@ -0,0 +1,135 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +.wave-button { + background: none; + border: none; + cursor: pointer; + outline: inherit; + display: flex; + padding: 6px 16px; + align-items: center; + gap: 4px; + border-radius: 6px; + height: auto; + line-height: 1.5; + white-space: nowrap; + user-select: none; + + color: var(--main-text-color); + background: var(--accent-color); + i { + fill: var(--main-text-color); + } + + &.primary { + color: var(--main-text-color); + background: var(--accent-color); + i { + fill: var(--main-text-color); + } + } + + &.primary.danger { + background: var(--error-color); + } + + &.primary.outlined { + background: none; + border: 1px solid var(--accent-color); + + i { + fill: var(--accent-color); + } + } + + &.primary.greyoutlined { + background: none; + border: 1px solid var(--secondary-text-color); + + i { + fill: var(--secondary-text-color); + } + } + + &.primary.outlined, + &.primary.greyoutlined { + &.hover-danger:hover { + color: var(--main-text-color); + border: 1px solid var(--error-color); + background: var(--error-color); + } + } + + &.primary.outlined.danger { + background: none; + border: 1px solid var(--error-color); + + i { + fill: var(--error-color); + } + } + + &.greytext { + color: var(--secondary-text-color); + } + + &.primary.ghost { + background: none; + i { + fill: var(--accent-color); + } + } + + &.primary.ghost.danger { + background: none; + i { + fill: var(--app-error-color); + } + } + + &.secondary { + color: var(--main-text-color); + background: var(--highlight-bg-color); + i { + fill: var(--main-text-color); + } + } + + &.secondary.outlined { + background: none; + border: 1px solid var(--main-text-color); + } + + &.secondary.outlined.danger { + background: none; + border: 1px solid var(--error-color); + } + + &.secondary.ghost { + background: none; + } + + &.secondary.danger { + color: var(--error-color); + } + + &.small { + padding: 4px 8px; + font-size: 12px; + border-radius: 3.6px; + } + + &.term-inline { + padding: 2px 8px; + border-radius: 3px; + } + + &.disabled { + opacity: 0.5; + } + + &.link-button { + cursor: pointer; + } +} diff --git a/frontend/app/element/button.tsx b/frontend/app/element/button.tsx new file mode 100644 index 000000000..18a49bdaf --- /dev/null +++ b/frontend/app/element/button.tsx @@ -0,0 +1,55 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from "react"; +import { clsx } from "clsx"; + +import "./button.less"; + +interface ButtonProps { + children: React.ReactNode; + onClick?: (e: React.MouseEvent) => void; + disabled?: boolean; + leftIcon?: React.ReactNode; + rightIcon?: React.ReactNode; + style?: React.CSSProperties; + autoFocus?: boolean; + className?: string; + termInline?: boolean; + title?: string; +} + +class Button extends React.Component { + static defaultProps = { + style: {}, + className: "primary", + }; + + handleClick(e) { + if (this.props.onClick && !this.props.disabled) { + this.props.onClick(e); + } + } + + render() { + const { leftIcon, rightIcon, children, disabled, style, autoFocus, termInline, className, title } = this.props; + + return ( + + ); + } +} + +export { Button }; +export type { ButtonProps }; diff --git a/frontend/app/element/markdown.less b/frontend/app/element/markdown.less index a2c0b4c89..21c460461 100644 --- a/frontend/app/element/markdown.less +++ b/frontend/app/element/markdown.less @@ -1,5 +1,8 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + .markdown { - color: var(--main-color); + color: var(--main-text-color); font-family: var(--markdown-font); font-size: 14px; overflow-wrap: break-word; @@ -9,13 +12,13 @@ } .title { - color: var(--main-color); + color: var(--main-text-color); margin-top: 16px; margin-bottom: 8px; } strong { - color: var(--main-color); + color: var(--main-text-color); } a { @@ -24,7 +27,7 @@ table { tr th { - color: var(--main-color); + color: var(--main-text-color); } } @@ -61,7 +64,7 @@ code { font: var(--fixed-font); - color: var(--main-color); + color: var(--main-text-color); border-radius: 4px; background-color: var(--panel-bg-color); padding: 0.15em 0.5em; diff --git a/frontend/app/store/global.ts b/frontend/app/store/global.ts index 55825f033..c9e64297f 100644 --- a/frontend/app/store/global.ts +++ b/frontend/app/store/global.ts @@ -4,6 +4,9 @@ import * as jotai from "jotai"; import { atomFamily } from "jotai/utils"; import { v4 as uuidv4 } from "uuid"; +import * as rxjs from "rxjs"; +import type { WailsEvent } from "@wailsio/runtime/types/events"; +import { Events } from "@wailsio/runtime"; const globalStore = jotai.createStore(); @@ -42,4 +45,40 @@ const atoms = { blockAtomFamily, }; -export { globalStore, atoms }; +type SubjectWithRef = rxjs.Subject & { refCount: number; release: () => void }; + +const blockSubjects = new Map>(); + +function getBlockSubject(blockId: string): SubjectWithRef { + let subject = blockSubjects.get(blockId); + if (subject == null) { + subject = new rxjs.Subject() as any; + subject.refCount = 0; + subject.release = () => { + subject.refCount--; + if (subject.refCount === 0) { + subject.complete(); + blockSubjects.delete(blockId); + } + }; + blockSubjects.set(blockId, subject); + } + subject.refCount++; + return subject; +} + +Events.On("block:ptydata", (event: any) => { + const data = event?.data; + if (data?.blockid == null) { + console.log("block:ptydata with null blockid"); + return; + } + // we don't use getBlockSubject here because we don't want to create a new subject + const subject = blockSubjects.get(data.blockid); + if (subject == null) { + return; + } + subject.next(data); +}); + +export { globalStore, atoms, getBlockSubject }; diff --git a/frontend/app/view/term.tsx b/frontend/app/view/term.tsx index a8387f634..2cc1f81c8 100644 --- a/frontend/app/view/term.tsx +++ b/frontend/app/view/term.tsx @@ -6,6 +6,10 @@ import * as jotai from "jotai"; import { Terminal } from "@xterm/xterm"; import type { ITheme } from "@xterm/xterm"; import { FitAddon } from "@xterm/addon-fit"; +import { Button } from "@/element/button"; +import * as BlockService from "@/bindings/pkg/service/blockservice/BlockService"; +import { getBlockSubject } from "@/store/global"; +import { base64ToArray } from "@/util/util"; import "./view.less"; import "/public/xterm.css"; @@ -40,6 +44,8 @@ function getThemeFromCSSVars(el: Element): ITheme { const TerminalView = ({ blockId }: { blockId: string }) => { const connectElemRef = React.useRef(null); + const [term, setTerm] = React.useState(null); + const [blockStarted, setBlockStarted] = React.useState(false); React.useEffect(() => { if (!connectElemRef.current) { @@ -53,17 +59,85 @@ const TerminalView = ({ blockId }: { blockId: string }) => { fontWeight: "normal", fontWeightBold: "bold", }); + setTerm(term); const fitAddon = new FitAddon(); term.loadAddon(fitAddon); term.open(connectElemRef.current); fitAddon.fit(); - term.write("Hello, world!"); + term.write("Hello, world!\r\n"); + console.log(term); + term.onData((data) => { + const b64data = btoa(data); + const inputCmd = { command: "input", blockid: blockId, inputdata64: b64data }; + BlockService.SendCommand(blockId, inputCmd); + }); + + // resize observer + const rszObs = new ResizeObserver(() => { + const oldRows = term.rows; + const oldCols = term.cols; + fitAddon.fit(); + if (oldRows !== term.rows || oldCols !== term.cols) { + BlockService.SendCommand(blockId, { command: "input", termsize: { rows: term.rows, cols: term.cols } }); + } + }); + rszObs.observe(connectElemRef.current); + + // block subject + const blockSubject = getBlockSubject(blockId); + blockSubject.subscribe((data) => { + // base64 decode + const decodedData = base64ToArray(data.ptydata); + term.write(decodedData); + }); + return () => { term.dispose(); + rszObs.disconnect(); + blockSubject.release(); }; }, [connectElemRef.current]); - return
; + async function handleRunClick() { + try { + if (!blockStarted) { + await BlockService.StartBlock(blockId); + setBlockStarted(true); + } + let termSize = { rows: term.rows, cols: term.cols }; + await BlockService.SendCommand(blockId, { command: "run", cmdstr: "ls -l", termsize: termSize }); + } catch (e) { + console.log("run click error: ", e); + } + } + + async function handleStartTerminalClick() { + try { + if (!blockStarted) { + await BlockService.StartBlock(blockId); + setBlockStarted(true); + } + let termSize = { rows: term.rows, cols: term.cols }; + await BlockService.SendCommand(blockId, { command: "runshell", termsize: termSize }); + } catch (e) { + console.log("start terminal click error: ", e); + } + } + + return ( +
+
+
Terminal
+ + +
+
+
+ ); }; export { TerminalView }; diff --git a/frontend/app/view/view.less b/frontend/app/view/view.less index 7be9e9179..b75ad0643 100644 --- a/frontend/app/view/view.less +++ b/frontend/app/view/view.less @@ -7,6 +7,23 @@ width: 100%; height: 100%; overflow: hidden; + + .term-header { + display: flex; + flex-direction: row; + padding: 4px 10px; + height: 35px; + gap: 10px; + align-items: center; + flex-shrink: 0; + border-bottom: 1px solid var(--border-color); + } + + .term-connectelem { + flex-grow: 1; + min-height: 0; + overflow: hidden; + } } .view-preview { diff --git a/go.mod b/go.mod index 82c6688e5..44ac8e0c5 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 toolchain go1.22.1 require ( + github.com/creack/pty v1.1.18 github.com/golang-migrate/migrate/v4 v4.17.1 github.com/google/uuid v1.4.0 github.com/jmoiron/sqlx v1.4.0 diff --git a/go.sum b/go.sum index 9c2e268d0..43b01693a 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7N github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/blockcontroller/blockcommand.go b/pkg/blockcontroller/blockcommand.go new file mode 100644 index 000000000..bdff3f0f1 --- /dev/null +++ b/pkg/blockcontroller/blockcommand.go @@ -0,0 +1,92 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package blockcontroller + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/wavetermdev/thenextwave/pkg/shellexec" +) + +const CommandKey = "command" + +const ( + BlockCommand_Message = "message" + BlockCommand_Run = "run" + BlockCommand_Input = "input" + BlockCommand_RunShell = "runshell" +) + +var CommandToTypeMap = map[string]reflect.Type{ + BlockCommand_Message: reflect.TypeOf(MessageCommand{}), + BlockCommand_Run: reflect.TypeOf(RunCommand{}), + BlockCommand_Input: reflect.TypeOf(InputCommand{}), + BlockCommand_RunShell: reflect.TypeOf(RunShellCommand{}), +} + +type BlockCommand interface { + GetCommand() string +} + +func ParseCmdMap(cmdMap map[string]any) (BlockCommand, error) { + cmdType, ok := cmdMap[CommandKey].(string) + if !ok { + return nil, fmt.Errorf("no %s field in command map", CommandKey) + } + mapJson, err := json.Marshal(cmdMap) + if err != nil { + return nil, fmt.Errorf("error marshalling command map: %w", err) + } + rtype := CommandToTypeMap[cmdType] + if rtype == nil { + return nil, fmt.Errorf("unknown command type %q", cmdType) + } + cmd := reflect.New(rtype).Interface() + err = json.Unmarshal(mapJson, cmd) + if err != nil { + return nil, fmt.Errorf("error unmarshalling command: %w", err) + } + return cmd.(BlockCommand), nil +} + +type MessageCommand struct { + Command string `json:"command"` + Message string `json:"message"` +} + +func (mc *MessageCommand) GetCommand() string { + return BlockCommand_Message +} + +type RunCommand struct { + Command string `json:"command"` + CmdStr string `json:"cmdstr"` + TermSize shellexec.TermSize `json:"termsize"` +} + +func (rc *RunCommand) GetCommand() string { + return BlockCommand_Run +} + +type InputCommand struct { + Command string `json:"command"` + InputData64 string `json:"inputdata64"` + SigName string `json:"signame,omitempty"` + TermSize *shellexec.TermSize `json:"termsize,omitempty"` +} + +func (ic *InputCommand) GetCommand() string { + return BlockCommand_Input +} + +type RunShellCommand struct { + Command string `json:"command"` + TermSize shellexec.TermSize `json:"termsize"` +} + +func (rsc *RunShellCommand) GetCommand() string { + return BlockCommand_RunShell +} diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 22e2e9864..6bf657b84 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -4,54 +4,137 @@ package blockcontroller import ( - "encoding/json" + "encoding/base64" "fmt" + "io" + "log" + "os/exec" "sync" + "github.com/creack/pty" "github.com/wailsapp/wails/v3/pkg/application" "github.com/wavetermdev/thenextwave/pkg/eventbus" + "github.com/wavetermdev/thenextwave/pkg/shellexec" + "github.com/wavetermdev/thenextwave/pkg/util/shellutil" ) var globalLock = &sync.Mutex{} var blockControllerMap = make(map[string]*BlockController) -type BlockCommand interface { - GetType() string -} - -type MessageCommand struct { - Message string `json:"message"` -} - -func (mc *MessageCommand) GetType() string { - return "message" -} - type BlockController struct { - BlockId string - InputCh chan BlockCommand + Lock *sync.Mutex + BlockId string + InputCh chan BlockCommand + ShellProc *shellexec.ShellProc + ShellInputCh chan *InputCommand } -func ParseCmdMap(cmdMap map[string]any) (BlockCommand, error) { - cmdType, ok := cmdMap["type"].(string) - if !ok { - return nil, fmt.Errorf("no type field in command map") +func (bc *BlockController) setShellProc(shellProc *shellexec.ShellProc) error { + bc.Lock.Lock() + defer bc.Lock.Unlock() + if bc.ShellProc != nil { + return fmt.Errorf("shell process already running") } - mapJson, err := json.Marshal(cmdMap) + bc.ShellProc = shellProc + return nil +} + +func (bc *BlockController) getShellProc() *shellexec.ShellProc { + bc.Lock.Lock() + defer bc.Lock.Unlock() + return bc.ShellProc +} + +func (bc *BlockController) DoRunCommand(rc *RunCommand) error { + cmdStr := rc.CmdStr + shellPath := shellutil.DetectLocalShellPath() + ecmd := exec.Command(shellPath, "-c", cmdStr) + log.Printf("running shell command: %q %q\n", shellPath, cmdStr) + barr, err := shellexec.RunSimpleCmdInPty(ecmd, rc.TermSize) if err != nil { - return nil, fmt.Errorf("error marshalling command map: %w", err) + return err } - switch cmdType { - case "message": - var cmd MessageCommand - err := json.Unmarshal(mapJson, &cmd) - if err != nil { - return nil, fmt.Errorf("error unmarshalling message command: %w", err) + for len(barr) > 0 { + part := barr + if len(part) > 4096 { + part = part[:4096] } - return &cmd, nil - default: - return nil, fmt.Errorf("unknown command type %q", cmdType) + eventbus.SendEvent(application.WailsEvent{ + Name: "block:ptydata", + Data: map[string]any{ + "blockid": bc.BlockId, + "blockfile": "main", + "ptydata": base64.StdEncoding.EncodeToString(part), + }, + }) + barr = barr[len(part):] } + return nil +} + +func (bc *BlockController) DoRunShellCommand(rc *RunShellCommand) error { + if bc.getShellProc() != nil { + return nil + } + shellProc, err := shellexec.StartShellProc(rc.TermSize) + if err != nil { + return err + } + err = bc.setShellProc(shellProc) + if err != nil { + bc.ShellProc.Close() + return err + } + shellInputCh := make(chan *InputCommand) + bc.ShellInputCh = shellInputCh + go func() { + defer func() { + // needs synchronization + bc.ShellProc.Close() + close(bc.ShellInputCh) + bc.ShellProc = nil + bc.ShellInputCh = nil + }() + buf := make([]byte, 4096) + for { + nr, err := bc.ShellProc.Pty.Read(buf) + eventbus.SendEvent(application.WailsEvent{ + Name: "block:ptydata", + Data: map[string]any{ + "blockid": bc.BlockId, + "blockfile": "main", + "ptydata": base64.StdEncoding.EncodeToString(buf[:nr]), + }, + }) + if err == io.EOF { + break + } + if err != nil { + log.Printf("error reading from shell: %v\n", err) + break + } + } + }() + go func() { + for ic := range shellInputCh { + if ic.InputData64 != "" { + inputBuf := make([]byte, base64.StdEncoding.DecodedLen(len(ic.InputData64))) + nw, err := base64.StdEncoding.Decode(inputBuf, []byte(ic.InputData64)) + if err != nil { + log.Printf("error decoding input data: %v\n", err) + continue + } + bc.ShellProc.Pty.Write(inputBuf[:nw]) + } + if ic.TermSize != nil { + err := pty.Setsize(bc.ShellProc.Pty, &pty.Winsize{Rows: uint16(ic.TermSize.Rows), Cols: uint16(ic.TermSize.Cols)}) + if err != nil { + log.Printf("error setting term size: %v\n", err) + } + } + } + }() + return nil } func (bc *BlockController) Run() { @@ -65,21 +148,59 @@ func (bc *BlockController) Run() { delete(blockControllerMap, bc.BlockId) }() + messageCount := 0 for genCmd := range bc.InputCh { switch cmd := genCmd.(type) { case *MessageCommand: fmt.Printf("MESSAGE: %s | %q\n", bc.BlockId, cmd.Message) + messageCount++ + eventbus.SendEvent(application.WailsEvent{ + Name: "block:ptydata", + Data: map[string]any{ + "blockid": bc.BlockId, + "blockfile": "main", + "ptydata": base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("message %d\r\n", messageCount))), + }, + }) + case *RunCommand: + fmt.Printf("RUN: %s | %q\n", bc.BlockId, cmd.CmdStr) + go func() { + err := bc.DoRunCommand(cmd) + if err != nil { + log.Printf("error running shell command: %v\n", err) + } + }() + case *InputCommand: + fmt.Printf("INPUT: %s | %q\n", bc.BlockId, cmd.InputData64) + if bc.ShellInputCh != nil { + bc.ShellInputCh <- cmd + } + case *RunShellCommand: + fmt.Printf("RUNSHELL: %s\n", bc.BlockId) + if bc.ShellProc != nil { + continue + } + go func() { + err := bc.DoRunShellCommand(cmd) + if err != nil { + log.Printf("error running shell: %v\n", err) + } + }() default: fmt.Printf("unknown command type %T\n", cmd) } } } -func NewBlockController(blockId string) *BlockController { +func StartBlockController(blockId string) *BlockController { globalLock.Lock() defer globalLock.Unlock() + if existingBC, ok := blockControllerMap[blockId]; ok { + return existingBC + } bc := &BlockController{ + Lock: &sync.Mutex{}, BlockId: blockId, InputCh: make(chan BlockCommand), } diff --git a/pkg/service/blockservice/blockservice.go b/pkg/service/blockservice/blockservice.go index 55aff6b17..257563b3a 100644 --- a/pkg/service/blockservice/blockservice.go +++ b/pkg/service/blockservice/blockservice.go @@ -11,6 +11,11 @@ import ( type BlockService struct{} +func (bs *BlockService) StartBlock(blockId string) error { + blockcontroller.StartBlockController(blockId) + return nil +} + func (bs *BlockService) SendCommand(blockId string, cmdMap map[string]any) error { bc := blockcontroller.GetBlockController(blockId) if bc == nil { diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go new file mode 100644 index 000000000..ef041327f --- /dev/null +++ b/pkg/shellexec/shellexec.go @@ -0,0 +1,109 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package shellexec + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "syscall" + + "github.com/creack/pty" + "github.com/wavetermdev/thenextwave/pkg/util/shellutil" +) + +type TermSize struct { + Rows int `json:"rows"` + Cols int `json:"cols"` +} + +type ShellProc struct { + Cmd *exec.Cmd + Pty *os.File +} + +func (sp *ShellProc) Close() { + sp.Cmd.Process.Kill() + go func() { + sp.Cmd.Process.Wait() + sp.Pty.Close() + }() +} + +func StartShellProc(termSize TermSize) (*ShellProc, error) { + shellPath := shellutil.DetectLocalShellPath() + ecmd := exec.Command(shellPath, "-i", "-l") + ecmd.Env = os.Environ() + shellutil.UpdateCmdEnv(ecmd, shellutil.WaveshellEnvVars(shellutil.DefaultTermType)) + cmdPty, cmdTty, err := pty.Open() + if err != nil { + return nil, fmt.Errorf("opening new pty: %w", err) + } + if termSize.Rows == 0 || termSize.Cols == 0 { + termSize.Rows = shellutil.DefaultTermRows + termSize.Cols = shellutil.DefaultTermCols + } + if termSize.Rows <= 0 || termSize.Cols <= 0 { + return nil, fmt.Errorf("invalid term size: %v", termSize) + } + pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) + ecmd.Stdin = cmdTty + ecmd.Stdout = cmdTty + ecmd.Stderr = cmdTty + ecmd.SysProcAttr = &syscall.SysProcAttr{} + ecmd.SysProcAttr.Setsid = true + ecmd.SysProcAttr.Setctty = true + err = ecmd.Start() + cmdTty.Close() + if err != nil { + cmdPty.Close() + return nil, err + } + return &ShellProc{Cmd: ecmd, Pty: cmdPty}, nil +} + +func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize TermSize) ([]byte, error) { + ecmd.Env = os.Environ() + shellutil.UpdateCmdEnv(ecmd, shellutil.WaveshellEnvVars(shellutil.DefaultTermType)) + cmdPty, cmdTty, err := pty.Open() + if err != nil { + return nil, fmt.Errorf("opening new pty: %w", err) + } + if termSize.Rows == 0 || termSize.Cols == 0 { + termSize.Rows = shellutil.DefaultTermRows + termSize.Cols = shellutil.DefaultTermCols + } + if termSize.Rows <= 0 || termSize.Cols <= 0 { + return nil, fmt.Errorf("invalid term size: %v", termSize) + } + pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) + ecmd.Stdin = cmdTty + ecmd.Stdout = cmdTty + ecmd.Stderr = cmdTty + ecmd.SysProcAttr = &syscall.SysProcAttr{} + ecmd.SysProcAttr.Setsid = true + ecmd.SysProcAttr.Setctty = true + err = ecmd.Start() + cmdTty.Close() + if err != nil { + cmdPty.Close() + return nil, err + } + defer cmdPty.Close() + ioDone := make(chan bool) + var outputBuf bytes.Buffer + go func() { + // ignore error (/dev/ptmx has read error when process is done) + defer close(ioDone) + io.Copy(&outputBuf, cmdPty) + }() + exitErr := ecmd.Wait() + if exitErr != nil { + return nil, exitErr + } + <-ioDone + return outputBuf.Bytes(), nil +} diff --git a/pkg/util/shellutil/shellutil.go b/pkg/util/shellutil/shellutil.go new file mode 100644 index 000000000..ce635826f --- /dev/null +++ b/pkg/util/shellutil/shellutil.go @@ -0,0 +1,117 @@ +// Copyright 2023, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package shellutil + +import ( + "context" + "os" + "os/exec" + "os/user" + "regexp" + "runtime" + "strings" + "sync" + "time" + + "github.com/wavetermdev/thenextwave/pkg/wavebase" +) + +const DefaultTermType = "xterm-256color" +const DefaultTermRows = 24 +const DefaultTermCols = 80 + +var cachedMacUserShell string +var macUserShellOnce = &sync.Once{} +var userShellRegexp = regexp.MustCompile(`^UserShell: (.*)$`) + +const DefaultShellPath = "/bin/bash" + +func DetectLocalShellPath() string { + shellPath := GetMacUserShell() + if shellPath == "" { + shellPath = os.Getenv("SHELL") + } + if shellPath == "" { + return DefaultShellPath + } + return shellPath +} + +func GetMacUserShell() string { + if runtime.GOOS != "darwin" { + return "" + } + macUserShellOnce.Do(func() { + cachedMacUserShell = internalMacUserShell() + }) + return cachedMacUserShell +} + +// dscl . -read /User/[username] UserShell +// defaults to /bin/bash +func internalMacUserShell() string { + osUser, err := user.Current() + if err != nil { + return DefaultShellPath + } + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + userStr := "/Users/" + osUser.Username + out, err := exec.CommandContext(ctx, "dscl", ".", "-read", userStr, "UserShell").CombinedOutput() + if err != nil { + return DefaultShellPath + } + outStr := strings.TrimSpace(string(out)) + m := userShellRegexp.FindStringSubmatch(outStr) + if m == nil { + return DefaultShellPath + } + return m[1] +} + +func WaveshellEnvVars(termType string) map[string]string { + rtn := make(map[string]string) + if termType != "" { + rtn["TERM"] = termType + } + rtn["WAVETERM"], _ = os.Executable() + rtn["WAVETERM_VERSION"] = wavebase.WaveVersion + return rtn +} + +func UpdateCmdEnv(cmd *exec.Cmd, envVars map[string]string) { + if len(envVars) == 0 { + return + } + found := make(map[string]bool) + var newEnv []string + for _, envStr := range cmd.Env { + envKey := GetEnvStrKey(envStr) + newEnvVal, ok := envVars[envKey] + if ok { + if newEnvVal == "" { + continue + } + newEnv = append(newEnv, envKey+"="+newEnvVal) + found[envKey] = true + } else { + newEnv = append(newEnv, envStr) + } + } + for envKey, envVal := range envVars { + if found[envKey] { + continue + } + newEnv = append(newEnv, envKey+"="+envVal) + } + cmd.Env = newEnv +} + +func GetEnvStrKey(envStr string) string { + eqIdx := strings.Index(envStr, "=") + if eqIdx == -1 { + return envStr + } + return envStr[0:eqIdx] +} diff --git a/pkg/util/utilfn/utilfn.go b/pkg/util/utilfn/utilfn.go new file mode 100644 index 000000000..d83bf8574 --- /dev/null +++ b/pkg/util/utilfn/utilfn.go @@ -0,0 +1,675 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package utilfn + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math" + mathrand "math/rand" + "net/http" + "os" + "os/exec" + "regexp" + "sort" + "strings" + "syscall" + "unicode/utf8" +) + +var HexDigits = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'} + +func GetStrArr(v interface{}, field string) []string { + if v == nil { + return nil + } + m, ok := v.(map[string]interface{}) + if !ok { + return nil + } + fieldVal := m[field] + if fieldVal == nil { + return nil + } + iarr, ok := fieldVal.([]interface{}) + if !ok { + return nil + } + var sarr []string + for _, iv := range iarr { + if sv, ok := iv.(string); ok { + sarr = append(sarr, sv) + } + } + return sarr +} + +func GetBool(v interface{}, field string) bool { + if v == nil { + return false + } + m, ok := v.(map[string]interface{}) + if !ok { + return false + } + fieldVal := m[field] + if fieldVal == nil { + return false + } + bval, ok := fieldVal.(bool) + if !ok { + return false + } + return bval +} + +var needsQuoteRe = regexp.MustCompile(`[^\w@%:,./=+-]`) + +// minimum maxlen=6 +func ShellQuote(val string, forceQuote bool, maxLen int) string { + if maxLen < 6 { + maxLen = 6 + } + rtn := val + if needsQuoteRe.MatchString(val) { + rtn = "'" + strings.ReplaceAll(val, "'", `'"'"'`) + "'" + } + if strings.HasPrefix(rtn, "\"") || strings.HasPrefix(rtn, "'") { + if len(rtn) > maxLen { + return rtn[0:maxLen-4] + "..." + rtn[0:1] + } + return rtn + } + if forceQuote { + if len(rtn) > maxLen-2 { + return "\"" + rtn[0:maxLen-5] + "...\"" + } + return "\"" + rtn + "\"" + } else { + if len(rtn) > maxLen { + return rtn[0:maxLen-3] + "..." + } + return rtn + } +} + +func EllipsisStr(s string, maxLen int) string { + if maxLen < 4 { + maxLen = 4 + } + if len(s) > maxLen { + return s[0:maxLen-3] + "..." + } + return s +} + +func LongestPrefix(root string, strs []string) string { + if len(strs) == 0 { + return root + } + if len(strs) == 1 { + comp := strs[0] + if len(comp) >= len(root) && strings.HasPrefix(comp, root) { + if strings.HasSuffix(comp, "/") { + return strs[0] + } + return strs[0] + } + } + lcp := strs[0] + for i := 1; i < len(strs); i++ { + s := strs[i] + for j := 0; j < len(lcp); j++ { + if j >= len(s) || lcp[j] != s[j] { + lcp = lcp[0:j] + break + } + } + } + if len(lcp) < len(root) || !strings.HasPrefix(lcp, root) { + return root + } + return lcp +} + +func ContainsStr(strs []string, test string) bool { + for _, s := range strs { + if s == test { + return true + } + } + return false +} + +func IsPrefix(strs []string, test string) bool { + for _, s := range strs { + if len(s) > len(test) && strings.HasPrefix(s, test) { + return true + } + } + return false +} + +// sentinel value for StrWithPos.Pos to indicate no position +const NoStrPos = -1 + +type StrWithPos struct { + Str string `json:"str"` + Pos int `json:"pos"` // this is a 'rune' position (not a byte position) +} + +func (sp StrWithPos) String() string { + return strWithCursor(sp.Str, sp.Pos) +} + +func ParseToSP(s string) StrWithPos { + idx := strings.Index(s, "[*]") + if idx == -1 { + return StrWithPos{Str: s, Pos: NoStrPos} + } + return StrWithPos{Str: s[0:idx] + s[idx+3:], Pos: utf8.RuneCountInString(s[0:idx])} +} + +func strWithCursor(str string, pos int) string { + if pos == NoStrPos { + return str + } + if pos < 0 { + // invalid position + return "[*]_" + str + } + if pos > len(str) { + // invalid position + return str + "_[*]" + } + if pos == len(str) { + return str + "[*]" + } + var rtn []rune + for _, ch := range str { + if len(rtn) == pos { + rtn = append(rtn, '[', '*', ']') + } + rtn = append(rtn, ch) + } + return string(rtn) +} + +func (sp StrWithPos) Prepend(str string) StrWithPos { + return StrWithPos{Str: str + sp.Str, Pos: utf8.RuneCountInString(str) + sp.Pos} +} + +func (sp StrWithPos) Append(str string) StrWithPos { + return StrWithPos{Str: sp.Str + str, Pos: sp.Pos} +} + +// returns base64 hash of data +func Sha1Hash(data []byte) string { + hvalRaw := sha1.Sum(data) + hval := base64.StdEncoding.EncodeToString(hvalRaw[:]) + return hval +} + +func ChunkSlice[T any](s []T, chunkSize int) [][]T { + var rtn [][]T + for len(rtn) > 0 { + if len(s) <= chunkSize { + rtn = append(rtn, s) + break + } + rtn = append(rtn, s[:chunkSize]) + s = s[chunkSize:] + } + return rtn +} + +var ErrOverflow = errors.New("integer overflow") + +// Add two int values, returning an error if the result overflows. +func AddInt(left, right int) (int, error) { + if right > 0 { + if left > math.MaxInt-right { + return 0, ErrOverflow + } + } else { + if left < math.MinInt-right { + return 0, ErrOverflow + } + } + return left + right, nil +} + +// Add a slice of ints, returning an error if the result overflows. +func AddIntSlice(vals ...int) (int, error) { + var rtn int + for _, v := range vals { + var err error + rtn, err = AddInt(rtn, v) + if err != nil { + return 0, err + } + } + return rtn, nil +} + +func StrsEqual(s1arr []string, s2arr []string) bool { + if len(s1arr) != len(s2arr) { + return false + } + for i, s1 := range s1arr { + s2 := s2arr[i] + if s1 != s2 { + return false + } + } + return true +} + +func StrMapsEqual(m1 map[string]string, m2 map[string]string) bool { + if len(m1) != len(m2) { + return false + } + for key, val1 := range m1 { + val2, found := m2[key] + if !found || val1 != val2 { + return false + } + } + for key := range m2 { + _, found := m1[key] + if !found { + return false + } + } + return true +} + +func ByteMapsEqual(m1 map[string][]byte, m2 map[string][]byte) bool { + if len(m1) != len(m2) { + return false + } + for key, val1 := range m1 { + val2, found := m2[key] + if !found || !bytes.Equal(val1, val2) { + return false + } + } + for key := range m2 { + _, found := m1[key] + if !found { + return false + } + } + return true +} + +func GetOrderedStringerMapKeys[K interface { + comparable + fmt.Stringer +}, V any](m map[K]V) []K { + keyStrMap := make(map[K]string) + keys := make([]K, 0, len(m)) + for key := range m { + keys = append(keys, key) + keyStrMap[key] = key.String() + } + sort.Slice(keys, func(i, j int) bool { + return keyStrMap[keys[i]] < keyStrMap[keys[j]] + }) + return keys +} + +func GetOrderedMapKeys[V any](m map[string]V) []string { + keys := make([]string, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} + +const ( + nullEncodeEscByte = '\\' + nullEncodeSepByte = '|' + nullEncodeEqByte = '=' + nullEncodeZeroByteEsc = '0' + nullEncodeEscByteEsc = '\\' + nullEncodeSepByteEsc = 's' + nullEncodeEqByteEsc = 'e' +) + +func EncodeStringMap(m map[string]string) []byte { + var buf bytes.Buffer + for idx, key := range GetOrderedMapKeys(m) { + val := m[key] + buf.Write(NullEncodeStr(key)) + buf.WriteByte(nullEncodeEqByte) + buf.Write(NullEncodeStr(val)) + if idx < len(m)-1 { + buf.WriteByte(nullEncodeSepByte) + } + } + return buf.Bytes() +} + +func DecodeStringMap(barr []byte) (map[string]string, error) { + if len(barr) == 0 { + return nil, nil + } + var rtn = make(map[string]string) + for _, b := range bytes.Split(barr, []byte{nullEncodeSepByte}) { + keyVal := bytes.SplitN(b, []byte{nullEncodeEqByte}, 2) + if len(keyVal) != 2 { + return nil, fmt.Errorf("invalid null encoding: %s", string(b)) + } + key, err := NullDecodeStr(keyVal[0]) + if err != nil { + return nil, err + } + val, err := NullDecodeStr(keyVal[1]) + if err != nil { + return nil, err + } + rtn[key] = val + } + return rtn, nil +} + +func EncodeStringArray(arr []string) []byte { + var buf bytes.Buffer + for idx, s := range arr { + buf.Write(NullEncodeStr(s)) + if idx < len(arr)-1 { + buf.WriteByte(nullEncodeSepByte) + } + } + return buf.Bytes() +} + +func DecodeStringArray(barr []byte) ([]string, error) { + if len(barr) == 0 { + return nil, nil + } + var rtn []string + for _, b := range bytes.Split(barr, []byte{nullEncodeSepByte}) { + s, err := NullDecodeStr(b) + if err != nil { + return nil, err + } + rtn = append(rtn, s) + } + return rtn, nil +} + +func EncodedStringArrayHasFirstVal(encoded []byte, firstKey string) bool { + firstKeyBytes := NullEncodeStr(firstKey) + if !bytes.HasPrefix(encoded, firstKeyBytes) { + return false + } + if len(encoded) == len(firstKeyBytes) || encoded[len(firstKeyBytes)] == nullEncodeSepByte { + return true + } + return false +} + +// on encoding error returns "" +// this is used to perform logic on first value without decoding the entire array +func EncodedStringArrayGetFirstVal(encoded []byte) string { + sepIdx := bytes.IndexByte(encoded, nullEncodeSepByte) + if sepIdx == -1 { + str, _ := NullDecodeStr(encoded) + return str + } + str, _ := NullDecodeStr(encoded[0:sepIdx]) + return str +} + +// encodes a string, removing null/zero bytes (and separators '|') +// a zero byte is encoded as "\0", a '\' is encoded as "\\", sep is encoded as "\s" +// allows for easy double splitting (first on \x00, and next on "|") +func NullEncodeStr(s string) []byte { + strBytes := []byte(s) + if bytes.IndexByte(strBytes, 0) == -1 && + bytes.IndexByte(strBytes, nullEncodeEscByte) == -1 && + bytes.IndexByte(strBytes, nullEncodeSepByte) == -1 && + bytes.IndexByte(strBytes, nullEncodeEqByte) == -1 { + return strBytes + } + var rtn []byte + for _, b := range strBytes { + if b == 0 { + rtn = append(rtn, nullEncodeEscByte, nullEncodeZeroByteEsc) + } else if b == nullEncodeEscByte { + rtn = append(rtn, nullEncodeEscByte, nullEncodeEscByteEsc) + } else if b == nullEncodeSepByte { + rtn = append(rtn, nullEncodeEscByte, nullEncodeSepByteEsc) + } else if b == nullEncodeEqByte { + rtn = append(rtn, nullEncodeEscByte, nullEncodeEqByteEsc) + } else { + rtn = append(rtn, b) + } + } + return rtn +} + +func NullDecodeStr(barr []byte) (string, error) { + if bytes.IndexByte(barr, nullEncodeEscByte) == -1 { + return string(barr), nil + } + var rtn []byte + for i := 0; i < len(barr); i++ { + curByte := barr[i] + if curByte == nullEncodeEscByte { + i++ + nextByte := barr[i] + if nextByte == nullEncodeZeroByteEsc { + rtn = append(rtn, 0) + } else if nextByte == nullEncodeEscByteEsc { + rtn = append(rtn, nullEncodeEscByte) + } else if nextByte == nullEncodeSepByteEsc { + rtn = append(rtn, nullEncodeSepByte) + } else if nextByte == nullEncodeEqByteEsc { + rtn = append(rtn, nullEncodeEqByte) + } else { + // invalid encoding + return "", fmt.Errorf("invalid null encoding: %d", nextByte) + } + } else { + rtn = append(rtn, curByte) + } + } + return string(rtn), nil +} + +func SortStringRunes(s string) string { + runes := []rune(s) + sort.Slice(runes, func(i, j int) bool { + return runes[i] < runes[j] + }) + return string(runes) +} + +// will overwrite m1 with m2's values +func CombineMaps[V any](m1 map[string]V, m2 map[string]V) { + for key, val := range m2 { + m1[key] = val + } +} + +// returns hex escaped string (\xNN for each byte) +func ShellHexEscape(s string) string { + var rtn []byte + for _, ch := range []byte(s) { + rtn = append(rtn, []byte(fmt.Sprintf("\\x%02x", ch))...) + } + return string(rtn) +} + +func GetMapKeys[K comparable, V any](m map[K]V) []K { + var rtn []K + for key := range m { + rtn = append(rtn, key) + } + return rtn +} + +// combines string arrays and removes duplicates (returns a new array) +func CombineStrArrays(sarr1 []string, sarr2 []string) []string { + var rtn []string + m := make(map[string]struct{}) + for _, s := range sarr1 { + if _, found := m[s]; found { + continue + } + m[s] = struct{}{} + rtn = append(rtn, s) + } + for _, s := range sarr2 { + if _, found := m[s]; found { + continue + } + m[s] = struct{}{} + rtn = append(rtn, s) + } + return rtn +} + +func QuickJson(v interface{}) string { + barr, _ := json.Marshal(v) + return string(barr) +} + +func QuickParseJson[T any](s string) T { + var v T + _ = json.Unmarshal([]byte(s), &v) + return v +} + +func StrArrayToMap(sarr []string) map[string]bool { + m := make(map[string]bool) + for _, s := range sarr { + m[s] = true + } + return m +} + +func AppendNonZeroRandomBytes(b []byte, randLen int) []byte { + if randLen <= 0 { + return b + } + numAdded := 0 + for numAdded < randLen { + rn := mathrand.Intn(256) + if rn > 0 && rn < 256 { // exclude 0, also helps to suppress security warning to have a guard here + b = append(b, byte(rn)) + numAdded++ + } + } + return b +} + +// returns (isEOF, error) +func CopyWithEndBytes(outputBuf *bytes.Buffer, reader io.Reader, endBytes []byte) (bool, error) { + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + if n > 0 { + outputBuf.Write(buf[:n]) + obytes := outputBuf.Bytes() + if bytes.HasSuffix(obytes, endBytes) { + outputBuf.Truncate(len(obytes) - len(endBytes)) + return (err == io.EOF), nil + } + } + if err == io.EOF { + return true, nil + } + if err != nil { + return false, err + } + } +} + +// does *not* close outputCh on EOF or error +func CopyToChannel(outputCh chan<- []byte, reader io.Reader) error { + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + if n > 0 { + // copy so client can use []byte without it being overwritten + bufCopy := make([]byte, n) + copy(bufCopy, buf[:n]) + outputCh <- bufCopy + } + if err == io.EOF { + return nil + } + if err != nil { + return err + } + } +} + +// on error just returns "" +// does not return "application/octet-stream" as this is considered a detection failure +func DetectMimeType(path string) string { + fd, err := os.Open(path) + if err != nil { + return "" + } + defer fd.Close() + buf := make([]byte, 512) + // ignore the error (EOF / UnexpectedEOF is fine, just process how much we got back) + n, _ := io.ReadAtLeast(fd, buf, 512) + if n == 0 { + return "" + } + buf = buf[:n] + rtn := http.DetectContentType(buf) + if rtn == "application/octet-stream" { + return "" + } + return rtn +} + +func GetCmdExitCode(cmd *exec.Cmd, err error) int { + if cmd == nil || cmd.ProcessState == nil { + return GetExitCode(err) + } + status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus) + if !ok { + return cmd.ProcessState.ExitCode() + } + signaled := status.Signaled() + if signaled { + signal := status.Signal() + return 128 + int(signal) + } + exitStatus := status.ExitStatus() + return exitStatus +} + +func GetExitCode(err error) int { + if err == nil { + return 0 + } + if exitErr, ok := err.(*exec.ExitError); ok { + return exitErr.ExitCode() + } else { + return -1 + } +} + +func GetFirstLine(s string) string { + idx := strings.Index(s, "\n") + if idx == -1 { + return s + } + return s[0:idx] +} diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index d6a82b48e..d828e5800 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -13,6 +13,7 @@ import ( "sync" ) +const WaveVersion = "v0.1.0" const DefaultWaveHome = "~/.w2" const WaveHomeVarName = "WAVETERM_HOME" const WaveDevVarName = "WAVETERM_DEV" diff --git a/public/style.less b/public/style.less index 955d81e17..2af4d1919 100644 --- a/public/style.less +++ b/public/style.less @@ -2,19 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 @import "./reset.less"; - -:root { - --main-bg-color: #000000; - --border-color: #333333; - --main-color: #f7f7f7; - --base-font: normal 15px / normal "Lato", sans-serif; - --fixed-font: normal 12px / normal "Hack", monospace; - --app-accent-color: rgb(88, 193, 66); - --panel-bg-color: rgba(31, 33, 31, 1); - --highlight-bg-color: rgba(255, 255, 255, 0.2); - --markdown-font: -apple-system, BlinkMacSystemFont, "Segoe UI", "Noto Sans", Helvetica, Arial, sans-serif, - "Apple Color Emoji", "Segoe UI Emoji"; -} +@import "./theme.less"; body { display: flex; @@ -22,7 +10,7 @@ body { width: 100vw; height: 100vh; background-color: var(--main-bg-color); - color: var(--main-color); + color: var(--main-text-color); font: var(--base-font); overflow: hidden; } diff --git a/public/theme.less b/public/theme.less new file mode 100644 index 000000000..682cf8128 --- /dev/null +++ b/public/theme.less @@ -0,0 +1,19 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +:root { + --main-text-color: #f7f7f7; + --secondary-text-color: rgb(195, 200, 194); + --main-bg-color: #000000; + --border-color: #333333; + --base-font: normal 15px / normal "Lato", sans-serif; + --fixed-font: normal 12px / normal "Hack", monospace; + --accent-color: rgb(88, 193, 66); + --panel-bg-color: rgba(31, 33, 31, 1); + --highlight-bg-color: rgba(255, 255, 255, 0.2); + --markdown-font: -apple-system, BlinkMacSystemFont, "Segoe UI", "Noto Sans", Helvetica, Arial, sans-serif, + "Apple Color Emoji", "Segoe UI Emoji"; + --error-color: rgb(229, 77, 46); + --warning-color: rgb(224, 185, 86); + --success-color: rgb(78, 154, 6); +}