fleet/orbit/pkg/update/notifications_test.go
Jordan Montgomery 95178043cf
Fix race condition in TestRenewEnrollmentProfilePrevented (#37576)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #35852

Tested by adding a small(100ms but even smaller should work) sleep in
the goroutine before calling renewReceiver.Run() which simulates the
active goroutine being preempted and the other running before it gets
scheduled again. When I did this it would hang and timeout every time
before the fix. AFter the fix I never saw a timeout over 500 runs with
the sleep added and without

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] If paths of existing endpoints are modified without backwards
compatibility, checked the frontend/CLI for any necessary changes

## Testing

- [x] Added/updated automated tests

- [x] QA'd all new/changed functionality manually
2025-12-19 22:09:26 -05:00

782 lines
27 KiB
Go

package update
import (
"bytes"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/fleetdm/fleet/v4/orbit/pkg/bitlocker"
"github.com/fleetdm/fleet/v4/orbit/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
)
func TestRenewEnrollmentProfile(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
cases := []struct {
desc string
renewFlag bool
cmdErr error
wantCmdCalled bool
wantLog string
}{
{"renew=false", false, nil, false, ""},
{"renew=true; success", true, nil, true, "successfully called /usr/bin/profiles to renew enrollment profile"},
{"renew=true; fail", true, io.ErrUnexpectedEOF, true, "calling /usr/bin/profiles to renew enrollment profile failed"},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
logBuf.Reset()
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{RenewEnrollmentProfile: c.renewFlag}}
var cmdGotCalled bool
var depAssignedCheckGotCalled bool
renewReceiver := &renewEnrollmentProfileConfigReceiver{
Frequency: time.Hour, // doesn't matter for this test
runCmdFn: func() error {
cmdGotCalled = true
return c.cmdErr
},
checkEnrollmentFn: func() (bool, string, error) {
return false, "", nil
},
checkAssignedEnrollmentProfileFn: func(url string) error {
depAssignedCheckGotCalled = true
return nil
},
}
err := renewReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
require.Equal(t, c.wantCmdCalled, cmdGotCalled)
require.Equal(t, c.wantCmdCalled, depAssignedCheckGotCalled)
require.Contains(t, logBuf.String(), c.wantLog)
})
}
}
func TestRenewEnrollmentProfilePrevented(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{RenewEnrollmentProfile: true}}
var cmdCallCount int
isEnrolled := false
isAssigned := true
chProceed := make(chan struct{})
renewReceiver := &renewEnrollmentProfileConfigReceiver{
Frequency: 2 * time.Second, // just to be safe with slow environments (CI)
runCmdFn: func() error {
cmdCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
return nil
},
checkEnrollmentFn: func() (bool, string, error) {
<-chProceed // will be unblocked only when allowed
return isEnrolled, "", nil
},
checkAssignedEnrollmentProfileFn: func(url string) error {
<-chProceed // will be unblocked only when allowed
if !isAssigned {
return errors.New("not assigned")
}
return nil
},
}
// One of the calls to renewReceiver.Run() will run first and get blocked in checkEnrollmentFn. The
// second won't call the command (won't be able to lock the mutex). So, it will still complete successfully
// without being blocked by the other call in progress. Whichever one exits first then needs to close
// chProceed so the other one is unblocked.
var shouldCloseChProceed atomic.Bool
shouldCloseChProceed.Store(true)
started := make(chan struct{})
frequencyMu := sync.Mutex{}
go func() {
frequencyMu.Lock()
defer frequencyMu.Unlock()
close(started)
err := renewReceiver.Run(testConfig)
require.NoError(t, err)
if shouldCloseChProceed.CompareAndSwap(true, false) {
close(chProceed)
t.Logf("%v unblock the first call from the goroutine", time.Now())
}
}()
<-started
t.Logf("%v started", time.Now())
err := renewReceiver.Run(testConfig)
require.NoError(t, err)
if shouldCloseChProceed.CompareAndSwap(true, false) {
// unblock the first call
close(chProceed)
t.Logf("%v unblock the first call", time.Now())
}
// this next call won't execute the command because of the frequency
// restriction (it got called less than N seconds ago)
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
t.Logf("%v frequency restriction check done", time.Now())
frequencyMu.Lock()
renewReceiver.Frequency = 200 * time.Millisecond
frequencyMu.Unlock()
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this call executes the command
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this call doesn't execute the command since the host is already
// enrolled
isEnrolled = true
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
require.Equal(t, 2, cmdCallCount) // the initial call and the one after sleep
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this call doesn't execute the command since the assigned profile check fails
isAssigned = false
isEnrolled = false
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
require.Equal(t, 2, cmdCallCount) // the initial call and the one after sleep
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this next call won't execute the command because the backoff
// for a failed assigned check is always 2 minutes
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
}
type mockNodeKeyGetter struct{}
func (m mockNodeKeyGetter) GetNodeKey() (string, error) {
return "nodekey-test", nil
}
func TestWindowsMDMEnrollment(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
cases := []struct {
desc string
enrollFlag *bool
unenrollFlag *bool
migrateFlag *bool
discoveryURL string
apiErr error
wantAPICalled bool
wantLog string
}{
{"enroll=false", ptr.Bool(false), nil, nil, "", nil, false, ""},
{"enroll=true,discovery=''", ptr.Bool(true), nil, nil, "", nil, false, "discovery endpoint is empty"},
{"enroll=true,discovery!='',success", ptr.Bool(true), nil, nil, "http://example.com", nil, true, "successfully called RegisterDeviceWithManagement"},
{"enroll=true,discovery!='',fail", ptr.Bool(true), nil, nil, "http://example.com", io.ErrUnexpectedEOF, true, "enroll Windows device failed"},
{"enroll=true,discovery!='',server", ptr.Bool(true), nil, nil, "http://example.com", errIsWindowsServer, true, "device is a Windows Server, skipping enrollment"},
{"unenroll=false", nil, ptr.Bool(false), nil, "", nil, false, ""},
{"unenroll=true,success", nil, ptr.Bool(true), nil, "", nil, true, "successfully called UnregisterDeviceWithManagement to unenroll"},
{"unenroll=true,fail", nil, ptr.Bool(true), nil, "", io.ErrUnexpectedEOF, true, "unenroll Windows device failed"},
{"unenroll=true,server", nil, ptr.Bool(true), nil, "", errIsWindowsServer, true, "device is a Windows Server, skipping unenroll"},
{"migrate=false", nil, nil, ptr.Bool(false), "", nil, false, ""},
{"migrate=true,success", nil, nil, ptr.Bool(true), "", nil, true, "successfully called UnregisterDeviceWithManagement to migrate"},
{"migrate=true,fail", nil, nil, ptr.Bool(true), "", io.ErrUnexpectedEOF, true, "migrate Windows device failed"},
{"migrate=true,server", nil, nil, ptr.Bool(true), "", errIsWindowsServer, true, "device is a Windows Server, skipping migrate"},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
logBuf.Reset()
var (
enroll = c.enrollFlag != nil && *c.enrollFlag
unenroll = c.unenrollFlag != nil && *c.unenrollFlag
migrate = c.migrateFlag != nil && *c.migrateFlag
isUnenroll = c.unenrollFlag != nil
)
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
NeedsProgrammaticWindowsMDMEnrollment: enroll,
NeedsProgrammaticWindowsMDMUnenrollment: unenroll,
NeedsMDMMigration: migrate,
WindowsMDMDiscoveryEndpoint: c.discoveryURL,
}}
var enrollGotCalled, unenrollGotCalled bool
enrollReceiver := &windowsMDMEnrollmentConfigReceiver{
Frequency: time.Hour, // doesn't matter for this test
execEnrollFn: func(args WindowsMDMEnrollmentArgs) error {
enrollGotCalled = true
return c.apiErr
},
execUnenrollFn: func(args WindowsMDMEnrollmentArgs) error {
unenrollGotCalled = true
return c.apiErr
},
nodeKeyGetter: mockNodeKeyGetter{},
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
if isUnenroll || migrate {
require.Equal(t, c.wantAPICalled, unenrollGotCalled)
require.False(t, enrollGotCalled)
} else {
require.Equal(t, c.wantAPICalled, enrollGotCalled)
require.False(t, unenrollGotCalled)
}
require.Contains(t, logBuf.String(), c.wantLog)
})
}
}
func TestWindowsMDMEnrollmentPrevented(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
cfgs := []fleet.OrbitConfigNotifications{
{
NeedsProgrammaticWindowsMDMEnrollment: true,
WindowsMDMDiscoveryEndpoint: "http://example.com",
},
{
NeedsProgrammaticWindowsMDMUnenrollment: true,
},
}
for _, cfg := range cfgs {
t.Run(fmt.Sprintf("%+v", cfg), func(t *testing.T) {
testConfig := &fleet.OrbitConfig{Notifications: cfg}
var (
apiCallCount int
apiErr error
)
chProceed := make(chan struct{})
receiver := &windowsMDMEnrollmentConfigReceiver{
Frequency: 2 * time.Second, // just to be safe with slow environments (CI)
nodeKeyGetter: mockNodeKeyGetter{},
}
if cfg.NeedsProgrammaticWindowsMDMEnrollment {
receiver.execEnrollFn = func(args WindowsMDMEnrollmentArgs) error {
<-chProceed // will be unblocked only when allowed
apiCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
return apiErr
}
receiver.execUnenrollFn = func(args WindowsMDMEnrollmentArgs) error {
panic("should not be called")
}
} else {
receiver.execUnenrollFn = func(args WindowsMDMEnrollmentArgs) error {
<-chProceed // will be unblocked only when allowed
apiCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
return apiErr
}
receiver.execEnrollFn = func(args WindowsMDMEnrollmentArgs) error {
panic("should not be called")
}
}
go func() {
// the first call will block in enroll/unenroll func
err := receiver.Run(testConfig)
require.NoError(t, err)
}()
// wait a little bit to ensure the first `receiver.Run` call runs first.
time.Sleep(100 * time.Millisecond)
// this call will happen while the first call is blocked in
// enroll/unenrollfn, so it won't call the API (won't be able to lock the
// mutex). However it will still complete successfully without being
// blocked by the other call in progress.
err := receiver.Run(testConfig)
require.NoError(t, err)
// unblock the first call and wait for it to complete
close(chProceed)
time.Sleep(100 * time.Millisecond)
// this next call won't execute the command because of the frequency
// restriction (it got called less than N seconds ago)
err = receiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(receiver.Frequency)
// this call executes the command, and it returns the Is Windows Server error
apiErr = errIsWindowsServer
err = receiver.Run(testConfig)
require.NoError(t, err)
// this next call won't execute the command (both due to frequency and the
// detection of windows server)
err = receiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(receiver.Frequency)
// this next call still won't execute the command (due to the detection of
// windows server)
err = receiver.Run(testConfig)
require.NoError(t, err)
require.Equal(t, 2, apiCallCount) // the initial call and the one that returned errIsWindowsServer after first sleep
})
}
}
func TestRunScripts(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
var (
callsCount atomic.Int64
runFailure error
blockRun chan struct{}
)
mockRun := func(r *scripts.Runner, ids []string) error {
callsCount.Add(1)
if blockRun != nil {
<-blockRun
}
return runFailure
}
waitForRun := func(t *testing.T, r *runScriptsConfigReceiver) {
var ok bool
for start := time.Now(); !ok && time.Since(start) < time.Second; {
ok = r.mu.TryLock()
}
require.True(t, ok, "timed out waiting for the lock to become available")
r.mu.Unlock()
}
t.Run("no pending scripts", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset() })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: nil,
}}
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
// the lock should be available because no goroutine was started
require.True(t, runner.mu.TryLock())
require.Zero(t, callsCount.Load()) // no calls to execute scripts
require.Empty(t, logBuf.String()) // no logs written
})
t.Run("pending scripts succeed", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset() })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a", "b", "c"},
}}
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
require.Equal(t, int64(1), callsCount.Load()) // all scripts executed in a single run
require.Contains(t, logBuf.String(), "received notification to run scripts [a b c]")
require.Contains(t, logBuf.String(), "running scripts [a b c] succeeded")
})
t.Run("pending scripts failed", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset(); runFailure = nil })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a", "b", "c"},
}}
runFailure = io.ErrUnexpectedEOF
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
require.Equal(t, int64(1), callsCount.Load()) // all scripts executed in a single run
require.Contains(t, logBuf.String(), "received notification to run scripts [a b c]")
require.Contains(t, logBuf.String(), "running scripts failed")
require.Contains(t, logBuf.String(), io.ErrUnexpectedEOF.Error())
})
t.Run("concurrent run prevented", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset(); blockRun = nil })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a", "b", "c"},
}}
blockRun = make(chan struct{})
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
// call it again, while the previous run is still running
err = runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
// unblock the initial run
close(blockRun)
waitForRun(t, runner)
require.Equal(t, int64(1), callsCount.Load()) // only called once because of mutex
require.Contains(t, logBuf.String(), "received notification to run scripts [a b c]")
require.Contains(t, logBuf.String(), "running scripts [a b c] succeeded")
})
t.Run("dynamic enabling of scripts", func(t *testing.T) {
t.Cleanup(logBuf.Reset)
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a"},
}}
var (
scriptsEnabledCalls []bool
dynamicEnabled atomic.Bool
dynamicInterval = 300 * time.Millisecond
)
runner := &runScriptsConfigReceiver{
ScriptsExecutionEnabled: false,
runScriptsFn: func(r *scripts.Runner, s []string) error {
scriptsEnabledCalls = append(scriptsEnabledCalls, r.ScriptExecutionEnabled)
return nil
},
testGetFleetdConfig: func() (*fleet.MDMAppleFleetdConfig, error) {
return &fleet.MDMAppleFleetdConfig{
EnableScripts: dynamicEnabled.Load(),
}, nil
},
dynamicScriptsEnabledCheckInterval: dynamicInterval,
}
// the static Scripts Enabled flag is false, so it relies on the dynamic check
runner.runDynamicScriptsEnabledCheck()
// first call, scripts are disabled
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
// swap scripts execution to true and wait to ensure the dynamic check
// did run.
dynamicEnabled.Store(true)
time.Sleep(dynamicInterval + 100*time.Millisecond)
// second call, scripts are enabled (change exec ID to "b")
testConfig.Notifications.PendingScriptExecutionIDs[0] = "b"
err = runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
// swap scripts execution back to false and wait to ensure the dynamic
// check did run.
dynamicEnabled.Store(false)
time.Sleep(dynamicInterval + 100*time.Millisecond)
// third call, scripts are disabled (change exec ID to "c")
testConfig.Notifications.PendingScriptExecutionIDs[0] = "c"
err = runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
// validate the Scripts Enabled flags that were passed to the runScriptsFn
require.Equal(t, []bool{false, true, false}, scriptsEnabledCalls)
require.Contains(t, logBuf.String(), "received notification to run scripts [a]")
require.Contains(t, logBuf.String(), "running scripts [a] succeeded")
require.Contains(t, logBuf.String(), "received notification to run scripts [b]")
require.Contains(t, logBuf.String(), "running scripts [b] succeeded")
require.Contains(t, logBuf.String(), "received notification to run scripts [c]")
require.Contains(t, logBuf.String(), "running scripts [c] succeeded")
})
}
type mockDiskEncryptionKeySetter struct {
SetOrUpdateDiskEncryptionKeyImpl func(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error
SetOrUpdateDiskEncryptionKeyInvoked bool
}
func (m *mockDiskEncryptionKeySetter) SetOrUpdateDiskEncryptionKey(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error {
m.SetOrUpdateDiskEncryptionKeyInvoked = true
return m.SetOrUpdateDiskEncryptionKeyImpl(diskEncryptionStatus)
}
func TestBitlockerOperations(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
var (
shouldEncrypt = true
shouldFailEncryption = false
shouldFailDecryption = false
shouldFailServerUpdate = false
encryptFnCalled = false
decryptFnCalled = false
)
testConfig := &fleet.OrbitConfig{
Notifications: fleet.OrbitConfigNotifications{
EnforceBitLockerEncryption: shouldEncrypt,
},
}
clientMock := &mockDiskEncryptionKeySetter{}
clientMock.SetOrUpdateDiskEncryptionKeyImpl = func(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error {
if shouldFailServerUpdate {
return errors.New("server error")
}
return nil
}
var enrollReceiver *windowsMDMBitlockerConfigReceiver
setupTest := func() {
enrollReceiver = &windowsMDMBitlockerConfigReceiver{
Frequency: time.Hour, // doesn't matter for this test
lastRun: time.Now().Add(-2 * time.Hour),
EncryptionResult: clientMock,
execGetEncryptionStatusFn: func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{}, nil
},
execEncryptVolumeFn: func(string) (string, error) {
encryptFnCalled = true
if shouldFailEncryption {
return "", errors.New("error encrypting")
}
return "123456", nil
},
execDecryptVolumeFn: func(string) error {
decryptFnCalled = true
if shouldFailDecryption {
return errors.New("error decrypting")
}
return nil
},
}
shouldEncrypt = true
shouldFailEncryption = false
shouldFailDecryption = false
shouldFailServerUpdate = false
encryptFnCalled = false
decryptFnCalled = false
clientMock.SetOrUpdateDiskEncryptionKeyInvoked = false
logBuf.Reset()
}
t.Run("bitlocker encryption is performed", func(t *testing.T) {
setupTest()
shouldEncrypt = true
shouldFailEncryption = false
shouldFailDecryption = false
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
})
t.Run("bitlocker encryption is not performed", func(t *testing.T) {
setupTest()
shouldEncrypt = false
shouldFailEncryption = false
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("bitlocker encryption returns an error", func(t *testing.T) {
setupTest()
shouldEncrypt = true
shouldFailEncryption = true
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("encryption skipped based on various current statuses", func(t *testing.T) {
setupTest()
statusesToTest := []int32{
bitlocker.ConversionStatusDecryptionInProgress,
bitlocker.ConversionStatusDecryptionPaused,
bitlocker.ConversionStatusEncryptionInProgress,
bitlocker.ConversionStatusEncryptionPaused,
}
for _, status := range statusesToTest {
t.Run(fmt.Sprintf("status %d", status), func(t *testing.T) {
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: status}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "skipping encryption as the disk is not available")
require.False(t, encryptFnCalled, "encryption function should not be called")
require.False(t, decryptFnCalled, "decryption function should not be called")
logBuf.Reset() // Reset the log buffer for the next iteration
})
}
})
t.Run("handle misreported decryption error", func(t *testing.T) {
setupTest()
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
enrollReceiver.execEncryptVolumeFn = func(string) (string, error) {
return "", bitlocker.NewEncryptionError("", bitlocker.ErrorCodeNotDecrypted)
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "disk encryption failed due to previous unsuccessful attempt, user action required")
require.False(t, encryptFnCalled, "encryption function should not be called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("decrypts the disk if previously encrypted", func(t *testing.T) {
setupTest()
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "disk was previously encrypted. Attempting to decrypt it")
require.False(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.False(t, encryptFnCalled, "encryption function should not have been called")
require.True(t, decryptFnCalled, "decryption function should have been called")
})
t.Run("reports to the server if decryption fails", func(t *testing.T) {
setupTest()
shouldFailDecryption = true
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "disk was previously encrypted. Attempting to decrypt it")
require.Contains(t, logBuf.String(), "decryption failed")
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.False(t, encryptFnCalled, "encryption function should not be called")
require.True(t, decryptFnCalled, "decryption function should have been called")
})
t.Run("encryption skipped if last run too recent", func(t *testing.T) {
setupTest()
enrollReceiver.lastRun = time.Now().Add(-30 * time.Minute)
enrollReceiver.Frequency = 1 * time.Hour
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "skipped encryption process, last run was too recent")
require.False(t, encryptFnCalled, "encryption function should not be called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("successful fleet server update", func(t *testing.T) {
setupTest()
shouldFailEncryption = false
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("failed fleet server update", func(t *testing.T) {
setupTest()
shouldFailEncryption = false
shouldFailServerUpdate = true
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "failed to send encryption result to Fleet Server")
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
}