Bug 7874: Adding SCM calls to register Orbit as a windows service (#7934)

* Bug 7874: Adding SCM calls to register Orbit as a windows service
This commit is contained in:
Marcos Oviedo 2022-09-27 11:52:41 -03:00 committed by GitHub
parent 142e298631
commit 381f628be7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 426 additions and 42 deletions

View file

@ -383,6 +383,53 @@ jobs:
run: |
"C:\Program Files\Orbit\bin\orbit\orbit.exe" shell -- --json "select * from osquery_info;" | jq -e "if (.[0]) then true else false end"
- name: Fleet Service Tests
shell: powershell
run: |
#Tests setup
$serviceName = "Fleet osquery"
$defaultWaitTime = 2
$orbitRequiredInitExtraTime = 20
#Test 1 - Check that the service starts without issues
Stop-Service -Name $serviceName
Start-Service -Name $serviceName
Get-Service -Name $serviceName | %{ if ($_.Status -ne "Running") { throw "Test #1 failed" } }
#Test 2 - Check that the service stops without issues
Stop-Service -Name $serviceName
Get-Service -Name $serviceName | %{ if ($_.Status -ne "Stopped") { throw "Test #2 failed" } }
#Test 3 - Check that no orbit.exe is running after service stop
Start-Service -Name $serviceName
Stop-Service -Name $serviceName
Start-Sleep -Seconds $defaultWaitTime # shutdown wait time
Get-Process | %{ if ($_.Name -eq "orbit") { throw "Test #3 failed" } }
#Test 4 - Check that service starts in less than 3 secs
Start-Job { Start-Service -Name $args[0] } -ArgumentList $serviceName | Out-Null #async operation
Start-Sleep -Seconds $defaultWaitTime
Get-Service -Name $serviceName | %{ if ($_.Status -ne "Running") { throw "Test #4 failed" } }
#Test 5 - Check that service stops in less than 3 secs
Start-Job { Stop-Service -Name $args[0] } -ArgumentList $serviceName | Out-Null #async operation
Start-Sleep -Seconds $defaultWaitTime
Get-Service -Name $serviceName | %{ if ($_.Status -ne "Stopped") { throw "Test #5 failed" } }
#Test 6 - Check that no osqueryd process is running once service stops
Start-Service -Name $serviceName
Start-Sleep -Seconds $orbitRequiredInitExtraTime # orbit takes some time to spawn osquery and desktop app due to update check
Stop-Service -Name $serviceName
Start-Sleep -Seconds $defaultWaitTime
Get-Process | %{ if ($_.Name -eq "osqueryd") { throw "Test #6 failed" } }
#Test 7 - Check that no fleet-desktop process is running once service stops
Start-Service -Name $serviceName
Start-Sleep -Seconds $orbitRequiredInitExtraTime # orbit takes some time to spawn osquery and desktop app due to update check
Stop-Service -Name $serviceName
Start-Sleep -Seconds $defaultWaitTime
Get-Process | %{ if ($_.Name -eq "fleet-desktop") { throw "Test #7 failed" } }
- name: Upload Orbit logs
if: always()
uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v2

View file

@ -0,0 +1,3 @@
* When running on Windows, Fleet service was getting killed by the OS when
service start takes longer than 30 secs due to missing calls to the
Service Control Manager (SCM) APIs.

View file

@ -26,6 +26,8 @@ import (
"github.com/fleetdm/fleet/v4/orbit/pkg/execuser"
"github.com/fleetdm/fleet/v4/orbit/pkg/insecure"
"github.com/fleetdm/fleet/v4/orbit/pkg/osquery"
"github.com/fleetdm/fleet/v4/orbit/pkg/osservice"
"github.com/fleetdm/fleet/v4/orbit/pkg/platform"
"github.com/fleetdm/fleet/v4/orbit/pkg/table"
"github.com/fleetdm/fleet/v4/orbit/pkg/update"
"github.com/fleetdm/fleet/v4/orbit/pkg/update/filestore"
@ -35,7 +37,6 @@ import (
"github.com/oklog/run"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
gopsutil_process "github.com/shirou/gopsutil/v3/process"
"github.com/urfave/cli/v2"
"gopkg.in/natefinch/lumberjack.v2"
)
@ -260,6 +261,12 @@ func main() {
g run.Group
)
// List of interrupt functions to call during service teardown
var interruptFunctions []func(err error)
// Setting up the system service management early on the process lifetime
go osservice.SetupServiceManagement(constant.SystemServiceName, c.Bool("fleet-desktop"), &interruptFunctions)
// NOTE: When running in dev-mode, even if `disable-updates` is set,
// it fetches osqueryd once as part of initialization.
if !c.Bool("disable-updates") || c.Bool("dev-mode") {
@ -297,6 +304,7 @@ func main() {
log.Info().Msg("exiting due to successful early update")
return nil
}
g.Add(updateRunner.Execute, updateRunner.Interrupt)
osquerydLocalTarget, err := updater.Get("osqueryd")
@ -487,7 +495,6 @@ func main() {
capabilities := fleet.CapabilityMap{}
orbitClient, err := service.NewOrbitClient(fleetURL, c.String("fleet-certificate"), c.Bool("insecure"), enrollSecret, uuidStr, capabilities)
if err != nil {
return fmt.Errorf("error new orbit client: %w", err)
}
@ -552,6 +559,10 @@ func main() {
}
g.Add(r.Execute, r.Interrupt)
// Only osquery runner is being interrupted
// This ends up forcing the rest of the interrupt functions in the runner group to get called
interruptFunctions = append(interruptFunctions, r.Interrupt)
registerExtensionRunner(&g, r.ExtensionSocketPath(), deviceAuthToken)
checkerClient, err := service.NewOrbitClient(fleetURL, c.String("fleet-certificate"), c.Bool("insecure"), enrollSecret, uuidStr, capabilities)
@ -663,10 +674,10 @@ func (d *desktopRunner) execute() error {
// Second retry logic to monitor fleet-desktop.
// Call with waitFirst=true to give some time for the process to start.
if done := retry(30*time.Second, true, d.interruptCh, func() bool {
switch _, err := getProcessByName(constant.DesktopAppExecName); {
switch _, err := platform.GetProcessByName(constant.DesktopAppExecName); {
case err == nil:
return true // all good, process is running, retry.
case errors.Is(err, errProcessNotFound):
case errors.Is(err, platform.ErrProcessNotFound):
log.Debug().Msgf("%s process not running", constant.DesktopAppExecName)
return false // process is not running, do not retry.
default:
@ -705,7 +716,7 @@ func (d *desktopRunner) interrupt(err error) {
close(d.interruptCh) // Signal execute to return.
<-d.executeDoneCh // Wait for execute to return.
if err := killProcessByName(constant.DesktopAppExecName); err != nil {
if err := platform.KillProcessByName(constant.DesktopAppExecName); err != nil {
log.Error().Err(err).Msg("killProcess")
}
}
@ -747,7 +758,6 @@ func getUUID(osqueryPath string) (string, error) {
return "", fmt.Errorf("invalid number of rows from system_info query: %d", len(uuids))
}
return uuids[0].UuidString, nil
}
// getOrbitNodeKeyOrEnroll attempts to read the orbit node key if the file exists on disk
@ -806,42 +816,6 @@ func loadOrGenerateToken(rootDir string) (string, error) {
}
}
func killProcessByName(name string) error {
foundProcess, err := getProcessByName(name)
if err != nil {
return fmt.Errorf("get process: %w", err)
}
if err := foundProcess.Kill(); err != nil {
return fmt.Errorf("kill process %d: %w", foundProcess.Pid, err)
}
return nil
}
var errProcessNotFound = errors.New("process not found")
func getProcessByName(name string) (*gopsutil_process.Process, error) {
processes, err := gopsutil_process.Processes()
if err != nil {
return nil, err
}
var foundProcess *gopsutil_process.Process
for _, process := range processes {
processName, err := process.Name()
if err != nil {
log.Debug().Err(err).Int32("pid", process.Pid).Msg("get process name")
continue
}
if strings.HasPrefix(processName, name) {
foundProcess = process
break
}
}
if foundProcess == nil {
return nil, errProcessNotFound
}
return foundProcess, nil
}
var versionCommand = &cli.Command{
Name: "version",
Usage: "Get the orbit version",

View file

@ -20,4 +20,13 @@ const (
OrbitEnrollMaxRetries = 10
// OrbitEnrollRetrySleep is the time duration to sleep between retries
OrbitEnrollRetrySleep = 5 * time.Second
// OsquerydName is the name of osqueryd binary
// We use osqueryd as name to properly identify the process when listing
// running processes/tasks.
OsquerydName = "osqueryd"
// OsqueryPidfile is the file containing the PID of the running osqueryd process
OsqueryPidfile = "osquery.pid"
// SystemServiceName is the name of Orbit system service
// The service name is used by the OS service management framework
SystemServiceName = "Fleet osquery"
)

View file

@ -0,0 +1,9 @@
//go:build !windows
// +build !windows
package osservice
// SetupServiceManagement is currently a placeholder for non-windows OSes
// system service configuration
func SetupServiceManagement(serviceName string, fleetDesktopPresent bool, shutdownFunctions *[]func(err error)) {
}

View file

@ -0,0 +1,112 @@
//go:build windows
// +build windows
package osservice
import (
"errors"
"os"
"time"
"github.com/fleetdm/fleet/v4/orbit/pkg/constant"
"github.com/fleetdm/fleet/v4/orbit/pkg/platform"
"github.com/rs/zerolog/log"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
)
type windowsService struct {
shutdownFunctions *[]func(err error)
fleetDesktopPresent bool
}
func (m *windowsService) bestEffortShutdown() {
serviceShutdown := errors.New("service is shutting down")
// Calling interrupt functions to gracefully shutdown runners
for _, interruptFn := range *m.shutdownFunctions {
interruptFn(serviceShutdown)
}
// Now ensuring that no child process are left
if m.fleetDesktopPresent {
err := platform.KillProcessByName(constant.DesktopAppExecName)
if err != nil {
log.Error().Err(err).Msg("The desktop app couldn't be killed")
}
}
err := platform.KillAllProcessByName(constant.OsquerydName)
if err != nil {
log.Error().Err(err).Msg("The child osqueryd processes cannot be killed")
}
}
func (m *windowsService) Execute(args []string, requests <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
// Accepted service operations
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
// Expected service status update during initialization
changes <- svc.Status{State: svc.StartPending}
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
// The listening loop below will keep listening for new SCM requests
for req := range requests {
switch req.Cmd {
case svc.Interrogate:
changes <- req.CurrentStatus
case svc.Stop, svc.Shutdown:
// Service shutdown was requested
// Updating the service state to indicate stop
changes <- svc.Status{State: svc.Stopped, Win32ExitCode: 0}
// Best effort tear down
// Runner group's interrupt functions will be called here
m.bestEffortShutdown()
// Dummy delay to allow the SCM to pick up the changes
time.Sleep(500 * time.Millisecond)
// Drastic teardown
os.Exit(windows.NO_ERROR)
default:
return false, uint32(windows.ERROR_INVALID_SERVICE_CONTROL)
}
}
return false, 0
}
// SetupServiceManagement implements the dispatcher and notification logic to
// interact with the Windows Service Control Manager (SCM)
func SetupServiceManagement(serviceName string, fleetDesktopPresent bool, shutdownFunctions *[]func(err error)) {
if serviceName == "" {
log.Error().Msg(" service name should not be empty")
return
}
// Ensuring that we are only calling the SCM if running as a service
isWindowsService, err := svc.IsWindowsService()
if err != nil {
log.Error().Err(err).Msg("couldn't determine if running as a service")
return
}
if isWindowsService {
srvData := windowsService{
shutdownFunctions: shutdownFunctions,
fleetDesktopPresent: fleetDesktopPresent,
}
// Registering our service into the SCM
err := svc.Run(serviceName, &srvData)
if err != nil {
log.Info().Err(err).Msg("SCM registration failed")
}
}
}

View file

@ -0,0 +1,230 @@
package platform
import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/mitchellh/go-ps"
"github.com/rs/zerolog/log"
gopsutil_process "github.com/shirou/gopsutil/v3/process"
)
var ErrProcessNotFound = errors.New("process not found")
// readPidFromFile reads a PID from a file
func readPidFromFile(destDir string, destFile string) (int32, error) {
// Defense programming - sanity checks on inputs
if destDir == "" {
return 0, errors.New(" destination directory should not be empty")
}
if destFile == "" {
return 0, errors.New(" destination file should not be empty")
}
pidFilePath := filepath.Join(destDir, destFile)
data, err := os.ReadFile(pidFilePath)
if err != nil {
return 0, fmt.Errorf("error reading pidfile %s: %w", pidFilePath, err)
}
intNumber, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 32)
if err != nil {
return 0, fmt.Errorf("error converting pidfile %s: %w", pidFilePath, err)
}
return int32(intNumber), err
}
// processNameMatches returns whether the process running with the given pid matches
// the executable name (case insensitive).
// If there's no process running with the given pid then (false, nil) is returned.
func processNameMatches(pid int, expectedPrefix string) (bool, error) {
if pid == 0 {
return false, errors.New("process id should not be zero")
}
if expectedPrefix == "" {
return false, errors.New("expected prefix should not be empty")
}
process, err := ps.FindProcess(pid)
if err != nil {
return false, fmt.Errorf("find process: %d: %w", pid, err)
}
if process == nil {
return false, nil
}
return strings.HasPrefix(strings.ToLower(process.Executable()), strings.ToLower(expectedPrefix)), nil
}
// killPID kills a process by PID
func killPID(pid int32) error {
if pid == 0 {
return errors.New("process id should not be zero")
}
processes, err := gopsutil_process.Processes()
if err != nil {
return err
}
for _, process := range processes {
if pid == process.Pid {
process.Kill()
break
}
}
return nil
}
// KillProcessByName kills a single process by its name
func KillProcessByName(name string) error {
if name == "" {
return errors.New("process name should not be empty")
}
foundProcess, err := GetProcessByName(name)
if err != nil {
return fmt.Errorf("get process: %w", err)
}
if err := foundProcess.Kill(); err != nil {
return fmt.Errorf("kill process %d: %w", foundProcess.Pid, err)
}
return nil
}
// getProcessesByName gets a single process object by its name
func getProcessesByName(name string) ([]*gopsutil_process.Process, error) {
if name == "" {
return nil, errors.New("process name should not be empty")
}
processes, err := gopsutil_process.Processes()
if err != nil {
return nil, err
}
var foundProcesses []*gopsutil_process.Process
for _, process := range processes {
processName, err := process.Name()
if err != nil {
log.Debug().Err(err).Int32("pid", process.Pid).Msg("get process name")
continue
}
if strings.HasPrefix(processName, name) {
foundProcesses = append(foundProcesses, process)
}
}
if len(foundProcesses) == 0 {
return nil, ErrProcessNotFound
}
return foundProcesses, nil
}
// KillAllProcessByName kills all process found by their name
func KillAllProcessByName(name string) error {
if name == "" {
return errors.New("process name should not be empty")
}
foundProcesses, err := getProcessesByName(name)
if err != nil {
return fmt.Errorf("get process: %w", err)
}
// Killing found processes
for _, foundProcess := range foundProcesses {
if err := foundProcess.Kill(); err != nil {
return fmt.Errorf("kill process %d: %w", foundProcess.Pid, err)
}
}
return nil
}
// GetProcessByName gets a single process object by its name
func GetProcessByName(name string) (*gopsutil_process.Process, error) {
if name == "" {
return nil, errors.New("process name should not be empty")
}
processes, err := gopsutil_process.Processes()
if err != nil {
return nil, err
}
var foundProcess *gopsutil_process.Process
for _, process := range processes {
processName, err := process.Name()
if err != nil {
log.Debug().Err(err).Int32("pid", process.Pid).Msg("get process name")
continue
}
if strings.HasPrefix(processName, name) {
foundProcess = process
break
}
}
if foundProcess == nil {
return nil, ErrProcessNotFound
}
return foundProcess, nil
}
// KillFromPIDFile kills a process taking the PID value from a file
func KillFromPIDFile(destDir string, pidFileName string, expectedExecName string) error {
if destDir == "" {
return errors.New("destination directory should not be empty")
}
if pidFileName == "" {
return errors.New("PID file name should not be empty")
}
if expectedExecName == "" {
return errors.New("expected executable name should not be empty")
}
pid, err := readPidFromFile(destDir, pidFileName)
switch {
case err == nil:
// OK
case errors.Is(err, os.ErrNotExist):
return nil // we assume it's not running
default:
return fmt.Errorf("reading pid from: %s: %w", destDir, err)
}
matches, err := processNameMatches(int(pid), expectedExecName)
if err != nil {
return fmt.Errorf("inspecting process %d: %w", pid, err)
}
if !matches {
// Nothing to do, another process may be running with this pid
// (e.g. could happen after a restart).
return nil
}
if err := killPID(pid); err != nil {
return fmt.Errorf("killing %d: %w", pid, err)
}
return nil
}