mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 09:28:54 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #40809 **Orbit agent: key rotation replaces decrypt-then-re-encrypt:** - When the disk is already encrypted, orbit now adds a new Fleet-managed recovery key protector, removes old ones, and escrows the new key. The disk is never decrypted. - If key escrow fails, the rotated key is cached in memory and retried on subsequent ticks without rotating again. - Removes `DecryptVolume` and `decrypt()` (no longer called from production code). **Server: osquery query returns both protection_status and conversion_status:** - The `disk_encryption_windows` query now returns both columns instead of just checking `protection_status = 1`. This lets the server correctly identify a disk as encrypted via `conversion_status = 1` even when `protection_status = 0`. - New `directIngestDiskEncryptionWindows` function parses both values, handles parse errors, and normalizes `protection_status = 2` (unknown) to NULL. **Server: new `bitlocker_protection_status` column and status logic:** - Adds `bitlocker_protection_status` column to `host_disks` (DB migration). - When a disk is encrypted and key is escrowed but protection is off, the host shows "Action required" with a detail message explaining the issue, instead of misleadingly showing "Verified." - `protection_status = 2` (unknown) and `NULL` (older orbit hosts) are treated as protection on for backward compatibility. - The `profiles_verified` and `profiles_verifying` branches in the combined profiles+BitLocker summary now handle `bitlocker_action_required`, counting those hosts as "pending". Contributor docs updates: https://github.com/fleetdm/fleet/pull/43241 Public docs updates: https://github.com/fleetdm/fleet/pull/43243/changes # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually ## Database migrations - [x] Checked schema for all modified table for columns that will auto-update timestamps during migration. ## fleetd/orbit/Fleet Desktop - [x] Verified compatibility with the latest released version of Fleet (see [Must rule](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/workflows/fleetd-development-and-release-strategy.md)) - [x] If the change applies to only one platform, confirmed that `runtime.GOOS` is used as needed to isolate changes - [x] Verified that fleetd runs on macOS, Linux and Windows - [x] Verified auto-update works from the released version of component to the new version (see [tools/tuf/test](../tools/tuf/test/README.md)) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Bug Fixes** * Fixed Windows BitLocker encryption/decryption request loop on systems with secondary drives and auto-unlock. * **New Features** * Added BitLocker recovery key rotation capability, allowing safe key updates without full disk re-encryption. * Enhanced BitLocker protection status tracking to correctly display "Action required" when protection is disabled. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
821 lines
29 KiB
Go
821 lines
29 KiB
Go
package update
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/orbit/pkg/bitlocker"
|
|
"github.com/fleetdm/fleet/v4/orbit/pkg/scripts"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRenewEnrollmentProfile(t *testing.T) {
|
|
var logBuf bytes.Buffer
|
|
|
|
oldLog := log.Logger
|
|
log.Logger = log.Output(&logBuf)
|
|
t.Cleanup(func() { log.Logger = oldLog })
|
|
|
|
cases := []struct {
|
|
desc string
|
|
renewFlag bool
|
|
cmdErr error
|
|
wantCmdCalled bool
|
|
wantLog string
|
|
}{
|
|
{"renew=false", false, nil, false, ""},
|
|
{"renew=true; success", true, nil, true, "successfully called /usr/bin/profiles to renew enrollment profile"},
|
|
{"renew=true; fail", true, io.ErrUnexpectedEOF, true, "calling /usr/bin/profiles to renew enrollment profile failed"},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.desc, func(t *testing.T) {
|
|
logBuf.Reset()
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{RenewEnrollmentProfile: c.renewFlag}}
|
|
|
|
var cmdGotCalled bool
|
|
var depAssignedCheckGotCalled bool
|
|
renewReceiver := &renewEnrollmentProfileConfigReceiver{
|
|
Frequency: time.Hour, // doesn't matter for this test
|
|
runCmdFn: func() error {
|
|
cmdGotCalled = true
|
|
return c.cmdErr
|
|
},
|
|
checkEnrollmentFn: func() (bool, string, error) {
|
|
return false, "", nil
|
|
},
|
|
checkAssignedEnrollmentProfileFn: func(url string) error {
|
|
depAssignedCheckGotCalled = true
|
|
return nil
|
|
},
|
|
}
|
|
|
|
err := renewReceiver.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
require.Equal(t, c.wantCmdCalled, cmdGotCalled)
|
|
require.Equal(t, c.wantCmdCalled, depAssignedCheckGotCalled)
|
|
require.Contains(t, logBuf.String(), c.wantLog)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRenewEnrollmentProfilePrevented(t *testing.T) {
|
|
var logBuf bytes.Buffer
|
|
|
|
oldLog := log.Logger
|
|
log.Logger = log.Output(&logBuf)
|
|
t.Cleanup(func() { log.Logger = oldLog })
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{RenewEnrollmentProfile: true}}
|
|
|
|
var cmdCallCount int
|
|
isEnrolled := false
|
|
isAssigned := true
|
|
chProceed := make(chan struct{})
|
|
renewReceiver := &renewEnrollmentProfileConfigReceiver{
|
|
Frequency: 2 * time.Second, // just to be safe with slow environments (CI)
|
|
runCmdFn: func() error {
|
|
cmdCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
|
|
return nil
|
|
},
|
|
checkEnrollmentFn: func() (bool, string, error) {
|
|
<-chProceed // will be unblocked only when allowed
|
|
return isEnrolled, "", nil
|
|
},
|
|
checkAssignedEnrollmentProfileFn: func(url string) error {
|
|
<-chProceed // will be unblocked only when allowed
|
|
if !isAssigned {
|
|
return errors.New("not assigned")
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
|
|
// One of the calls to renewReceiver.Run() will run first and get blocked in checkEnrollmentFn. The
|
|
// second won't call the command (won't be able to lock the mutex). So, it will still complete successfully
|
|
// without being blocked by the other call in progress. Whichever one exits first then needs to close
|
|
// chProceed so the other one is unblocked.
|
|
var shouldCloseChProceed atomic.Bool
|
|
shouldCloseChProceed.Store(true)
|
|
|
|
started := make(chan struct{})
|
|
frequencyMu := sync.Mutex{}
|
|
go func() {
|
|
frequencyMu.Lock()
|
|
defer frequencyMu.Unlock()
|
|
close(started)
|
|
|
|
err := renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
if shouldCloseChProceed.CompareAndSwap(true, false) {
|
|
close(chProceed)
|
|
t.Logf("%v unblock the first call from the goroutine", time.Now())
|
|
}
|
|
}()
|
|
|
|
<-started
|
|
t.Logf("%v started", time.Now())
|
|
|
|
err := renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
if shouldCloseChProceed.CompareAndSwap(true, false) {
|
|
// unblock the first call
|
|
close(chProceed)
|
|
t.Logf("%v unblock the first call", time.Now())
|
|
}
|
|
|
|
// this next call won't execute the command because of the frequency
|
|
// restriction (it got called less than N seconds ago)
|
|
err = renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
t.Logf("%v frequency restriction check done", time.Now())
|
|
|
|
frequencyMu.Lock()
|
|
renewReceiver.Frequency = 200 * time.Millisecond
|
|
frequencyMu.Unlock()
|
|
// wait for the receiver's frequency to pass
|
|
time.Sleep(renewReceiver.Frequency)
|
|
|
|
// this call executes the command
|
|
err = renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
// wait for the receiver's frequency to pass
|
|
time.Sleep(renewReceiver.Frequency)
|
|
|
|
// this call doesn't execute the command since the host is already
|
|
// enrolled
|
|
isEnrolled = true
|
|
err = renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, 2, cmdCallCount) // the initial call and the one after sleep
|
|
|
|
// wait for the receiver's frequency to pass
|
|
time.Sleep(renewReceiver.Frequency)
|
|
|
|
// this call doesn't execute the command since the assigned profile check fails
|
|
isAssigned = false
|
|
isEnrolled = false
|
|
err = renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, 2, cmdCallCount) // the initial call and the one after sleep
|
|
|
|
// wait for the receiver's frequency to pass
|
|
time.Sleep(renewReceiver.Frequency)
|
|
|
|
// this next call won't execute the command because the backoff
|
|
// for a failed assigned check is always 2 minutes
|
|
err = renewReceiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type mockNodeKeyGetter struct{}
|
|
|
|
func (m mockNodeKeyGetter) GetNodeKey() (string, error) {
|
|
return "nodekey-test", nil
|
|
}
|
|
|
|
func TestWindowsMDMEnrollment(t *testing.T) {
|
|
var logBuf bytes.Buffer
|
|
|
|
oldLog := log.Logger
|
|
log.Logger = log.Output(&logBuf)
|
|
t.Cleanup(func() { log.Logger = oldLog })
|
|
|
|
cases := []struct {
|
|
desc string
|
|
enrollFlag *bool
|
|
unenrollFlag *bool
|
|
migrateFlag *bool
|
|
discoveryURL string
|
|
apiErr error
|
|
wantAPICalled bool
|
|
wantLog string
|
|
}{
|
|
{"enroll=false", ptr.Bool(false), nil, nil, "", nil, false, ""},
|
|
{"enroll=true,discovery=''", ptr.Bool(true), nil, nil, "", nil, false, "discovery endpoint is empty"},
|
|
{"enroll=true,discovery!='',success", ptr.Bool(true), nil, nil, "http://example.com", nil, true, "successfully called RegisterDeviceWithManagement"},
|
|
{"enroll=true,discovery!='',fail", ptr.Bool(true), nil, nil, "http://example.com", io.ErrUnexpectedEOF, true, "enroll Windows device failed"},
|
|
{"enroll=true,discovery!='',server", ptr.Bool(true), nil, nil, "http://example.com", errIsWindowsServer, true, "device is a Windows Server, skipping enrollment"},
|
|
|
|
{"unenroll=false", nil, ptr.Bool(false), nil, "", nil, false, ""},
|
|
{"unenroll=true,success", nil, ptr.Bool(true), nil, "", nil, true, "successfully called UnregisterDeviceWithManagement to unenroll"},
|
|
{"unenroll=true,fail", nil, ptr.Bool(true), nil, "", io.ErrUnexpectedEOF, true, "unenroll Windows device failed"},
|
|
{"unenroll=true,server", nil, ptr.Bool(true), nil, "", errIsWindowsServer, true, "device is a Windows Server, skipping unenroll"},
|
|
|
|
{"migrate=false", nil, nil, ptr.Bool(false), "", nil, false, ""},
|
|
{"migrate=true,success", nil, nil, ptr.Bool(true), "", nil, true, "successfully called UnregisterDeviceWithManagement to migrate"},
|
|
{"migrate=true,fail", nil, nil, ptr.Bool(true), "", io.ErrUnexpectedEOF, true, "migrate Windows device failed"},
|
|
{"migrate=true,server", nil, nil, ptr.Bool(true), "", errIsWindowsServer, true, "device is a Windows Server, skipping migrate"},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.desc, func(t *testing.T) {
|
|
logBuf.Reset()
|
|
|
|
var (
|
|
enroll = c.enrollFlag != nil && *c.enrollFlag
|
|
unenroll = c.unenrollFlag != nil && *c.unenrollFlag
|
|
migrate = c.migrateFlag != nil && *c.migrateFlag
|
|
isUnenroll = c.unenrollFlag != nil
|
|
)
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
|
|
NeedsProgrammaticWindowsMDMEnrollment: enroll,
|
|
NeedsProgrammaticWindowsMDMUnenrollment: unenroll,
|
|
NeedsMDMMigration: migrate,
|
|
WindowsMDMDiscoveryEndpoint: c.discoveryURL,
|
|
}}
|
|
|
|
var enrollGotCalled, unenrollGotCalled bool
|
|
enrollReceiver := &windowsMDMEnrollmentConfigReceiver{
|
|
Frequency: time.Hour, // doesn't matter for this test
|
|
execEnrollFn: func(args WindowsMDMEnrollmentArgs) error {
|
|
enrollGotCalled = true
|
|
return c.apiErr
|
|
},
|
|
execUnenrollFn: func(args WindowsMDMEnrollmentArgs) error {
|
|
unenrollGotCalled = true
|
|
return c.apiErr
|
|
},
|
|
nodeKeyGetter: mockNodeKeyGetter{},
|
|
}
|
|
|
|
err := enrollReceiver.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
if isUnenroll || migrate {
|
|
require.Equal(t, c.wantAPICalled, unenrollGotCalled)
|
|
require.False(t, enrollGotCalled)
|
|
} else {
|
|
require.Equal(t, c.wantAPICalled, enrollGotCalled)
|
|
require.False(t, unenrollGotCalled)
|
|
}
|
|
require.Contains(t, logBuf.String(), c.wantLog)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWindowsMDMEnrollmentPrevented(t *testing.T) {
|
|
var logBuf bytes.Buffer
|
|
|
|
oldLog := log.Logger
|
|
log.Logger = log.Output(&logBuf)
|
|
t.Cleanup(func() { log.Logger = oldLog })
|
|
|
|
cfgs := []fleet.OrbitConfigNotifications{
|
|
{
|
|
NeedsProgrammaticWindowsMDMEnrollment: true,
|
|
WindowsMDMDiscoveryEndpoint: "http://example.com",
|
|
},
|
|
{
|
|
NeedsProgrammaticWindowsMDMUnenrollment: true,
|
|
},
|
|
}
|
|
for _, cfg := range cfgs {
|
|
t.Run(fmt.Sprintf("%+v", cfg), func(t *testing.T) {
|
|
testConfig := &fleet.OrbitConfig{Notifications: cfg}
|
|
|
|
var (
|
|
apiCallCount int
|
|
apiErr error
|
|
)
|
|
chProceed := make(chan struct{})
|
|
receiver := &windowsMDMEnrollmentConfigReceiver{
|
|
Frequency: 2 * time.Second, // just to be safe with slow environments (CI)
|
|
nodeKeyGetter: mockNodeKeyGetter{},
|
|
}
|
|
if cfg.NeedsProgrammaticWindowsMDMEnrollment {
|
|
receiver.execEnrollFn = func(args WindowsMDMEnrollmentArgs) error {
|
|
<-chProceed // will be unblocked only when allowed
|
|
apiCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
|
|
return apiErr
|
|
}
|
|
receiver.execUnenrollFn = func(args WindowsMDMEnrollmentArgs) error {
|
|
panic("should not be called")
|
|
}
|
|
} else {
|
|
receiver.execUnenrollFn = func(args WindowsMDMEnrollmentArgs) error {
|
|
<-chProceed // will be unblocked only when allowed
|
|
apiCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
|
|
return apiErr
|
|
}
|
|
receiver.execEnrollFn = func(args WindowsMDMEnrollmentArgs) error {
|
|
panic("should not be called")
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
// the first call will block in enroll/unenroll func
|
|
err := receiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
// wait a little bit to ensure the first `receiver.Run` call runs first.
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// this call will happen while the first call is blocked in
|
|
// enroll/unenrollfn, so it won't call the API (won't be able to lock the
|
|
// mutex). However it will still complete successfully without being
|
|
// blocked by the other call in progress.
|
|
err := receiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
// unblock the first call and wait for it to complete
|
|
close(chProceed)
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// this next call won't execute the command because of the frequency
|
|
// restriction (it got called less than N seconds ago)
|
|
err = receiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
// wait for the receiver's frequency to pass
|
|
time.Sleep(receiver.Frequency)
|
|
|
|
// this call executes the command, and it returns the Is Windows Server error
|
|
apiErr = errIsWindowsServer
|
|
err = receiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
// this next call won't execute the command (both due to frequency and the
|
|
// detection of windows server)
|
|
err = receiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
// wait for the receiver's frequency to pass
|
|
time.Sleep(receiver.Frequency)
|
|
|
|
// this next call still won't execute the command (due to the detection of
|
|
// windows server)
|
|
err = receiver.Run(testConfig)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, 2, apiCallCount) // the initial call and the one that returned errIsWindowsServer after first sleep
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRunScripts(t *testing.T) {
|
|
var logBuf bytes.Buffer
|
|
|
|
oldLog := log.Logger
|
|
log.Logger = log.Output(&logBuf)
|
|
t.Cleanup(func() { log.Logger = oldLog })
|
|
|
|
var (
|
|
callsCount atomic.Int64
|
|
runFailure error
|
|
blockRun chan struct{}
|
|
)
|
|
|
|
mockRun := func(r *scripts.Runner, ids []string) error {
|
|
callsCount.Add(1)
|
|
if blockRun != nil {
|
|
<-blockRun
|
|
}
|
|
return runFailure
|
|
}
|
|
|
|
waitForRun := func(t *testing.T, r *runScriptsConfigReceiver) {
|
|
var ok bool
|
|
for start := time.Now(); !ok && time.Since(start) < time.Second; {
|
|
ok = r.mu.TryLock()
|
|
}
|
|
require.True(t, ok, "timed out waiting for the lock to become available")
|
|
r.mu.Unlock()
|
|
}
|
|
|
|
t.Run("no pending scripts", func(t *testing.T) {
|
|
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset() })
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
|
|
PendingScriptExecutionIDs: nil,
|
|
}}
|
|
|
|
runner := &runScriptsConfigReceiver{
|
|
runScriptsFn: mockRun,
|
|
}
|
|
err := runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
// the lock should be available because no goroutine was started
|
|
require.True(t, runner.mu.TryLock())
|
|
require.Zero(t, callsCount.Load()) // no calls to execute scripts
|
|
require.Empty(t, logBuf.String()) // no logs written
|
|
})
|
|
|
|
t.Run("pending scripts succeed", func(t *testing.T) {
|
|
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset() })
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
|
|
PendingScriptExecutionIDs: []string{"a", "b", "c"},
|
|
}}
|
|
|
|
runner := &runScriptsConfigReceiver{
|
|
runScriptsFn: mockRun,
|
|
}
|
|
err := runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
waitForRun(t, runner)
|
|
require.Equal(t, int64(1), callsCount.Load()) // all scripts executed in a single run
|
|
require.Contains(t, logBuf.String(), "received notification to run scripts [a b c]")
|
|
require.Contains(t, logBuf.String(), "running scripts [a b c] succeeded")
|
|
})
|
|
|
|
t.Run("pending scripts failed", func(t *testing.T) {
|
|
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset(); runFailure = nil })
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
|
|
PendingScriptExecutionIDs: []string{"a", "b", "c"},
|
|
}}
|
|
|
|
runFailure = io.ErrUnexpectedEOF
|
|
runner := &runScriptsConfigReceiver{
|
|
runScriptsFn: mockRun,
|
|
}
|
|
|
|
err := runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
waitForRun(t, runner)
|
|
require.Equal(t, int64(1), callsCount.Load()) // all scripts executed in a single run
|
|
require.Contains(t, logBuf.String(), "received notification to run scripts [a b c]")
|
|
require.Contains(t, logBuf.String(), "running scripts failed")
|
|
require.Contains(t, logBuf.String(), io.ErrUnexpectedEOF.Error())
|
|
})
|
|
|
|
t.Run("concurrent run prevented", func(t *testing.T) {
|
|
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset(); blockRun = nil })
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
|
|
PendingScriptExecutionIDs: []string{"a", "b", "c"},
|
|
}}
|
|
|
|
blockRun = make(chan struct{})
|
|
runner := &runScriptsConfigReceiver{
|
|
runScriptsFn: mockRun,
|
|
}
|
|
|
|
err := runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
// call it again, while the previous run is still running
|
|
err = runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
|
|
// unblock the initial run
|
|
close(blockRun)
|
|
|
|
waitForRun(t, runner)
|
|
require.Equal(t, int64(1), callsCount.Load()) // only called once because of mutex
|
|
require.Contains(t, logBuf.String(), "received notification to run scripts [a b c]")
|
|
require.Contains(t, logBuf.String(), "running scripts [a b c] succeeded")
|
|
})
|
|
|
|
t.Run("dynamic enabling of scripts", func(t *testing.T) {
|
|
t.Cleanup(logBuf.Reset)
|
|
|
|
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
|
|
PendingScriptExecutionIDs: []string{"a"},
|
|
}}
|
|
|
|
var (
|
|
scriptsEnabledCalls []bool
|
|
dynamicEnabled atomic.Bool
|
|
|
|
dynamicInterval = 300 * time.Millisecond
|
|
)
|
|
|
|
runner := &runScriptsConfigReceiver{
|
|
ScriptsExecutionEnabled: false,
|
|
runScriptsFn: func(r *scripts.Runner, s []string) error {
|
|
scriptsEnabledCalls = append(scriptsEnabledCalls, r.ScriptExecutionEnabled)
|
|
return nil
|
|
},
|
|
testGetFleetdConfig: func() (*fleet.MDMAppleFleetdConfig, error) {
|
|
return &fleet.MDMAppleFleetdConfig{
|
|
EnableScripts: dynamicEnabled.Load(),
|
|
}, nil
|
|
},
|
|
dynamicScriptsEnabledCheckInterval: dynamicInterval,
|
|
}
|
|
|
|
// the static Scripts Enabled flag is false, so it relies on the dynamic check
|
|
runner.runDynamicScriptsEnabledCheck()
|
|
|
|
// first call, scripts are disabled
|
|
err := runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
waitForRun(t, runner)
|
|
|
|
// swap scripts execution to true and wait to ensure the dynamic check
|
|
// did run.
|
|
dynamicEnabled.Store(true)
|
|
time.Sleep(dynamicInterval + 100*time.Millisecond)
|
|
|
|
// second call, scripts are enabled (change exec ID to "b")
|
|
testConfig.Notifications.PendingScriptExecutionIDs[0] = "b"
|
|
err = runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
waitForRun(t, runner)
|
|
|
|
// swap scripts execution back to false and wait to ensure the dynamic
|
|
// check did run.
|
|
dynamicEnabled.Store(false)
|
|
time.Sleep(dynamicInterval + 100*time.Millisecond)
|
|
|
|
// third call, scripts are disabled (change exec ID to "c")
|
|
testConfig.Notifications.PendingScriptExecutionIDs[0] = "c"
|
|
err = runner.Run(testConfig)
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
waitForRun(t, runner)
|
|
|
|
// validate the Scripts Enabled flags that were passed to the runScriptsFn
|
|
require.Equal(t, []bool{false, true, false}, scriptsEnabledCalls)
|
|
require.Contains(t, logBuf.String(), "received notification to run scripts [a]")
|
|
require.Contains(t, logBuf.String(), "running scripts [a] succeeded")
|
|
require.Contains(t, logBuf.String(), "received notification to run scripts [b]")
|
|
require.Contains(t, logBuf.String(), "running scripts [b] succeeded")
|
|
require.Contains(t, logBuf.String(), "received notification to run scripts [c]")
|
|
require.Contains(t, logBuf.String(), "running scripts [c] succeeded")
|
|
})
|
|
}
|
|
|
|
type mockDiskEncryptionKeySetter struct {
|
|
SetOrUpdateDiskEncryptionKeyImpl func(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error
|
|
SetOrUpdateDiskEncryptionKeyInvoked bool
|
|
}
|
|
|
|
func (m *mockDiskEncryptionKeySetter) SetOrUpdateDiskEncryptionKey(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error {
|
|
m.SetOrUpdateDiskEncryptionKeyInvoked = true
|
|
return m.SetOrUpdateDiskEncryptionKeyImpl(diskEncryptionStatus)
|
|
}
|
|
|
|
func TestBitlockerOperations(t *testing.T) {
|
|
var logBuf bytes.Buffer
|
|
|
|
oldLog := log.Logger
|
|
log.Logger = log.Output(&logBuf)
|
|
t.Cleanup(func() { log.Logger = oldLog })
|
|
|
|
var (
|
|
shouldEncrypt = true
|
|
shouldFailEncryption = false
|
|
shouldFailServerUpdate = false
|
|
encryptFnCalled = false
|
|
)
|
|
|
|
clientMock := &mockDiskEncryptionKeySetter{}
|
|
clientMock.SetOrUpdateDiskEncryptionKeyImpl = func(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error {
|
|
if shouldFailServerUpdate {
|
|
return errors.New("server error")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var rotateKeyFnCalled bool
|
|
var shouldFailKeyRotation bool
|
|
|
|
var enrollReceiver *windowsMDMBitlockerConfigReceiver
|
|
setupTest := func() {
|
|
enrollReceiver = &windowsMDMBitlockerConfigReceiver{
|
|
Frequency: time.Hour, // doesn't matter for this test
|
|
lastRun: time.Now().Add(-2 * time.Hour),
|
|
EncryptionResult: clientMock,
|
|
execGetEncryptionStatusFn: func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{}, nil
|
|
},
|
|
execEncryptVolumeFn: func(string) (string, error) {
|
|
encryptFnCalled = true
|
|
if shouldFailEncryption {
|
|
return "", errors.New("error encrypting")
|
|
}
|
|
|
|
return "123456", nil
|
|
},
|
|
execRotateRecoveryKeyFn: func(string) (string, error) {
|
|
rotateKeyFnCalled = true
|
|
if shouldFailKeyRotation {
|
|
return "", errors.New("error rotating key")
|
|
}
|
|
return "rotated-key-789", nil
|
|
},
|
|
}
|
|
shouldEncrypt = true
|
|
shouldFailEncryption = false
|
|
shouldFailKeyRotation = false
|
|
shouldFailServerUpdate = false
|
|
encryptFnCalled = false
|
|
rotateKeyFnCalled = false
|
|
clientMock.SetOrUpdateDiskEncryptionKeyInvoked = false
|
|
logBuf.Reset()
|
|
}
|
|
|
|
makeConfig := func() *fleet.OrbitConfig {
|
|
return &fleet.OrbitConfig{
|
|
Notifications: fleet.OrbitConfigNotifications{
|
|
EnforceBitLockerEncryption: shouldEncrypt,
|
|
},
|
|
}
|
|
}
|
|
|
|
t.Run("bitlocker encryption is performed", func(t *testing.T) {
|
|
setupTest()
|
|
// shouldEncrypt defaults to true from setupTest
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
})
|
|
|
|
t.Run("bitlocker encryption is not performed when not enforced", func(t *testing.T) {
|
|
setupTest()
|
|
shouldEncrypt = false
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.False(t, encryptFnCalled, "encryption function should not be called when not enforced")
|
|
require.False(t, rotateKeyFnCalled, "rotate key function should not be called when not enforced")
|
|
})
|
|
|
|
t.Run("bitlocker encryption returns an error", func(t *testing.T) {
|
|
setupTest()
|
|
shouldFailEncryption = true
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err) // the dummy receiver never returns an error
|
|
require.True(t, encryptFnCalled, "encryption function should have been called")
|
|
})
|
|
|
|
t.Run("encryption skipped based on various current statuses", func(t *testing.T) {
|
|
setupTest()
|
|
statusesToTest := []int32{
|
|
bitlocker.ConversionStatusDecryptionInProgress,
|
|
bitlocker.ConversionStatusDecryptionPaused,
|
|
bitlocker.ConversionStatusEncryptionInProgress,
|
|
bitlocker.ConversionStatusEncryptionPaused,
|
|
}
|
|
|
|
for _, status := range statusesToTest {
|
|
t.Run(fmt.Sprintf("status %d", status), func(t *testing.T) {
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: status}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "skipping encryption as the disk is not available")
|
|
require.False(t, encryptFnCalled, "encryption function should not be called")
|
|
logBuf.Reset() // Reset the log buffer for the next iteration
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("handle misreported decryption error", func(t *testing.T) {
|
|
setupTest()
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
enrollReceiver.execEncryptVolumeFn = func(string) (string, error) {
|
|
return "", bitlocker.NewEncryptionError("", bitlocker.ErrorCodeNotDecrypted)
|
|
}
|
|
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "disk encryption failed due to previous unsuccessful attempt, user action required")
|
|
require.False(t, encryptFnCalled, "encryption function should not be called")
|
|
})
|
|
|
|
t.Run("rotates recovery key if disk already encrypted", func(t *testing.T) {
|
|
setupTest()
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "disk is already encrypted, rotating recovery key")
|
|
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked, "should escrow the rotated key")
|
|
require.True(t, rotateKeyFnCalled, "rotate key function should have been called")
|
|
require.False(t, encryptFnCalled, "encryption function should not be called")
|
|
})
|
|
|
|
t.Run("reports to the server if key rotation fails", func(t *testing.T) {
|
|
setupTest()
|
|
shouldFailKeyRotation = true
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "disk is already encrypted, rotating recovery key")
|
|
require.Contains(t, logBuf.String(), "recovery key rotation failed")
|
|
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
|
|
require.True(t, rotateKeyFnCalled, "rotate key function should have been called")
|
|
require.False(t, encryptFnCalled, "encryption function should not be called")
|
|
})
|
|
|
|
t.Run("encryption skipped if last run too recent", func(t *testing.T) {
|
|
setupTest()
|
|
enrollReceiver.lastRun = time.Now().Add(-30 * time.Minute)
|
|
enrollReceiver.Frequency = 1 * time.Hour
|
|
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "skipped encryption process, last run was too recent")
|
|
require.False(t, encryptFnCalled, "encryption function should not be called")
|
|
})
|
|
|
|
t.Run("successful fleet server update", func(t *testing.T) {
|
|
setupTest()
|
|
shouldFailEncryption = false
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
|
|
require.True(t, encryptFnCalled, "encryption function should have been called")
|
|
})
|
|
|
|
t.Run("failed fleet server update", func(t *testing.T) {
|
|
setupTest()
|
|
shouldFailEncryption = false
|
|
shouldFailServerUpdate = true
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "failed to send encryption result to Fleet Server")
|
|
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
|
|
require.True(t, encryptFnCalled, "encryption function should have been called")
|
|
})
|
|
|
|
t.Run("failed escrow caches key for retry", func(t *testing.T) {
|
|
setupTest()
|
|
shouldFailServerUpdate = true
|
|
lastRunBefore := enrollReceiver.lastRun
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
|
|
// First run: rotation succeeds but escrow fails, key should be cached
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.True(t, rotateKeyFnCalled, "rotate key function should have been called")
|
|
require.Equal(t, "rotated-key-789", enrollReceiver.pendingRecoveryKey, "key should be cached after failed escrow")
|
|
require.Equal(t, lastRunBefore, enrollReceiver.lastRun, "lastRun should not advance when escrow fails")
|
|
})
|
|
|
|
t.Run("cached key retried without re-rotating", func(t *testing.T) {
|
|
setupTest()
|
|
shouldFailServerUpdate = true
|
|
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
|
|
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
|
|
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
|
|
}
|
|
|
|
// First run: rotation succeeds, escrow fails, key cached
|
|
err := enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.True(t, rotateKeyFnCalled)
|
|
require.Equal(t, "rotated-key-789", enrollReceiver.pendingRecoveryKey)
|
|
|
|
// Second run: escrow succeeds, key cleared, no re-rotation
|
|
rotateKeyFnCalled = false
|
|
encryptFnCalled = false
|
|
shouldFailServerUpdate = false
|
|
logBuf.Reset()
|
|
|
|
err = enrollReceiver.Run(makeConfig())
|
|
require.NoError(t, err)
|
|
require.Contains(t, logBuf.String(), "retrying escrow of previously rotated recovery key")
|
|
require.False(t, rotateKeyFnCalled, "should NOT rotate again")
|
|
require.False(t, encryptFnCalled, "should NOT encrypt again")
|
|
require.Empty(t, enrollReceiver.pendingRecoveryKey, "cached key should be cleared after successful escrow")
|
|
require.False(t, enrollReceiver.lastRun.IsZero(), "lastRun should be set after successful escrow")
|
|
})
|
|
|
|
}
|