2024-08-19 04:26:44 +00:00
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package conncontroller
import (
"context"
2024-08-24 01:12:40 +00:00
"errors"
2024-08-19 04:26:44 +00:00
"fmt"
"io"
2024-08-24 01:12:40 +00:00
"io/fs"
2024-08-19 04:26:44 +00:00
"log"
"net"
2024-08-24 01:12:40 +00:00
"os"
"path/filepath"
2024-08-19 04:26:44 +00:00
"strings"
"sync"
2024-08-24 01:12:40 +00:00
"sync/atomic"
"time"
2024-08-19 04:26:44 +00:00
2024-08-24 01:12:40 +00:00
"github.com/kevinburke/ssh_config"
2024-08-19 04:26:44 +00:00
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
2024-08-24 01:12:40 +00:00
"github.com/wavetermdev/thenextwave/pkg/wps"
2024-08-19 04:26:44 +00:00
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh"
)
2024-08-24 01:12:40 +00:00
const (
Status_Init = "init"
Status_Connecting = "connecting"
Status_Connected = "connected"
Status_Disconnected = "disconnected"
Status_Error = "error"
)
2024-08-19 04:26:44 +00:00
var globalLock = & sync . Mutex { }
var clientControllerMap = make ( map [ remote . SSHOpts ] * SSHConn )
type SSHConn struct {
Lock * sync . Mutex
2024-08-24 01:12:40 +00:00
Status string
2024-08-19 04:26:44 +00:00
Opts * remote . SSHOpts
Client * ssh . Client
SockName string
DomainSockListener net . Listener
ConnController * ssh . Session
2024-08-24 01:12:40 +00:00
Error string
HasWaiter * atomic . Bool
}
func GetAllConnStatus ( ) [ ] wshrpc . ConnStatus {
globalLock . Lock ( )
defer globalLock . Unlock ( )
var connStatuses [ ] wshrpc . ConnStatus
for _ , conn := range clientControllerMap {
connStatuses = append ( connStatuses , conn . DeriveConnStatus ( ) )
}
return connStatuses
}
func ( conn * SSHConn ) DeriveConnStatus ( ) wshrpc . ConnStatus {
conn . Lock . Lock ( )
defer conn . Lock . Unlock ( )
return wshrpc . ConnStatus {
Status : conn . Status ,
2024-08-26 23:19:03 +00:00
Connected : conn . Status == Status_Connected ,
2024-08-24 01:12:40 +00:00
Connection : conn . Opts . String ( ) ,
Error : conn . Error ,
}
}
func ( conn * SSHConn ) FireConnChangeEvent ( ) {
status := conn . DeriveConnStatus ( )
event := wshrpc . WaveEvent {
Event : wshrpc . Event_ConnChange ,
Scopes : [ ] string {
fmt . Sprintf ( "connection:%s" , conn . GetName ( ) ) ,
} ,
Data : status ,
}
log . Printf ( "sending event: %+#v" , event )
wps . Broker . Publish ( event )
2024-08-19 04:26:44 +00:00
}
func ( conn * SSHConn ) Close ( ) error {
2024-08-24 01:12:40 +00:00
defer conn . FireConnChangeEvent ( )
conn . WithLock ( func ( ) {
if conn . Status == Status_Connected || conn . Status == Status_Connecting {
// if status is init, disconnected, or error don't change it
conn . Status = Status_Disconnected
}
conn . close_nolock ( )
} )
// we must wait for the waiter to complete
startTime := time . Now ( )
for conn . HasWaiter . Load ( ) {
time . Sleep ( 10 * time . Millisecond )
if time . Since ( startTime ) > 2 * time . Second {
return fmt . Errorf ( "timeout waiting for waiter to complete" )
}
}
return nil
}
func ( conn * SSHConn ) close_nolock ( ) {
// does not set status (that should happen at another level)
2024-08-19 04:26:44 +00:00
if conn . DomainSockListener != nil {
conn . DomainSockListener . Close ( )
conn . DomainSockListener = nil
}
if conn . ConnController != nil {
conn . ConnController . Close ( )
conn . ConnController = nil
}
2024-08-24 01:12:40 +00:00
if conn . Client != nil {
conn . Client . Close ( )
conn . Client = nil
}
}
func ( conn * SSHConn ) GetDomainSocketName ( ) string {
conn . Lock . Lock ( )
defer conn . Lock . Unlock ( )
return conn . SockName
}
func ( conn * SSHConn ) GetStatus ( ) string {
conn . Lock . Lock ( )
defer conn . Lock . Unlock ( )
return conn . Status
}
func ( conn * SSHConn ) GetName ( ) string {
// no lock required because opts is immutable
return conn . Opts . String ( )
2024-08-19 04:26:44 +00:00
}
func ( conn * SSHConn ) OpenDomainSocketListener ( ) error {
2024-08-24 01:12:40 +00:00
var allowed bool
conn . WithLock ( func ( ) {
if conn . Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
} )
if ! allowed {
return fmt . Errorf ( "cannot open domain socket for %q when status is %q" , conn . GetName ( ) , conn . GetStatus ( ) )
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
client := conn . GetClient ( )
2024-08-19 04:26:44 +00:00
randStr , err := utilfn . RandomHexString ( 16 ) // 64-bits of randomness
if err != nil {
return fmt . Errorf ( "error generating random string: %w" , err )
}
sockName := fmt . Sprintf ( "/tmp/waveterm-%s.sock" , randStr )
2024-08-24 01:12:40 +00:00
log . Printf ( "remote domain socket %s %q\n" , conn . GetName ( ) , sockName )
listener , err := client . ListenUnix ( sockName )
2024-08-19 04:26:44 +00:00
if err != nil {
return fmt . Errorf ( "unable to request connection domain socket: %v" , err )
}
2024-08-24 01:12:40 +00:00
conn . WithLock ( func ( ) {
conn . SockName = sockName
conn . DomainSockListener = listener
} )
2024-08-19 04:26:44 +00:00
go func ( ) {
2024-08-24 01:12:40 +00:00
defer conn . WithLock ( func ( ) {
2024-08-19 04:26:44 +00:00
conn . DomainSockListener = nil
2024-08-24 01:12:40 +00:00
conn . SockName = ""
} )
2024-08-19 04:26:44 +00:00
wshutil . RunWshRpcOverListener ( listener )
} ( )
return nil
}
func ( conn * SSHConn ) StartConnServer ( ) error {
2024-08-24 01:12:40 +00:00
var allowed bool
conn . WithLock ( func ( ) {
if conn . Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
} )
if ! allowed {
return fmt . Errorf ( "cannot start conn server for %q when status is %q" , conn . GetName ( ) , conn . GetStatus ( ) )
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
client := conn . GetClient ( )
wshPath := remote . GetWshPath ( client )
2024-08-19 04:26:44 +00:00
rpcCtx := wshrpc . RpcContext {
2024-08-20 21:56:48 +00:00
ClientType : wshrpc . ClientType_ConnServer ,
2024-08-24 01:12:40 +00:00
Conn : conn . GetName ( ) ,
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
sockName := conn . GetDomainSocketName ( )
jwtToken , err := wshutil . MakeClientJWTToken ( rpcCtx , sockName )
2024-08-19 04:26:44 +00:00
if err != nil {
return fmt . Errorf ( "unable to create jwt token for conn controller: %w" , err )
}
2024-08-24 01:12:40 +00:00
sshSession , err := client . NewSession ( )
2024-08-19 04:26:44 +00:00
if err != nil {
return fmt . Errorf ( "unable to create ssh session for conn controller: %w" , err )
}
pipeRead , pipeWrite := io . Pipe ( )
sshSession . Stdout = pipeWrite
sshSession . Stderr = pipeWrite
cmdStr := fmt . Sprintf ( "%s=\"%s\" %s connserver" , wshutil . WaveJwtTokenVarName , jwtToken , wshPath )
log . Printf ( "starting conn controller: %s\n" , cmdStr )
err = sshSession . Start ( cmdStr )
if err != nil {
return fmt . Errorf ( "unable to start conn controller: %w" , err )
}
2024-08-24 01:12:40 +00:00
conn . WithLock ( func ( ) {
conn . ConnController = sshSession
} )
2024-08-19 04:26:44 +00:00
// service the I/O
go func ( ) {
// wait for termination, clear the controller
2024-08-24 01:12:40 +00:00
defer conn . WithLock ( func ( ) {
conn . ConnController = nil
} )
2024-08-19 04:26:44 +00:00
waitErr := sshSession . Wait ( )
2024-08-24 01:12:40 +00:00
log . Printf ( "conn controller (%q) terminated: %v" , conn . GetName ( ) , waitErr )
2024-08-19 04:26:44 +00:00
} ( )
go func ( ) {
readErr := wshutil . StreamToLines ( pipeRead , func ( line [ ] byte ) {
lineStr := string ( line )
if ! strings . HasSuffix ( lineStr , "\n" ) {
lineStr += "\n"
}
2024-08-24 01:12:40 +00:00
log . Printf ( "[conncontroller:%s:output] %s" , conn . GetName ( ) , lineStr )
2024-08-19 04:26:44 +00:00
} )
if readErr != nil && readErr != io . EOF {
2024-08-24 01:12:40 +00:00
log . Printf ( "[conncontroller:%s] error reading output: %v\n" , conn . GetName ( ) , readErr )
2024-08-19 04:26:44 +00:00
}
} ( )
return nil
}
func ( conn * SSHConn ) checkAndInstallWsh ( ctx context . Context ) error {
2024-08-24 01:12:40 +00:00
client := conn . GetClient ( )
if client == nil {
return fmt . Errorf ( "client is nil" )
}
2024-08-19 04:26:44 +00:00
// check that correct wsh extensions are installed
expectedVersion := fmt . Sprintf ( "wsh v%s" , wavebase . WaveVersion )
clientVersion , err := remote . GetWshVersion ( client )
if err == nil && clientVersion == expectedVersion {
return nil
}
2024-08-24 01:12:40 +00:00
// TODO add some progress to SSHConn about install status
2024-08-19 04:26:44 +00:00
var queryText string
var title string
if err != nil {
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
title = "Install Wsh Shell Extensions"
} else {
queryText = fmt . Sprintf ( "Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?" , clientVersion , expectedVersion )
title = "Update Wsh Shell Extensions"
}
request := & userinput . UserInputRequest {
ResponseType : "confirm" ,
QueryText : queryText ,
Title : title ,
CheckBoxMsg : "Don't show me this again" ,
}
response , err := userinput . GetUserInput ( ctx , request )
if err != nil || ! response . Confirm {
return err
}
log . Printf ( "attempting to install wsh to `%s@%s`" , client . User ( ) , client . RemoteAddr ( ) . String ( ) )
clientOs , err := remote . GetClientOs ( client )
if err != nil {
return err
}
clientArch , err := remote . GetClientArch ( client )
if err != nil {
return err
}
// attempt to install extension
wshLocalPath := shellutil . GetWshBinaryPath ( wavebase . WaveVersion , clientOs , clientArch )
err = remote . CpHostToRemote ( client , wshLocalPath , "~/.waveterm/bin/wsh" )
if err != nil {
return err
}
2024-08-24 01:12:40 +00:00
log . Printf ( "successfully installed wsh on %s\n" , conn . GetName ( ) )
2024-08-19 04:26:44 +00:00
return nil
}
2024-08-24 01:12:40 +00:00
func ( conn * SSHConn ) GetClient ( ) * ssh . Client {
conn . Lock . Lock ( )
defer conn . Lock . Unlock ( )
return conn . Client
}
2024-08-19 04:26:44 +00:00
2024-08-24 01:12:40 +00:00
func ( conn * SSHConn ) Reconnect ( ctx context . Context ) error {
err := conn . Close ( )
if err != nil {
return err
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
return conn . Connect ( ctx )
}
2024-08-19 04:26:44 +00:00
2024-08-24 01:12:40 +00:00
// does not return an error since that error is stored inside of SSHConn
func ( conn * SSHConn ) Connect ( ctx context . Context ) error {
var connectAllowed bool
conn . WithLock ( func ( ) {
if conn . Status == Status_Connecting || conn . Status == Status_Connected {
connectAllowed = false
} else {
conn . Status = Status_Connecting
conn . Error = ""
connectAllowed = true
}
} )
if ! connectAllowed {
return fmt . Errorf ( "cannot connect to %q when status is %q" , conn . GetName ( ) , conn . GetStatus ( ) )
}
conn . FireConnChangeEvent ( )
err := conn . connectInternal ( ctx )
conn . WithLock ( func ( ) {
if err != nil {
conn . Status = Status_Error
conn . Error = err . Error ( )
conn . close_nolock ( )
} else {
conn . Status = Status_Connected
}
} )
conn . FireConnChangeEvent ( )
return err
}
func ( conn * SSHConn ) WithLock ( fn func ( ) ) {
conn . Lock . Lock ( )
defer conn . Lock . Unlock ( )
fn ( )
}
func ( conn * SSHConn ) connectInternal ( ctx context . Context ) error {
client , err := remote . ConnectToClient ( ctx , conn . Opts ) //todo specify or remove opts
2024-08-19 04:26:44 +00:00
if err != nil {
2024-08-24 01:12:40 +00:00
return err
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
conn . WithLock ( func ( ) {
conn . Client = client
} )
2024-08-19 04:26:44 +00:00
err = conn . OpenDomainSocketListener ( )
if err != nil {
2024-08-24 01:12:40 +00:00
return err
2024-08-19 04:26:44 +00:00
}
installErr := conn . checkAndInstallWsh ( ctx )
if installErr != nil {
2024-08-24 01:12:40 +00:00
return fmt . Errorf ( "conncontroller %s wsh install error: %v" , conn . GetName ( ) , installErr )
2024-08-19 04:26:44 +00:00
}
csErr := conn . StartConnServer ( )
if csErr != nil {
2024-08-24 01:12:40 +00:00
return fmt . Errorf ( "conncontroller %s start wsh connserver error: %v" , conn . GetName ( ) , csErr )
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
conn . HasWaiter . Store ( true )
go conn . waitForDisconnect ( )
return nil
}
2024-08-19 04:26:44 +00:00
2024-08-24 01:12:40 +00:00
func ( conn * SSHConn ) waitForDisconnect ( ) {
defer conn . FireConnChangeEvent ( )
defer conn . HasWaiter . Store ( false )
client := conn . GetClient ( )
if client == nil {
return
}
err := client . Wait ( )
conn . WithLock ( func ( ) {
if err != nil {
if conn . Status != Status_Disconnected {
// don't set the error if our status is disconnected (because this error was caused by an explicit close)
conn . Status = Status_Error
conn . Error = err . Error ( )
}
} else {
// not sure if this is possible, because I think Wait() always returns an error (although that's not in the docs)
conn . Status = Status_Disconnected
}
conn . close_nolock ( )
} )
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
func getConnInternal ( opts * remote . SSHOpts ) * SSHConn {
2024-08-19 04:26:44 +00:00
globalLock . Lock ( )
defer globalLock . Unlock ( )
2024-08-24 01:12:40 +00:00
rtn := clientControllerMap [ * opts ]
if rtn == nil {
rtn = & SSHConn { Lock : & sync . Mutex { } , Status : Status_Init , Opts : opts , HasWaiter : & atomic . Bool { } }
clientControllerMap [ * opts ] = rtn
}
return rtn
}
func GetConn ( ctx context . Context , opts * remote . SSHOpts , shouldConnect bool ) * SSHConn {
conn := getConnInternal ( opts )
if conn . Client == nil && shouldConnect {
conn . Connect ( ctx )
}
return conn
}
func DisconnectClient ( opts * remote . SSHOpts ) error {
conn := getConnInternal ( opts )
if conn == nil {
return fmt . Errorf ( "client %q not found" , opts . String ( ) )
}
err := conn . Close ( )
return err
}
func resolveSshConfigPatterns ( configFiles [ ] string ) ( [ ] string , error ) {
// using two separate containers to track order and have O(1) lookups
// since go does not have an ordered map primitive
var discoveredPatterns [ ] string
alreadyUsed := make ( map [ string ] bool )
alreadyUsed [ "" ] = true // this excludes the empty string from potential alias
var openedFiles [ ] fs . File
defer func ( ) {
for _ , openedFile := range openedFiles {
openedFile . Close ( )
}
} ( )
var errs [ ] error
for _ , configFile := range configFiles {
fd , openErr := os . Open ( configFile )
openedFiles = append ( openedFiles , fd )
if fd == nil {
errs = append ( errs , openErr )
continue
}
2024-08-19 04:26:44 +00:00
2024-08-24 01:12:40 +00:00
cfg , _ := ssh_config . Decode ( fd )
for _ , host := range cfg . Hosts {
// for each host, find the first good alias
for _ , hostPattern := range host . Patterns {
hostPatternStr := hostPattern . String ( )
if ! strings . Contains ( hostPatternStr , "*" ) || alreadyUsed [ hostPatternStr ] {
discoveredPatterns = append ( discoveredPatterns , hostPatternStr )
alreadyUsed [ hostPatternStr ] = true
break
}
}
}
2024-08-19 04:26:44 +00:00
}
2024-08-24 01:12:40 +00:00
if len ( errs ) == len ( configFiles ) {
errs = append ( [ ] error { fmt . Errorf ( "no ssh config files could be opened:\n" ) } , errs ... )
return nil , errors . Join ( errs ... )
}
if len ( discoveredPatterns ) == 0 {
return nil , fmt . Errorf ( "no compatible hostnames found in ssh config files" )
}
return discoveredPatterns , nil
}
func GetConnectionsFromConfig ( ) ( [ ] string , error ) {
home := wavebase . GetHomeDir ( )
localConfig := filepath . Join ( home , ".ssh" , "config" )
systemConfig := filepath . Join ( "/etc" , "ssh" , "config" )
sshConfigFiles := [ ] string { localConfig , systemConfig }
ssh_config . ReloadConfigs ( )
return resolveSshConfigPatterns ( sshConfigFiles )
2024-08-19 04:26:44 +00:00
}