mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 09:28:54 +00:00
908 lines
31 KiB
Go
908 lines
31 KiB
Go
package apple_mdm
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
|
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/push"
|
|
nanomdm_pushsvc "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/push/service"
|
|
mdmmock "github.com/fleetdm/fleet/v4/server/mock/mdm"
|
|
svcmock "github.com/fleetdm/fleet/v4/server/service/mock"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
micromdm "github.com/micromdm/micromdm/mdm/mdm"
|
|
"github.com/micromdm/nanolib/log/stdlogfmt"
|
|
"github.com/micromdm/plist"
|
|
"github.com/smallstep/pkcs7"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// mockConflictError is used in tests to simulate a conflict error
|
|
type mockConflictError struct {
|
|
msg string
|
|
}
|
|
|
|
func (e *mockConflictError) Error() string {
|
|
return e.msg
|
|
}
|
|
|
|
func (e *mockConflictError) IsConflict() bool {
|
|
return true
|
|
}
|
|
|
|
func TestMDMAppleCommander(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
pushFactory, _ := newMockAPNSPushProviderFactory()
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
// TODO(roberto): there's a data race in the mock when more
|
|
// than one host ID is provided because the pusher uses one
|
|
// goroutine per uuid to send the commands
|
|
hostUUIDs := []string{"A"}
|
|
payloadName := "com.foo.bar"
|
|
payloadIdentifier := "com-foo-bar"
|
|
mc := mobileconfigForTest(payloadName, payloadIdentifier)
|
|
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
require.NotNil(t, cmd)
|
|
require.Equal(t, cmd.Command.Command.RequestType, "InstallProfile")
|
|
var fullCmd micromdm.CommandPayload
|
|
require.NoError(t, plist.Unmarshal(cmd.Raw, &fullCmd))
|
|
p7, err := pkcs7.Parse(fullCmd.Command.InstallProfile.Payload)
|
|
require.NoError(t, err)
|
|
require.Equal(t, string(p7.Content), string(mc))
|
|
return nil, nil
|
|
}
|
|
|
|
mdmStorage.RetrievePushInfoFunc = func(p0 context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) {
|
|
require.ElementsMatch(t, hostUUIDs, targetUUIDs)
|
|
pushes := make(map[string]*mdm.Push, len(targetUUIDs))
|
|
for _, uuid := range targetUUIDs {
|
|
pushes[uuid] = &mdm.Push{
|
|
PushMagic: "magic" + uuid,
|
|
Token: []byte("token" + uuid),
|
|
Topic: "topic" + uuid,
|
|
}
|
|
}
|
|
|
|
return pushes, nil
|
|
}
|
|
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
cert, err := tls.LoadX509KeyPair("../../service/testdata/server.pem", "../../service/testdata/server.key")
|
|
return &cert, "", err
|
|
}
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
mdmStorage.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName,
|
|
_ sqlx.QueryerContext,
|
|
) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) {
|
|
certPEM, err := os.ReadFile("../../service/testdata/server.pem")
|
|
require.NoError(t, err)
|
|
keyPEM, err := os.ReadFile("../../service/testdata/server.key")
|
|
require.NoError(t, err)
|
|
return map[fleet.MDMAssetName]fleet.MDMConfigAsset{
|
|
fleet.MDMAssetCACert: {Value: certPEM},
|
|
fleet.MDMAssetCAKey: {Value: keyPEM},
|
|
}, nil
|
|
}
|
|
|
|
cmdUUID := uuid.New().String()
|
|
err := cmdr.InstallProfile(ctx, hostUUIDs, mc, cmdUUID, "")
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
require.NotNil(t, cmd)
|
|
require.Equal(t, "RemoveProfile", cmd.Command.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), payloadIdentifier)
|
|
return nil, nil
|
|
}
|
|
cmdUUID = uuid.New().String()
|
|
err = cmdr.RemoveProfile(ctx, hostUUIDs, payloadIdentifier, cmdUUID, "")
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
require.NoError(t, err)
|
|
|
|
cmdUUID = uuid.New().String()
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
require.NotNil(t, cmd)
|
|
require.Equal(t, "InstallEnterpriseApplication", cmd.Command.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), "http://test.example.com")
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
return nil, nil
|
|
}
|
|
err = cmdr.InstallEnterpriseApplication(ctx, hostUUIDs, "http://test.example.com", cmdUUID)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
|
|
host := &fleet.Host{ID: 1, UUID: "A", Platform: "darwin"}
|
|
cmdUUID = uuid.New().String()
|
|
|
|
// Mock GetPendingLockCommand to return nil (no pending command)
|
|
mdmStorage.GetPendingLockCommandFunc = func(ctx context.Context, hostUUID string) (*mdm.Command, string, error) {
|
|
return nil, "", nil
|
|
}
|
|
|
|
mdmStorage.EnqueueDeviceLockCommandFunc = func(ctx context.Context, gotHost *fleet.Host, cmd *mdm.Command, pin string) error {
|
|
require.NotNil(t, gotHost)
|
|
require.Equal(t, host.ID, gotHost.ID)
|
|
require.Equal(t, host.UUID, gotHost.UUID)
|
|
require.Equal(t, "DeviceLock", cmd.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
require.Len(t, pin, 6)
|
|
return nil
|
|
}
|
|
pin, err := cmdr.DeviceLock(ctx, host, cmdUUID)
|
|
require.NoError(t, err)
|
|
require.Len(t, pin, 6)
|
|
require.True(t, mdmStorage.EnqueueDeviceLockCommandFuncInvoked)
|
|
mdmStorage.EnqueueDeviceLockCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
|
|
cmdUUID = uuid.New().String()
|
|
orgName := "My Org Name"
|
|
mdmStorage.EnqueueDeviceLockCommandFunc = func(ctx context.Context, gotHost *fleet.Host, cmd *mdm.Command, pin string) error {
|
|
require.NotNil(t, gotHost)
|
|
require.Equal(t, host.ID, gotHost.ID)
|
|
require.Equal(t, host.UUID, gotHost.UUID)
|
|
require.Equal(t, "EnableLostMode", cmd.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
require.Contains(t, string(cmd.Raw), orgName)
|
|
require.Empty(t, pin)
|
|
return nil
|
|
}
|
|
err = cmdr.EnableLostMode(ctx, host, cmdUUID, orgName)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueDeviceLockCommandFuncInvoked)
|
|
mdmStorage.EnqueueDeviceLockCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
|
|
mdmStorage.EnqueueDeviceUnlockCommandFunc = func(ctx context.Context, gotHost *fleet.Host, cmd *mdm.Command) error {
|
|
require.NotNil(t, gotHost)
|
|
require.Equal(t, host.ID, gotHost.ID)
|
|
require.Equal(t, host.UUID, gotHost.UUID)
|
|
require.Equal(t, "DisableLostMode", cmd.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
return nil
|
|
}
|
|
err = cmdr.DisableLostMode(ctx, host, cmdUUID)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueDeviceUnlockCommandFuncInvoked)
|
|
mdmStorage.EnqueueDeviceUnlockCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
|
|
cmdUUID = uuid.New().String()
|
|
mdmStorage.EnqueueDeviceWipeCommandFunc = func(ctx context.Context, gotHost *fleet.Host, cmd *mdm.Command) error {
|
|
require.NotNil(t, gotHost)
|
|
require.Equal(t, host.ID, gotHost.ID)
|
|
require.Equal(t, host.UUID, gotHost.UUID)
|
|
require.Equal(t, "EraseDevice", cmd.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
return nil
|
|
}
|
|
err = cmdr.EraseDevice(ctx, host, cmdUUID)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueDeviceWipeCommandFuncInvoked)
|
|
mdmStorage.EnqueueDeviceWipeCommandFuncInvoked = false
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
mdmStorage.RetrievePushInfoFuncInvoked = false
|
|
}
|
|
|
|
func TestMDMAppleCommanderConcurrentDeviceLock(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
pushFactory, _ := newMockAPNSPushProviderFactory()
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
host := &fleet.Host{ID: 1, UUID: "TEST-HOST", Platform: "darwin"}
|
|
|
|
// Variables to track calls (with mutex for thread safety)
|
|
var mu sync.Mutex
|
|
var pendingCommand *mdm.Command
|
|
var pendingPIN string
|
|
enqueueCalls := 0
|
|
getPendingCalls := 0
|
|
|
|
// Mock GetPendingLockCommand
|
|
// Need to track state across concurrent calls
|
|
var commandCreated bool
|
|
mdmStorage.GetPendingLockCommandFunc = func(ctx context.Context, hostUUID string) (*mdm.Command, string, error) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
getPendingCalls++
|
|
require.Equal(t, host.UUID, hostUUID)
|
|
// After the first command is enqueued, return it as pending
|
|
if commandCreated && pendingCommand != nil {
|
|
return pendingCommand, pendingPIN, nil
|
|
}
|
|
return nil, "", nil
|
|
}
|
|
|
|
// Mock EnqueueDeviceLockCommand
|
|
mdmStorage.EnqueueDeviceLockCommandFunc = func(ctx context.Context, gotHost *fleet.Host, cmd *mdm.Command, pin string) error {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
enqueueCalls++
|
|
require.NotNil(t, gotHost)
|
|
require.Equal(t, host.ID, gotHost.ID)
|
|
require.Equal(t, host.UUID, gotHost.UUID)
|
|
require.Equal(t, "DeviceLock", cmd.Command.RequestType)
|
|
require.Len(t, pin, 6)
|
|
// Store the first command as pending, reject others with conflict
|
|
if !commandCreated {
|
|
pendingCommand = cmd
|
|
pendingPIN = pin
|
|
commandCreated = true
|
|
return nil
|
|
}
|
|
// Command already exists, return conflict error
|
|
return &mockConflictError{msg: "host already has a pending lock command"}
|
|
}
|
|
|
|
// Mock RetrievePushInfo
|
|
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, tokens []string) (map[string]*mdm.Push, error) {
|
|
res := make(map[string]*mdm.Push)
|
|
for _, token := range tokens {
|
|
res[token] = &mdm.Push{
|
|
PushMagic: "magic",
|
|
Token: []byte("token"),
|
|
Topic: "topic",
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// Mock RetrievePushCert
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
// Return a mock certificate
|
|
return &tls.Certificate{}, "staleToken", nil
|
|
}
|
|
|
|
// Mock IsPushCertStale - return false (cert is not stale)
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
// Simulate concurrent lock requests
|
|
numGoroutines := 10
|
|
results := make(chan string, numGoroutines)
|
|
errors := make(chan error, numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func(idx int) {
|
|
cmdUUID := fmt.Sprintf("cmd-uuid-%d", idx)
|
|
pin, err := cmdr.DeviceLock(ctx, host, cmdUUID)
|
|
if err != nil {
|
|
errors <- err
|
|
} else {
|
|
results <- pin
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Collect results
|
|
var pins []string
|
|
for i := 0; i < numGoroutines; i++ {
|
|
select {
|
|
case pin := <-results:
|
|
pins = append(pins, pin)
|
|
case err := <-errors:
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
// Verify results
|
|
require.Len(t, pins, numGoroutines, "All requests should succeed")
|
|
|
|
// All PINs should be the same
|
|
firstPIN := pins[0]
|
|
for _, pin := range pins {
|
|
require.Equal(t, firstPIN, pin, "All requests should return the same PIN")
|
|
}
|
|
|
|
// Due to race conditions, multiple goroutines may attempt to enqueue
|
|
// but only one should succeed, the rest should get conflict errors.
|
|
// The important thing is that all requests return the same PIN
|
|
require.GreaterOrEqual(t, enqueueCalls, 1, "At least one enqueue attempt should be made")
|
|
require.LessOrEqual(t, enqueueCalls, numGoroutines, "At most numGoroutines enqueue attempts")
|
|
|
|
// GetPendingLockCommand should have been called multiple times
|
|
// This includes both initial checks and post-conflict checks
|
|
require.GreaterOrEqual(t, getPendingCalls, numGoroutines, "GetPendingLockCommand should be called at least once per request")
|
|
}
|
|
|
|
func TestMDMAppleCommanderDeviceLockPushNotificationFailure(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
|
|
// Create a mock push provider that will fail
|
|
pushProvider := &svcmock.APNSPushProvider{}
|
|
pushFactory := &svcmock.APNSPushProviderFactory{}
|
|
pushFactory.NewPushProviderFunc = func(*tls.Certificate) (push.PushProvider, error) {
|
|
return pushProvider, nil
|
|
}
|
|
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
host := &fleet.Host{ID: 1, UUID: "TEST-HOST-PUSH-FAIL", Platform: "darwin"}
|
|
|
|
// Track whether we're on the first or second request
|
|
var requestCount int
|
|
var existingCommand *mdm.Command
|
|
var existingPIN string
|
|
|
|
// Mock GetPendingLockCommand
|
|
mdmStorage.GetPendingLockCommandFunc = func(ctx context.Context, hostUUID string) (*mdm.Command, string, error) {
|
|
requestCount++
|
|
require.Equal(t, host.UUID, hostUUID)
|
|
|
|
switch requestCount {
|
|
case 1:
|
|
// First request - no pending command
|
|
return nil, "", nil
|
|
case 2:
|
|
// Second request initial check - still no pending command
|
|
// (hasn't been created yet)
|
|
return nil, "", nil
|
|
case 3:
|
|
// Second request after conflict - return the existing command
|
|
return existingCommand, existingPIN, nil
|
|
default:
|
|
t.Fatalf("Unexpected call to GetPendingLockCommand: %d", requestCount)
|
|
return nil, "", nil
|
|
}
|
|
}
|
|
|
|
// Mock EnqueueDeviceLockCommand
|
|
var enqueueCalls int
|
|
mdmStorage.EnqueueDeviceLockCommandFunc = func(ctx context.Context, gotHost *fleet.Host, cmd *mdm.Command, pin string) error {
|
|
enqueueCalls++
|
|
require.NotNil(t, gotHost)
|
|
require.Equal(t, host.ID, gotHost.ID)
|
|
require.Equal(t, "DeviceLock", cmd.Command.RequestType)
|
|
|
|
switch enqueueCalls {
|
|
case 1:
|
|
// First request succeeds
|
|
existingCommand = cmd
|
|
existingPIN = pin
|
|
return nil
|
|
case 2:
|
|
// Second request gets conflict
|
|
return &mockConflictError{msg: "host already has a pending lock command"}
|
|
default:
|
|
t.Fatalf("Unexpected call to EnqueueDeviceLockCommand: %d", enqueueCalls)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Mock RetrievePushInfo
|
|
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, tokens []string) (map[string]*mdm.Push, error) {
|
|
res := make(map[string]*mdm.Push)
|
|
for _, token := range tokens {
|
|
res[token] = &mdm.Push{
|
|
PushMagic: "magic",
|
|
Token: []byte("token"),
|
|
Topic: "topic",
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// Mock RetrievePushCert
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
return &tls.Certificate{}, "staleToken", nil
|
|
}
|
|
|
|
// Mock IsPushCertStale
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
// Configure push provider to fail on conflict scenario
|
|
var pushAttempts int
|
|
pushProvider.PushFunc = func(ctx context.Context, pushes []*mdm.Push) (map[string]*push.Response, error) {
|
|
pushAttempts++
|
|
|
|
switch pushAttempts {
|
|
case 1:
|
|
// First request - push succeeds
|
|
return mockSuccessfulPush(ctx, pushes)
|
|
case 2:
|
|
// Second request during conflict handling - push fails
|
|
// This simulates a network error or push service issue
|
|
return nil, errors.New("push notification service unavailable")
|
|
default:
|
|
t.Fatalf("Unexpected push attempt: %d", pushAttempts)
|
|
return nil, nil
|
|
}
|
|
}
|
|
|
|
// First request - should succeed normally
|
|
pin1, err := cmdr.DeviceLock(ctx, host, "cmd-uuid-1")
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, pin1)
|
|
require.Len(t, pin1, 6)
|
|
|
|
// Reset request count for second request
|
|
requestCount = 1
|
|
|
|
// Second concurrent request - should get conflict but still return PIN despite push failure
|
|
pin2, err := cmdr.DeviceLock(ctx, host, "cmd-uuid-2")
|
|
require.NoError(t, err, "Should not return error even when push notification fails")
|
|
require.NotEmpty(t, pin2)
|
|
require.Equal(t, pin1, pin2, "Should return the same PIN as first request")
|
|
|
|
// Verify the expected number of calls
|
|
require.Equal(t, 2, enqueueCalls, "Should have attempted to enqueue twice")
|
|
require.Equal(t, 2, pushAttempts, "Should have attempted push twice")
|
|
require.Equal(t, 3, requestCount, "Should have called GetPendingLockCommand three times")
|
|
}
|
|
|
|
func newMockAPNSPushProviderFactory() (*svcmock.APNSPushProviderFactory, *svcmock.APNSPushProvider) {
|
|
provider := &svcmock.APNSPushProvider{}
|
|
provider.PushFunc = mockSuccessfulPush
|
|
factory := &svcmock.APNSPushProviderFactory{}
|
|
factory.NewPushProviderFunc = func(*tls.Certificate) (push.PushProvider, error) {
|
|
return provider, nil
|
|
}
|
|
|
|
return factory, provider
|
|
}
|
|
|
|
func mockSuccessfulPush(_ context.Context, pushes []*mdm.Push) (map[string]*push.Response, error) {
|
|
res := make(map[string]*push.Response, len(pushes))
|
|
for _, p := range pushes {
|
|
res[p.Token.String()] = &push.Response{
|
|
Id: uuid.New().String(),
|
|
Err: nil,
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func mobileconfigForTest(name, identifier string) []byte {
|
|
return []byte(fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
<plist version="1.0">
|
|
<dict>
|
|
<key>PayloadContent</key>
|
|
<array/>
|
|
<key>PayloadDisplayName</key>
|
|
<string>%s</string>
|
|
<key>PayloadIdentifier</key>
|
|
<string>%s</string>
|
|
<key>PayloadType</key>
|
|
<string>Configuration</string>
|
|
<key>PayloadUUID</key>
|
|
<string>%s</string>
|
|
<key>PayloadVersion</key>
|
|
<integer>1</integer>
|
|
</dict>
|
|
</plist>
|
|
`, name, identifier, uuid.New().String()))
|
|
}
|
|
|
|
func TestAPNSDeliveryError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
errorsByUUID map[string]error
|
|
expectedError string
|
|
expectedFailedUUIDs []string
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
name: "single error",
|
|
errorsByUUID: map[string]error{
|
|
"uuid1": errors.New("network error"),
|
|
},
|
|
expectedError: `APNS delivery failed with the following errors:
|
|
UUID: uuid1, Error: network error`,
|
|
expectedFailedUUIDs: []string{"uuid1"},
|
|
expectedStatusCode: http.StatusBadGateway,
|
|
},
|
|
{
|
|
name: "multiple errors, sorted",
|
|
errorsByUUID: map[string]error{
|
|
"uuid3": errors.New("timeout error"),
|
|
"uuid1": errors.New("network error"),
|
|
"uuid2": errors.New("certificate error"),
|
|
},
|
|
expectedError: `APNS delivery failed with the following errors:
|
|
UUID: uuid1, Error: network error
|
|
UUID: uuid2, Error: certificate error
|
|
UUID: uuid3, Error: timeout error`,
|
|
expectedFailedUUIDs: []string{"uuid1", "uuid2", "uuid3"},
|
|
expectedStatusCode: http.StatusBadGateway,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
apnsErr := &APNSDeliveryError{
|
|
errorsByUUID: tt.errorsByUUID,
|
|
}
|
|
|
|
require.Equal(t, tt.expectedError, apnsErr.Error())
|
|
require.Equal(t, tt.expectedFailedUUIDs, apnsErr.FailedUUIDs())
|
|
require.Equal(t, tt.expectedStatusCode, apnsErr.StatusCode())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMDMAppleCommanderSetRecoveryLock(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
pushFactory, _ := newMockAPNSPushProviderFactory()
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
hostUUIDs := []string{"host-uuid-1"}
|
|
cmdUUID := uuid.New().String()
|
|
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
require.NotNil(t, cmd)
|
|
require.Equal(t, "SetRecoveryLock", cmd.Command.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
require.Contains(t, string(cmd.Raw), "SetRecoveryLock")
|
|
// Should contain the placeholder, not the actual password
|
|
require.Contains(t, string(cmd.Raw), "$"+fleet.HostSecretPrefix+fleet.HostSecretRecoveryLockPassword)
|
|
require.Contains(t, string(cmd.Raw), "<key>NewPassword</key>")
|
|
return nil, nil
|
|
}
|
|
|
|
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) {
|
|
require.ElementsMatch(t, hostUUIDs, targetUUIDs)
|
|
pushes := make(map[string]*mdm.Push, len(targetUUIDs))
|
|
for _, uuid := range targetUUIDs {
|
|
pushes[uuid] = &mdm.Push{
|
|
PushMagic: "magic" + uuid,
|
|
Token: []byte("token" + uuid),
|
|
Topic: "topic" + uuid,
|
|
}
|
|
}
|
|
return pushes, nil
|
|
}
|
|
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
cert, err := tls.LoadX509KeyPair("../../service/testdata/server.pem", "../../service/testdata/server.key")
|
|
return &cert, "", err
|
|
}
|
|
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
err := cmdr.SetRecoveryLock(ctx, hostUUIDs, cmdUUID)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
}
|
|
|
|
func TestMDMAppleCommanderClearPasscode(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
pushFactory, _ := newMockAPNSPushProviderFactory()
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
hostUUIDs := []string{"host-uuid-1"}
|
|
cmdUUID := uuid.New().String()
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
require.NotNil(t, cmd)
|
|
require.Equal(t, "ClearPasscode", cmd.Command.Command.RequestType)
|
|
require.Contains(t, string(cmd.Raw), "$"+fleet.HostSecretPrefix+fleet.HostSecretMDMUnlockToken, "Clear passcode should not use direct unlock token but rather Host secret")
|
|
require.Contains(t, string(cmd.Raw), cmdUUID)
|
|
return nil, nil
|
|
}
|
|
|
|
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) {
|
|
require.ElementsMatch(t, hostUUIDs, targetUUIDs)
|
|
pushes := make(map[string]*mdm.Push, len(targetUUIDs))
|
|
for _, uuid := range targetUUIDs {
|
|
pushes[uuid] = &mdm.Push{
|
|
PushMagic: "magic" + uuid,
|
|
Token: []byte("token" + uuid),
|
|
Topic: "topic" + uuid,
|
|
}
|
|
}
|
|
return pushes, nil
|
|
}
|
|
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
cert, err := tls.LoadX509KeyPair("../../service/testdata/server.pem", "../../service/testdata/server.key")
|
|
return &cert, "", err
|
|
}
|
|
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
err := cmdr.ClearPasscode(ctx, hostUUIDs, cmdUUID)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
|
|
}
|
|
|
|
func TestAccountConfigurationWithAdminAccount(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
pushFactory, _ := newMockAPNSPushProviderFactory()
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
hostUUIDs := []string{"ABC"}
|
|
cmdUUID := uuid.New().String()
|
|
|
|
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, targets []string) (map[string]*mdm.Push, error) {
|
|
pushes := make(map[string]*mdm.Push, len(targets))
|
|
for _, uuid := range targets {
|
|
pushes[uuid] = &mdm.Push{
|
|
PushMagic: "magic",
|
|
Token: []byte("token"),
|
|
Topic: "topic",
|
|
}
|
|
}
|
|
return pushes, nil
|
|
}
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
cert, err := tls.LoadX509KeyPair("../../service/testdata/server.pem", "../../service/testdata/server.key")
|
|
return &cert, "", err
|
|
}
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
t.Run("SSO only produces standard plist", func(t *testing.T) {
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
raw := string(cmd.Raw)
|
|
require.Contains(t, raw, "AccountConfiguration")
|
|
require.Contains(t, raw, "<key>PrimaryAccountFullName</key>")
|
|
require.Contains(t, raw, "<string>Test User</string>")
|
|
require.Contains(t, raw, "<key>PrimaryAccountUserName</key>")
|
|
require.Contains(t, raw, "<string>testuser</string>")
|
|
require.NotContains(t, raw, "AutoSetupAdminAccounts")
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
err := cmdr.AccountConfiguration(ctx, hostUUIDs, cmdUUID,
|
|
&SSOAccountConfig{FullName: "Test User", UserName: "testuser", LockPrimaryAccountInfo: true},
|
|
nil,
|
|
)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
})
|
|
|
|
t.Run("admin account adds AutoSetupAdminAccounts", func(t *testing.T) {
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
raw := string(cmd.Raw)
|
|
require.Contains(t, raw, "AccountConfiguration")
|
|
require.Contains(t, raw, "<key>AutoSetupAdminAccounts</key>")
|
|
require.Contains(t, raw, "<string>_fleetadmin</string>")
|
|
require.Contains(t, raw, "<string>Fleet Admin</string>")
|
|
require.Contains(t, raw, "<key>hidden</key>")
|
|
require.Contains(t, raw, "<true />")
|
|
require.Contains(t, raw, "<key>passwordHash</key>")
|
|
require.NotContains(t, raw, "PrimaryAccountFullName")
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
err := cmdr.AccountConfiguration(ctx, hostUUIDs, cmdUUID,
|
|
nil,
|
|
&AdminAccountConfig{ShortName: "_fleetadmin", FullName: "Fleet Admin", PasswordHash: []byte("fake-hash"), Hidden: true},
|
|
)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
})
|
|
|
|
t.Run("SSO + admin combined", func(t *testing.T) {
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
raw := string(cmd.Raw)
|
|
require.Contains(t, raw, "PrimaryAccountFullName")
|
|
require.Contains(t, raw, "AutoSetupAdminAccounts")
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
err := cmdr.AccountConfiguration(ctx, hostUUIDs, cmdUUID,
|
|
&SSOAccountConfig{FullName: "SSO User", UserName: "ssouser", LockPrimaryAccountInfo: false},
|
|
&AdminAccountConfig{ShortName: "_fleetadmin", FullName: "Fleet Admin", PasswordHash: []byte("fake-hash"), Hidden: true},
|
|
)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
})
|
|
}
|
|
|
|
func TestMDMAppleCommanderPassesCommandName(t *testing.T) {
|
|
ctx := context.Background()
|
|
mdmStorage := &mdmmock.MDMAppleStore{}
|
|
pushFactory, _ := newMockAPNSPushProviderFactory()
|
|
pusher := nanomdm_pushsvc.New(
|
|
mdmStorage,
|
|
mdmStorage,
|
|
pushFactory,
|
|
stdlogfmt.New(),
|
|
)
|
|
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
|
|
|
|
hostUUIDs := []string{"A"}
|
|
payloadName := "com.foo.bar"
|
|
payloadIdentifier := "com-foo-bar"
|
|
mc := mobileconfigForTest(payloadName, payloadIdentifier)
|
|
|
|
mdmStorage.RetrievePushInfoFunc = func(p0 context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) {
|
|
pushes := make(map[string]*mdm.Push, len(targetUUIDs))
|
|
for _, uuid := range targetUUIDs {
|
|
pushes[uuid] = &mdm.Push{
|
|
PushMagic: "magic" + uuid,
|
|
Token: []byte("token" + uuid),
|
|
Topic: "topic" + uuid,
|
|
}
|
|
}
|
|
return pushes, nil
|
|
}
|
|
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
|
cert, err := tls.LoadX509KeyPair("../../service/testdata/server.pem", "../../service/testdata/server.key")
|
|
return &cert, "", err
|
|
}
|
|
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
|
return false, nil
|
|
}
|
|
mdmStorage.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName,
|
|
_ sqlx.QueryerContext,
|
|
) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) {
|
|
certPEM, err := os.ReadFile("../../service/testdata/server.pem")
|
|
require.NoError(t, err)
|
|
keyPEM, err := os.ReadFile("../../service/testdata/server.key")
|
|
require.NoError(t, err)
|
|
return map[fleet.MDMAssetName]fleet.MDMConfigAsset{
|
|
fleet.MDMAssetCACert: {Value: certPEM},
|
|
fleet.MDMAssetCAKey: {Value: keyPEM},
|
|
}, nil
|
|
}
|
|
|
|
t.Run("InstallProfile with a name", func(t *testing.T) {
|
|
var gotName string
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
gotName = cmd.Name
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
cmdUUID := uuid.New().String()
|
|
err := cmdr.InstallProfile(ctx, hostUUIDs, mc, cmdUUID, "My Profile")
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.Equal(t, "My Profile", gotName)
|
|
})
|
|
|
|
t.Run("InstallProfile without a name", func(t *testing.T) {
|
|
var gotName string
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
gotName = cmd.Name
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
cmdUUID := uuid.New().String()
|
|
err := cmdr.InstallProfile(ctx, hostUUIDs, mc, cmdUUID, "")
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.Empty(t, gotName)
|
|
})
|
|
|
|
t.Run("RemoveProfile with a name", func(t *testing.T) {
|
|
var gotName string
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
gotName = cmd.Name
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
cmdUUID := uuid.New().String()
|
|
err := cmdr.RemoveProfile(ctx, hostUUIDs, payloadIdentifier, cmdUUID, "My Profile")
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.Equal(t, "My Profile", gotName)
|
|
})
|
|
|
|
t.Run("EnqueueCommand no name", func(t *testing.T) {
|
|
var gotName string
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
gotName = cmd.Name
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
cmdUUID := uuid.New().String()
|
|
rawCmd := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
<plist version="1.0">
|
|
<dict>
|
|
<key>CommandUUID</key>
|
|
<string>%s</string>
|
|
<key>Command</key>
|
|
<dict>
|
|
<key>RequestType</key>
|
|
<string>ProfileList</string>
|
|
</dict>
|
|
</dict>
|
|
</plist>`, cmdUUID)
|
|
err := cmdr.EnqueueCommand(ctx, hostUUIDs, rawCmd)
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.Empty(t, gotName)
|
|
})
|
|
|
|
t.Run("EnqueueCommandInstallProfileWithSecrets with a name", func(t *testing.T) {
|
|
var gotName string
|
|
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) {
|
|
gotName = cmd.Name
|
|
return nil, nil
|
|
}
|
|
mdmStorage.EnqueueCommandFuncInvoked = false
|
|
|
|
cmdUUID := uuid.New().String()
|
|
err := cmdr.EnqueueCommandInstallProfileWithSecrets(ctx, hostUUIDs, mc, cmdUUID, "Secret Profile")
|
|
require.NoError(t, err)
|
|
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
|
|
require.Equal(t, "Secret Profile", gotName)
|
|
})
|
|
}
|