fleet/orbit/pkg/update/notifications_test.go
Dante Catalfamo 643fc8314b
Orbit config receiver (#18518)
New interface for adding periodic jobs that rely on notifications/config
changes in Orbit.

Previously if we wanted to have recurring checks in Orbit, we would add
them into a chain of `GetConfig` calls. This call chain would be run
periodically by one of the runners registered with the cli application
framework.

The new method to register `OrbitConfigReceivers` with the
`OrbitClient`, and then register the orbit client itself with the
application framework.

Instead of having giving each fetcher an internal reference to the
previous fetcher that it must call, the receiver is registered with the
client and the new config is passed to the receiver.

This is the old `GetConfig()` interface:

```go
type OrbitConfigFetcher interface {
	GetConfig() (*fleet.OrbitConfig, error)
}
```

This is the new `OrbitConfigReceiver` interface:

```go
type OrbitConfigReceiver interface {
	Run(*OrbitConfig) error
}
```

To register a new receiver, you call the `RegisterConfigReceiver` method
on the client.

```go
orbitClient.RegisterConfigReceiver(extRunner)
```

Downsides of the old method:
- Spaghetti call chain setup
- Cascading failure, of one fails, all after it fail
- Run in series,  one long function call holds up the rest
- Anything that wants to restart orbit is added as a Runner to the
application, meaning there could be several timers calling `GetConfig`
and running the chain

Benefits of the new method:
- Clean `RegisterConfigReceiver` api, no call chaining required
- Config receivers can be added at runtime
- Isolated receivers, one failing call don't effect others
- All calls are run in parallel in goroutines, no calls can hold up the
rest
- No more need for multiple runners, using a context cancel, any
receiver can queue a call to restart orbit
- Single point to handle errors and logging for all receivers
- Panic recovery to stop orbit from crashing
- Easier to test, configs are passed in and do not require a call chain

This branch contains a little bit of code from the installer method I
was working on because I branched it off of that. (oops)

Not all code comments surrounding old `GetConfig()` methods have been
fully updated yet

Possible changes:
- Update the interface to take a context, so we can let receivers know
to exit early. I can imagine two cases for this:
  - The application is about to restart
  - We can set a timeout for how long receivers are allowed to take

Closes #12662

---------

Co-authored-by: Martin Angers <martin.n.angers@gmail.com>
Co-authored-by: Roberto Dip <dip.jesusr@gmail.com>
2024-05-09 15:22:56 -04:00

755 lines
26 KiB
Go

package update
import (
"bytes"
"errors"
"fmt"
"io"
"sync/atomic"
"testing"
"time"
"github.com/fleetdm/fleet/v4/orbit/pkg/bitlocker"
"github.com/fleetdm/fleet/v4/orbit/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
)
func TestRenewEnrollmentProfile(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
cases := []struct {
desc string
renewFlag bool
cmdErr error
wantCmdCalled bool
wantLog string
}{
{"renew=false", false, nil, false, ""},
{"renew=true; success", true, nil, true, "successfully called /usr/bin/profiles to renew enrollment profile"},
{"renew=true; fail", true, io.ErrUnexpectedEOF, true, "calling /usr/bin/profiles to renew enrollment profile failed"},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
logBuf.Reset()
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{RenewEnrollmentProfile: c.renewFlag}}
var cmdGotCalled bool
var depAssignedCheckGotCalled bool
renewReceiver := &renewEnrollmentProfileConfigReceiver{
Frequency: time.Hour, // doesn't matter for this test
runCmdFn: func() error {
cmdGotCalled = true
return c.cmdErr
},
checkEnrollmentFn: func() (bool, string, error) {
return false, "", nil
},
checkAssignedEnrollmentProfileFn: func(url string) error {
depAssignedCheckGotCalled = true
return nil
},
}
err := renewReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
require.Equal(t, c.wantCmdCalled, cmdGotCalled)
require.Equal(t, c.wantCmdCalled, depAssignedCheckGotCalled)
require.Contains(t, logBuf.String(), c.wantLog)
})
}
}
func TestRenewEnrollmentProfilePrevented(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{RenewEnrollmentProfile: true}}
var cmdCallCount int
isEnrolled := false
isAssigned := true
chProceed := make(chan struct{})
renewReceiver := &renewEnrollmentProfileConfigReceiver{
Frequency: 2 * time.Second, // just to be safe with slow environments (CI)
runCmdFn: func() error {
cmdCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
return nil
},
checkEnrollmentFn: func() (bool, string, error) {
<-chProceed // will be unblocked only when allowed
return isEnrolled, "", nil
},
checkAssignedEnrollmentProfileFn: func(url string) error {
<-chProceed // will be unblocked only when allowed
if !isAssigned {
return errors.New("not assigned")
}
return nil
},
}
started := make(chan struct{})
go func() {
close(started)
// the first call will block in runCmdFn
err := renewReceiver.Run(testConfig)
require.NoError(t, err)
}()
<-started
// this call will happen while the first call is blocked in runCmdFn, so it
// won't call the command (won't be able to lock the mutex). However it will
// still complete successfully without being blocked by the other call in
// progress.
err := renewReceiver.Run(testConfig)
require.NoError(t, err)
// unblock the first call
close(chProceed)
// this next call won't execute the command because of the frequency
// restriction (it got called less than N seconds ago)
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this call executes the command
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this call doesn't execute the command since the host is already
// enrolled
isEnrolled = true
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
require.Equal(t, 2, cmdCallCount) // the initial call and the one after sleep
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this call doesn't execute the command since the assigned profile check fails
isAssigned = false
isEnrolled = false
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
require.Equal(t, 2, cmdCallCount) // the initial call and the one after sleep
// wait for the receiver's frequency to pass
time.Sleep(renewReceiver.Frequency)
// this next call won't execute the command because the backoff
// for a failed assigned check is always 2 minutes
err = renewReceiver.Run(testConfig)
require.NoError(t, err)
}
type mockNodeKeyGetter struct{}
func (m mockNodeKeyGetter) GetNodeKey() (string, error) {
return "nodekey-test", nil
}
func TestWindowsMDMEnrollment(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
cases := []struct {
desc string
enrollFlag *bool
unenrollFlag *bool
discoveryURL string
apiErr error
wantAPICalled bool
wantLog string
}{
{"enroll=false", ptr.Bool(false), nil, "", nil, false, ""},
{"enroll=true,discovery=''", ptr.Bool(true), nil, "", nil, false, "discovery endpoint is empty"},
{"enroll=true,discovery!='',success", ptr.Bool(true), nil, "http://example.com", nil, true, "successfully called RegisterDeviceWithManagement"},
{"enroll=true,discovery!='',fail", ptr.Bool(true), nil, "http://example.com", io.ErrUnexpectedEOF, true, "enroll Windows device failed"},
{"enroll=true,discovery!='',server", ptr.Bool(true), nil, "http://example.com", errIsWindowsServer, true, "device is a Windows Server, skipping enrollment"},
{"unenroll=false", nil, ptr.Bool(false), "", nil, false, ""},
{"unenroll=true,success", nil, ptr.Bool(true), "", nil, true, "successfully called UnregisterDeviceWithManagement"},
{"unenroll=true,fail", nil, ptr.Bool(true), "", io.ErrUnexpectedEOF, true, "unenroll Windows device failed"},
{"unenroll=true,server", nil, ptr.Bool(true), "", errIsWindowsServer, true, "device is a Windows Server, skipping unenrollment"},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
logBuf.Reset()
var (
enroll = c.enrollFlag != nil && *c.enrollFlag
unenroll = c.unenrollFlag != nil && *c.unenrollFlag
isUnenroll = c.unenrollFlag != nil
)
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
NeedsProgrammaticWindowsMDMEnrollment: enroll,
NeedsProgrammaticWindowsMDMUnenrollment: unenroll,
WindowsMDMDiscoveryEndpoint: c.discoveryURL,
}}
var enrollGotCalled, unenrollGotCalled bool
enrollReceiver := &windowsMDMEnrollmentConfigReceiver{
Frequency: time.Hour, // doesn't matter for this test
execEnrollFn: func(args WindowsMDMEnrollmentArgs) error {
enrollGotCalled = true
return c.apiErr
},
execUnenrollFn: func(args WindowsMDMEnrollmentArgs) error {
unenrollGotCalled = true
return c.apiErr
},
nodeKeyGetter: mockNodeKeyGetter{},
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
if isUnenroll {
require.Equal(t, c.wantAPICalled, unenrollGotCalled)
require.False(t, enrollGotCalled)
} else {
require.Equal(t, c.wantAPICalled, enrollGotCalled)
require.False(t, unenrollGotCalled)
}
require.Contains(t, logBuf.String(), c.wantLog)
})
}
}
func TestWindowsMDMEnrollmentPrevented(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
cfgs := []fleet.OrbitConfigNotifications{
{
NeedsProgrammaticWindowsMDMEnrollment: true,
WindowsMDMDiscoveryEndpoint: "http://example.com",
},
{
NeedsProgrammaticWindowsMDMUnenrollment: true,
},
}
for _, cfg := range cfgs {
t.Run(fmt.Sprintf("%+v", cfg), func(t *testing.T) {
testConfig := &fleet.OrbitConfig{Notifications: cfg}
var (
apiCallCount int
apiErr error
)
chProceed := make(chan struct{})
receiver := &windowsMDMEnrollmentConfigReceiver{
Frequency: 2 * time.Second, // just to be safe with slow environments (CI)
nodeKeyGetter: mockNodeKeyGetter{},
}
if cfg.NeedsProgrammaticWindowsMDMEnrollment {
receiver.execEnrollFn = func(args WindowsMDMEnrollmentArgs) error {
<-chProceed // will be unblocked only when allowed
apiCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
return apiErr
}
receiver.execUnenrollFn = func(args WindowsMDMEnrollmentArgs) error {
panic("should not be called")
}
} else {
receiver.execUnenrollFn = func(args WindowsMDMEnrollmentArgs) error {
<-chProceed // will be unblocked only when allowed
apiCallCount++ // no need for sync, single-threaded call of this func is guaranteed by the receiver's mutex
return apiErr
}
receiver.execEnrollFn = func(args WindowsMDMEnrollmentArgs) error {
panic("should not be called")
}
}
go func() {
// the first call will block in enroll/unenroll func
err := receiver.Run(testConfig)
require.NoError(t, err)
}()
// wait a little bit to ensure the first `receiver.Run` call runs first.
time.Sleep(100 * time.Millisecond)
// this call will happen while the first call is blocked in
// enroll/unenrollfn, so it won't call the API (won't be able to lock the
// mutex). However it will still complete successfully without being
// blocked by the other call in progress.
err := receiver.Run(testConfig)
require.NoError(t, err)
// unblock the first call and wait for it to complete
close(chProceed)
time.Sleep(100 * time.Millisecond)
// this next call won't execute the command because of the frequency
// restriction (it got called less than N seconds ago)
err = receiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(receiver.Frequency)
// this call executes the command, and it returns the Is Windows Server error
apiErr = errIsWindowsServer
err = receiver.Run(testConfig)
require.NoError(t, err)
// this next call won't execute the command (both due to frequency and the
// detection of windows server)
err = receiver.Run(testConfig)
require.NoError(t, err)
// wait for the receiver's frequency to pass
time.Sleep(receiver.Frequency)
// this next call still won't execute the command (due to the detection of
// windows server)
err = receiver.Run(testConfig)
require.NoError(t, err)
require.Equal(t, 2, apiCallCount) // the initial call and the one that returned errIsWindowsServer after first sleep
})
}
}
func TestRunScripts(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
var (
callsCount atomic.Int64
runFailure error
blockRun chan struct{}
)
mockRun := func(r *scripts.Runner, ids []string) error {
callsCount.Add(1)
if blockRun != nil {
<-blockRun
}
return runFailure
}
waitForRun := func(t *testing.T, r *runScriptsConfigReceiver) {
var ok bool
for start := time.Now(); !ok && time.Since(start) < time.Second; {
ok = r.mu.TryLock()
}
require.True(t, ok, "timed out waiting for the lock to become available")
r.mu.Unlock()
}
t.Run("no pending scripts", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset() })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: nil,
}}
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
// the lock should be available because no goroutine was started
require.True(t, runner.mu.TryLock())
require.Zero(t, callsCount.Load()) // no calls to execute scripts
require.Empty(t, logBuf.String()) // no logs written
})
t.Run("pending scripts succeed", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset() })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a", "b", "c"},
}}
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
require.Equal(t, int64(1), callsCount.Load()) // all scripts executed in a single run
require.Contains(t, logBuf.String(), "received request to run scripts [a b c]")
require.Contains(t, logBuf.String(), "running scripts [a b c] succeeded")
})
t.Run("pending scripts failed", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset(); runFailure = nil })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a", "b", "c"},
}}
runFailure = io.ErrUnexpectedEOF
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
require.Equal(t, int64(1), callsCount.Load()) // all scripts executed in a single run
require.Contains(t, logBuf.String(), "received request to run scripts [a b c]")
require.Contains(t, logBuf.String(), "running scripts failed")
require.Contains(t, logBuf.String(), io.ErrUnexpectedEOF.Error())
})
t.Run("concurrent run prevented", func(t *testing.T) {
t.Cleanup(func() { callsCount.Store(0); logBuf.Reset(); blockRun = nil })
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a", "b", "c"},
}}
blockRun = make(chan struct{})
runner := &runScriptsConfigReceiver{
runScriptsFn: mockRun,
}
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
// call it again, while the previous run is still running
err = runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
// unblock the initial run
close(blockRun)
waitForRun(t, runner)
require.Equal(t, int64(1), callsCount.Load()) // only called once because of mutex
require.Contains(t, logBuf.String(), "received request to run scripts [a b c]")
require.Contains(t, logBuf.String(), "running scripts [a b c] succeeded")
})
t.Run("dynamic enabling of scripts", func(t *testing.T) {
t.Cleanup(logBuf.Reset)
testConfig := &fleet.OrbitConfig{Notifications: fleet.OrbitConfigNotifications{
PendingScriptExecutionIDs: []string{"a"},
}}
var (
scriptsEnabledCalls []bool
dynamicEnabled atomic.Bool
dynamicInterval = 300 * time.Millisecond
)
runner := &runScriptsConfigReceiver{
ScriptsExecutionEnabled: false,
runScriptsFn: func(r *scripts.Runner, s []string) error {
scriptsEnabledCalls = append(scriptsEnabledCalls, r.ScriptExecutionEnabled)
return nil
},
testGetFleetdConfig: func() (*fleet.MDMAppleFleetdConfig, error) {
return &fleet.MDMAppleFleetdConfig{
EnableScripts: dynamicEnabled.Load(),
}, nil
},
dynamicScriptsEnabledCheckInterval: dynamicInterval,
}
// the static Scripts Enabled flag is false, so it relies on the dynamic check
runner.runDynamicScriptsEnabledCheck()
// first call, scripts are disabled
err := runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
// swap scripts execution to true and wait to ensure the dynamic check
// did run.
dynamicEnabled.Store(true)
time.Sleep(dynamicInterval + 100*time.Millisecond)
// second call, scripts are enabled (change exec ID to "b")
testConfig.Notifications.PendingScriptExecutionIDs[0] = "b"
err = runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
// swap scripts execution back to false and wait to ensure the dynamic
// check did run.
dynamicEnabled.Store(false)
time.Sleep(dynamicInterval + 100*time.Millisecond)
// third call, scripts are disabled (change exec ID to "c")
testConfig.Notifications.PendingScriptExecutionIDs[0] = "c"
err = runner.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
waitForRun(t, runner)
// validate the Scripts Enabled flags that were passed to the runScriptsFn
require.Equal(t, []bool{false, true, false}, scriptsEnabledCalls)
require.Contains(t, logBuf.String(), "received request to run scripts [a]")
require.Contains(t, logBuf.String(), "running scripts [a] succeeded")
require.Contains(t, logBuf.String(), "received request to run scripts [b]")
require.Contains(t, logBuf.String(), "running scripts [b] succeeded")
require.Contains(t, logBuf.String(), "received request to run scripts [c]")
require.Contains(t, logBuf.String(), "running scripts [c] succeeded")
})
}
type mockDiskEncryptionKeySetter struct {
SetOrUpdateDiskEncryptionKeyImpl func(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error
SetOrUpdateDiskEncryptionKeyInvoked bool
}
func (m *mockDiskEncryptionKeySetter) SetOrUpdateDiskEncryptionKey(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error {
m.SetOrUpdateDiskEncryptionKeyInvoked = true
return m.SetOrUpdateDiskEncryptionKeyImpl(diskEncryptionStatus)
}
func TestBitlockerOperations(t *testing.T) {
var logBuf bytes.Buffer
oldLog := log.Logger
log.Logger = log.Output(&logBuf)
t.Cleanup(func() { log.Logger = oldLog })
var (
shouldEncrypt = true
shouldFailEncryption = false
shouldFailDecryption = false
shouldFailServerUpdate = false
encryptFnCalled = false
decryptFnCalled = false
)
testConfig := &fleet.OrbitConfig{
Notifications: fleet.OrbitConfigNotifications{
EnforceBitLockerEncryption: shouldEncrypt,
},
}
clientMock := &mockDiskEncryptionKeySetter{}
clientMock.SetOrUpdateDiskEncryptionKeyImpl = func(diskEncryptionStatus fleet.OrbitHostDiskEncryptionKeyPayload) error {
if shouldFailServerUpdate {
return errors.New("server error")
}
return nil
}
var enrollReceiver *windowsMDMBitlockerConfigReceiver
setupTest := func() {
enrollReceiver = &windowsMDMBitlockerConfigReceiver{
Frequency: time.Hour, // doesn't matter for this test
lastRun: time.Now().Add(-2 * time.Hour),
EncryptionResult: clientMock,
execGetEncryptionStatusFn: func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{}, nil
},
execEncryptVolumeFn: func(string) (string, error) {
encryptFnCalled = true
if shouldFailEncryption {
return "", errors.New("error encrypting")
}
return "123456", nil
},
execDecryptVolumeFn: func(string) error {
decryptFnCalled = true
if shouldFailDecryption {
return errors.New("error decrypting")
}
return nil
},
}
shouldEncrypt = true
shouldFailEncryption = false
shouldFailDecryption = false
shouldFailServerUpdate = false
encryptFnCalled = false
decryptFnCalled = false
clientMock.SetOrUpdateDiskEncryptionKeyInvoked = false
logBuf.Reset()
}
t.Run("bitlocker encryption is performed", func(t *testing.T) {
setupTest()
shouldEncrypt = true
shouldFailEncryption = false
shouldFailDecryption = false
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
})
t.Run("bitlocker encryption is not performed", func(t *testing.T) {
setupTest()
shouldEncrypt = false
shouldFailEncryption = false
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("bitlocker encryption returns an error", func(t *testing.T) {
setupTest()
shouldEncrypt = true
shouldFailEncryption = true
err := enrollReceiver.Run(testConfig)
require.NoError(t, err) // the dummy receiver never returns an error
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("encryption skipped based on various current statuses", func(t *testing.T) {
setupTest()
statusesToTest := []int32{
bitlocker.ConversionStatusDecryptionInProgress,
bitlocker.ConversionStatusDecryptionPaused,
bitlocker.ConversionStatusEncryptionInProgress,
bitlocker.ConversionStatusEncryptionPaused,
}
for _, status := range statusesToTest {
t.Run(fmt.Sprintf("status %d", status), func(t *testing.T) {
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: status}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "skipping encryption as the disk is not available")
require.False(t, encryptFnCalled, "encryption function should not be called")
require.False(t, decryptFnCalled, "decryption function should not be called")
logBuf.Reset() // Reset the log buffer for the next iteration
})
}
})
t.Run("handle misreported decryption error", func(t *testing.T) {
setupTest()
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
enrollReceiver.execEncryptVolumeFn = func(string) (string, error) {
return "", bitlocker.NewEncryptionError("", bitlocker.ErrorCodeNotDecrypted)
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "disk encryption failed due to previous unsuccessful attempt, user action required")
require.False(t, encryptFnCalled, "encryption function should not be called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("decrypts the disk if previously encrypted", func(t *testing.T) {
setupTest()
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "disk was previously encrypted. Attempting to decrypt it")
require.False(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.False(t, encryptFnCalled, "encryption function should not have been called")
require.True(t, decryptFnCalled, "decryption function should have been called")
})
t.Run("reports to the server if decryption fails", func(t *testing.T) {
setupTest()
shouldFailDecryption = true
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyEncrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "disk was previously encrypted. Attempting to decrypt it")
require.Contains(t, logBuf.String(), "decryption failed")
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.False(t, encryptFnCalled, "encryption function should not be called")
require.True(t, decryptFnCalled, "decryption function should have been called")
})
t.Run("encryption skipped if last run too recent", func(t *testing.T) {
setupTest()
enrollReceiver.lastRun = time.Now().Add(-30 * time.Minute)
enrollReceiver.Frequency = 1 * time.Hour
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "skipped encryption process, last run was too recent")
require.False(t, encryptFnCalled, "encryption function should not be called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("successful fleet server update", func(t *testing.T) {
setupTest()
shouldFailEncryption = false
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
t.Run("failed fleet server update", func(t *testing.T) {
setupTest()
shouldFailEncryption = false
shouldFailServerUpdate = true
mockStatus := &bitlocker.EncryptionStatus{ConversionStatus: bitlocker.ConversionStatusFullyDecrypted}
enrollReceiver.execGetEncryptionStatusFn = func() ([]bitlocker.VolumeStatus, error) {
return []bitlocker.VolumeStatus{{DriveVolume: "C:", Status: mockStatus}}, nil
}
err := enrollReceiver.Run(testConfig)
require.NoError(t, err)
require.Contains(t, logBuf.String(), "failed to send encryption result to Fleet Server")
require.True(t, clientMock.SetOrUpdateDiskEncryptionKeyInvoked)
require.True(t, encryptFnCalled, "encryption function should have been called")
require.False(t, decryptFnCalled, "decryption function should not be called")
})
}