diff --git a/.github/workflows/fleet-and-orbit.yml b/.github/workflows/fleet-and-orbit.yml index 997704d371..5ef8688257 100644 --- a/.github/workflows/fleet-and-orbit.yml +++ b/.github/workflows/fleet-and-orbit.yml @@ -383,6 +383,53 @@ jobs: run: | "C:\Program Files\Orbit\bin\orbit\orbit.exe" shell -- --json "select * from osquery_info;" | jq -e "if (.[0]) then true else false end" + - name: Fleet Service Tests + shell: powershell + run: | + #Tests setup + $serviceName = "Fleet osquery" + $defaultWaitTime = 2 + $orbitRequiredInitExtraTime = 20 + + #Test 1 - Check that the service starts without issues + Stop-Service -Name $serviceName + Start-Service -Name $serviceName + Get-Service -Name $serviceName | %{ if ($_.Status -ne "Running") { throw "Test #1 failed" } } + + #Test 2 - Check that the service stops without issues + Stop-Service -Name $serviceName + Get-Service -Name $serviceName | %{ if ($_.Status -ne "Stopped") { throw "Test #2 failed" } } + + #Test 3 - Check that no orbit.exe is running after service stop + Start-Service -Name $serviceName + Stop-Service -Name $serviceName + Start-Sleep -Seconds $defaultWaitTime # shutdown wait time + Get-Process | %{ if ($_.Name -eq "orbit") { throw "Test #3 failed" } } + + #Test 4 - Check that service starts in less than 3 secs + Start-Job { Start-Service -Name $args[0] } -ArgumentList $serviceName | Out-Null #async operation + Start-Sleep -Seconds $defaultWaitTime + Get-Service -Name $serviceName | %{ if ($_.Status -ne "Running") { throw "Test #4 failed" } } + + #Test 5 - Check that service stops in less than 3 secs + Start-Job { Stop-Service -Name $args[0] } -ArgumentList $serviceName | Out-Null #async operation + Start-Sleep -Seconds $defaultWaitTime + Get-Service -Name $serviceName | %{ if ($_.Status -ne "Stopped") { throw "Test #5 failed" } } + + #Test 6 - Check that no osqueryd process is running once service stops + Start-Service -Name $serviceName + Start-Sleep -Seconds $orbitRequiredInitExtraTime # orbit takes some time to spawn osquery and desktop app due to update check + Stop-Service -Name $serviceName + Start-Sleep -Seconds $defaultWaitTime + Get-Process | %{ if ($_.Name -eq "osqueryd") { throw "Test #6 failed" } } + + #Test 7 - Check that no fleet-desktop process is running once service stops + Start-Service -Name $serviceName + Start-Sleep -Seconds $orbitRequiredInitExtraTime # orbit takes some time to spawn osquery and desktop app due to update check + Stop-Service -Name $serviceName + Start-Sleep -Seconds $defaultWaitTime + Get-Process | %{ if ($_.Name -eq "fleet-desktop") { throw "Test #7 failed" } } + - name: Upload Orbit logs if: always() uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v2 diff --git a/orbit/changes/bug-7874-call-scm-on-service-start b/orbit/changes/bug-7874-call-scm-on-service-start new file mode 100644 index 0000000000..113f394d37 --- /dev/null +++ b/orbit/changes/bug-7874-call-scm-on-service-start @@ -0,0 +1,3 @@ +* When running on Windows, Fleet service was getting killed by the OS when +service start takes longer than 30 secs due to missing calls to the +Service Control Manager (SCM) APIs. diff --git a/orbit/cmd/orbit/orbit.go b/orbit/cmd/orbit/orbit.go index 5a2db6bd2c..3d71b19b36 100644 --- a/orbit/cmd/orbit/orbit.go +++ b/orbit/cmd/orbit/orbit.go @@ -26,6 +26,8 @@ import ( "github.com/fleetdm/fleet/v4/orbit/pkg/execuser" "github.com/fleetdm/fleet/v4/orbit/pkg/insecure" "github.com/fleetdm/fleet/v4/orbit/pkg/osquery" + "github.com/fleetdm/fleet/v4/orbit/pkg/osservice" + "github.com/fleetdm/fleet/v4/orbit/pkg/platform" "github.com/fleetdm/fleet/v4/orbit/pkg/table" "github.com/fleetdm/fleet/v4/orbit/pkg/update" "github.com/fleetdm/fleet/v4/orbit/pkg/update/filestore" @@ -35,7 +37,6 @@ import ( "github.com/oklog/run" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - gopsutil_process "github.com/shirou/gopsutil/v3/process" "github.com/urfave/cli/v2" "gopkg.in/natefinch/lumberjack.v2" ) @@ -260,6 +261,12 @@ func main() { g run.Group ) + // List of interrupt functions to call during service teardown + var interruptFunctions []func(err error) + + // Setting up the system service management early on the process lifetime + go osservice.SetupServiceManagement(constant.SystemServiceName, c.Bool("fleet-desktop"), &interruptFunctions) + // NOTE: When running in dev-mode, even if `disable-updates` is set, // it fetches osqueryd once as part of initialization. if !c.Bool("disable-updates") || c.Bool("dev-mode") { @@ -297,6 +304,7 @@ func main() { log.Info().Msg("exiting due to successful early update") return nil } + g.Add(updateRunner.Execute, updateRunner.Interrupt) osquerydLocalTarget, err := updater.Get("osqueryd") @@ -487,7 +495,6 @@ func main() { capabilities := fleet.CapabilityMap{} orbitClient, err := service.NewOrbitClient(fleetURL, c.String("fleet-certificate"), c.Bool("insecure"), enrollSecret, uuidStr, capabilities) - if err != nil { return fmt.Errorf("error new orbit client: %w", err) } @@ -552,6 +559,10 @@ func main() { } g.Add(r.Execute, r.Interrupt) + // Only osquery runner is being interrupted + // This ends up forcing the rest of the interrupt functions in the runner group to get called + interruptFunctions = append(interruptFunctions, r.Interrupt) + registerExtensionRunner(&g, r.ExtensionSocketPath(), deviceAuthToken) checkerClient, err := service.NewOrbitClient(fleetURL, c.String("fleet-certificate"), c.Bool("insecure"), enrollSecret, uuidStr, capabilities) @@ -663,10 +674,10 @@ func (d *desktopRunner) execute() error { // Second retry logic to monitor fleet-desktop. // Call with waitFirst=true to give some time for the process to start. if done := retry(30*time.Second, true, d.interruptCh, func() bool { - switch _, err := getProcessByName(constant.DesktopAppExecName); { + switch _, err := platform.GetProcessByName(constant.DesktopAppExecName); { case err == nil: return true // all good, process is running, retry. - case errors.Is(err, errProcessNotFound): + case errors.Is(err, platform.ErrProcessNotFound): log.Debug().Msgf("%s process not running", constant.DesktopAppExecName) return false // process is not running, do not retry. default: @@ -705,7 +716,7 @@ func (d *desktopRunner) interrupt(err error) { close(d.interruptCh) // Signal execute to return. <-d.executeDoneCh // Wait for execute to return. - if err := killProcessByName(constant.DesktopAppExecName); err != nil { + if err := platform.KillProcessByName(constant.DesktopAppExecName); err != nil { log.Error().Err(err).Msg("killProcess") } } @@ -747,7 +758,6 @@ func getUUID(osqueryPath string) (string, error) { return "", fmt.Errorf("invalid number of rows from system_info query: %d", len(uuids)) } return uuids[0].UuidString, nil - } // getOrbitNodeKeyOrEnroll attempts to read the orbit node key if the file exists on disk @@ -806,42 +816,6 @@ func loadOrGenerateToken(rootDir string) (string, error) { } } -func killProcessByName(name string) error { - foundProcess, err := getProcessByName(name) - if err != nil { - return fmt.Errorf("get process: %w", err) - } - if err := foundProcess.Kill(); err != nil { - return fmt.Errorf("kill process %d: %w", foundProcess.Pid, err) - } - return nil -} - -var errProcessNotFound = errors.New("process not found") - -func getProcessByName(name string) (*gopsutil_process.Process, error) { - processes, err := gopsutil_process.Processes() - if err != nil { - return nil, err - } - var foundProcess *gopsutil_process.Process - for _, process := range processes { - processName, err := process.Name() - if err != nil { - log.Debug().Err(err).Int32("pid", process.Pid).Msg("get process name") - continue - } - if strings.HasPrefix(processName, name) { - foundProcess = process - break - } - } - if foundProcess == nil { - return nil, errProcessNotFound - } - return foundProcess, nil -} - var versionCommand = &cli.Command{ Name: "version", Usage: "Get the orbit version", diff --git a/orbit/pkg/constant/constant.go b/orbit/pkg/constant/constant.go index 8395e50943..4ea3ba6285 100644 --- a/orbit/pkg/constant/constant.go +++ b/orbit/pkg/constant/constant.go @@ -20,4 +20,13 @@ const ( OrbitEnrollMaxRetries = 10 // OrbitEnrollRetrySleep is the time duration to sleep between retries OrbitEnrollRetrySleep = 5 * time.Second + // OsquerydName is the name of osqueryd binary + // We use osqueryd as name to properly identify the process when listing + // running processes/tasks. + OsquerydName = "osqueryd" + // OsqueryPidfile is the file containing the PID of the running osqueryd process + OsqueryPidfile = "osquery.pid" + // SystemServiceName is the name of Orbit system service + // The service name is used by the OS service management framework + SystemServiceName = "Fleet osquery" ) diff --git a/orbit/pkg/osservice/osservice_notwindows.go b/orbit/pkg/osservice/osservice_notwindows.go new file mode 100644 index 0000000000..1a3e48437b --- /dev/null +++ b/orbit/pkg/osservice/osservice_notwindows.go @@ -0,0 +1,9 @@ +//go:build !windows +// +build !windows + +package osservice + +// SetupServiceManagement is currently a placeholder for non-windows OSes +// system service configuration +func SetupServiceManagement(serviceName string, fleetDesktopPresent bool, shutdownFunctions *[]func(err error)) { +} diff --git a/orbit/pkg/osservice/osservice_windows.go b/orbit/pkg/osservice/osservice_windows.go new file mode 100644 index 0000000000..2b0ac543a4 --- /dev/null +++ b/orbit/pkg/osservice/osservice_windows.go @@ -0,0 +1,112 @@ +//go:build windows +// +build windows + +package osservice + +import ( + "errors" + "os" + "time" + + "github.com/fleetdm/fleet/v4/orbit/pkg/constant" + "github.com/fleetdm/fleet/v4/orbit/pkg/platform" + + "github.com/rs/zerolog/log" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" +) + +type windowsService struct { + shutdownFunctions *[]func(err error) + fleetDesktopPresent bool +} + +func (m *windowsService) bestEffortShutdown() { + serviceShutdown := errors.New("service is shutting down") + + // Calling interrupt functions to gracefully shutdown runners + for _, interruptFn := range *m.shutdownFunctions { + interruptFn(serviceShutdown) + } + + // Now ensuring that no child process are left + if m.fleetDesktopPresent { + err := platform.KillProcessByName(constant.DesktopAppExecName) + if err != nil { + log.Error().Err(err).Msg("The desktop app couldn't be killed") + } + } + + err := platform.KillAllProcessByName(constant.OsquerydName) + if err != nil { + log.Error().Err(err).Msg("The child osqueryd processes cannot be killed") + } +} + +func (m *windowsService) Execute(args []string, requests <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { + // Accepted service operations + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown + + // Expected service status update during initialization + changes <- svc.Status{State: svc.StartPending} + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + + // The listening loop below will keep listening for new SCM requests + for req := range requests { + switch req.Cmd { + + case svc.Interrogate: + changes <- req.CurrentStatus + + case svc.Stop, svc.Shutdown: + + // Service shutdown was requested + // Updating the service state to indicate stop + changes <- svc.Status{State: svc.Stopped, Win32ExitCode: 0} + + // Best effort tear down + // Runner group's interrupt functions will be called here + m.bestEffortShutdown() + + // Dummy delay to allow the SCM to pick up the changes + time.Sleep(500 * time.Millisecond) + + // Drastic teardown + os.Exit(windows.NO_ERROR) + + default: + return false, uint32(windows.ERROR_INVALID_SERVICE_CONTROL) + } + } + + return false, 0 +} + +// SetupServiceManagement implements the dispatcher and notification logic to +// interact with the Windows Service Control Manager (SCM) +func SetupServiceManagement(serviceName string, fleetDesktopPresent bool, shutdownFunctions *[]func(err error)) { + if serviceName == "" { + log.Error().Msg(" service name should not be empty") + return + } + + // Ensuring that we are only calling the SCM if running as a service + isWindowsService, err := svc.IsWindowsService() + if err != nil { + log.Error().Err(err).Msg("couldn't determine if running as a service") + return + } + + if isWindowsService { + srvData := windowsService{ + shutdownFunctions: shutdownFunctions, + fleetDesktopPresent: fleetDesktopPresent, + } + + // Registering our service into the SCM + err := svc.Run(serviceName, &srvData) + if err != nil { + log.Info().Err(err).Msg("SCM registration failed") + } + } +} diff --git a/orbit/pkg/platform/platform.go b/orbit/pkg/platform/platform.go new file mode 100644 index 0000000000..2ef662a860 --- /dev/null +++ b/orbit/pkg/platform/platform.go @@ -0,0 +1,230 @@ +package platform + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/mitchellh/go-ps" + "github.com/rs/zerolog/log" + gopsutil_process "github.com/shirou/gopsutil/v3/process" +) + +var ErrProcessNotFound = errors.New("process not found") + +// readPidFromFile reads a PID from a file +func readPidFromFile(destDir string, destFile string) (int32, error) { + // Defense programming - sanity checks on inputs + if destDir == "" { + return 0, errors.New(" destination directory should not be empty") + } + + if destFile == "" { + return 0, errors.New(" destination file should not be empty") + } + + pidFilePath := filepath.Join(destDir, destFile) + data, err := os.ReadFile(pidFilePath) + if err != nil { + return 0, fmt.Errorf("error reading pidfile %s: %w", pidFilePath, err) + } + + intNumber, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 32) + if err != nil { + return 0, fmt.Errorf("error converting pidfile %s: %w", pidFilePath, err) + } + + return int32(intNumber), err +} + +// processNameMatches returns whether the process running with the given pid matches +// the executable name (case insensitive). +// If there's no process running with the given pid then (false, nil) is returned. +func processNameMatches(pid int, expectedPrefix string) (bool, error) { + if pid == 0 { + return false, errors.New("process id should not be zero") + } + + if expectedPrefix == "" { + return false, errors.New("expected prefix should not be empty") + } + + process, err := ps.FindProcess(pid) + if err != nil { + return false, fmt.Errorf("find process: %d: %w", pid, err) + } + + if process == nil { + return false, nil + } + + return strings.HasPrefix(strings.ToLower(process.Executable()), strings.ToLower(expectedPrefix)), nil +} + +// killPID kills a process by PID +func killPID(pid int32) error { + if pid == 0 { + return errors.New("process id should not be zero") + } + + processes, err := gopsutil_process.Processes() + if err != nil { + return err + } + + for _, process := range processes { + if pid == process.Pid { + process.Kill() + break + } + } + + return nil +} + +// KillProcessByName kills a single process by its name +func KillProcessByName(name string) error { + if name == "" { + return errors.New("process name should not be empty") + } + + foundProcess, err := GetProcessByName(name) + if err != nil { + return fmt.Errorf("get process: %w", err) + } + + if err := foundProcess.Kill(); err != nil { + return fmt.Errorf("kill process %d: %w", foundProcess.Pid, err) + } + + return nil +} + +// getProcessesByName gets a single process object by its name +func getProcessesByName(name string) ([]*gopsutil_process.Process, error) { + if name == "" { + return nil, errors.New("process name should not be empty") + } + + processes, err := gopsutil_process.Processes() + if err != nil { + return nil, err + } + + var foundProcesses []*gopsutil_process.Process + for _, process := range processes { + processName, err := process.Name() + if err != nil { + log.Debug().Err(err).Int32("pid", process.Pid).Msg("get process name") + continue + } + + if strings.HasPrefix(processName, name) { + foundProcesses = append(foundProcesses, process) + } + } + + if len(foundProcesses) == 0 { + return nil, ErrProcessNotFound + } + + return foundProcesses, nil +} + +// KillAllProcessByName kills all process found by their name +func KillAllProcessByName(name string) error { + if name == "" { + return errors.New("process name should not be empty") + } + + foundProcesses, err := getProcessesByName(name) + if err != nil { + return fmt.Errorf("get process: %w", err) + } + + // Killing found processes + for _, foundProcess := range foundProcesses { + if err := foundProcess.Kill(); err != nil { + return fmt.Errorf("kill process %d: %w", foundProcess.Pid, err) + } + } + + return nil +} + +// GetProcessByName gets a single process object by its name +func GetProcessByName(name string) (*gopsutil_process.Process, error) { + if name == "" { + return nil, errors.New("process name should not be empty") + } + + processes, err := gopsutil_process.Processes() + if err != nil { + return nil, err + } + + var foundProcess *gopsutil_process.Process + for _, process := range processes { + processName, err := process.Name() + if err != nil { + log.Debug().Err(err).Int32("pid", process.Pid).Msg("get process name") + continue + } + + if strings.HasPrefix(processName, name) { + foundProcess = process + break + } + } + + if foundProcess == nil { + return nil, ErrProcessNotFound + } + + return foundProcess, nil +} + +// KillFromPIDFile kills a process taking the PID value from a file +func KillFromPIDFile(destDir string, pidFileName string, expectedExecName string) error { + if destDir == "" { + return errors.New("destination directory should not be empty") + } + + if pidFileName == "" { + return errors.New("PID file name should not be empty") + } + + if expectedExecName == "" { + return errors.New("expected executable name should not be empty") + } + + pid, err := readPidFromFile(destDir, pidFileName) + switch { + case err == nil: + // OK + case errors.Is(err, os.ErrNotExist): + return nil // we assume it's not running + default: + return fmt.Errorf("reading pid from: %s: %w", destDir, err) + } + + matches, err := processNameMatches(int(pid), expectedExecName) + if err != nil { + return fmt.Errorf("inspecting process %d: %w", pid, err) + } + + if !matches { + // Nothing to do, another process may be running with this pid + // (e.g. could happen after a restart). + return nil + } + + if err := killPID(pid); err != nil { + return fmt.Errorf("killing %d: %w", pid, err) + } + + return nil +}