package mysql
import (
"context"
"database/sql"
"encoding/base64"
"fmt"
"log/slog"
"sync"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestNanoMDMStorage(t *testing.T) {
ds := CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *Datastore)
}{
{"TestEnqueueDeviceLockCommand", testEnqueueDeviceLockCommand},
{"TestGetPendingLockCommand", testGetPendingLockCommand},
{"TestEnqueueDeviceLockCommandRaceCondition", testEnqueueDeviceLockCommandRaceCondition},
{"TestEnqueueDeviceUnlockCommand", testEnqueueDeviceUnlockCommand},
{"TestStoreAuthenticatePreservesBootstrapTokenDuringSCEPRenewal", testStoreAuthenticatePreservesBootstrapTokenDuringSCEPRenewal},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer TruncateTables(t, ds)
c.fn(t, ds)
})
}
}
func testEnqueueDeviceLockCommand(t *testing.T, ds *Datastore) {
ctx := context.Background()
ns, err := ds.NewMDMAppleMDMStorage()
require.NoError(t, err)
host, err := ds.NewHost(ctx, &fleet.Host{
Hostname: "test-host1-name",
OsqueryHostID: ptr.String("1337"),
NodeKey: ptr.String("1337"),
UUID: "test-uuid-1",
TeamID: nil,
Platform: "darwin",
})
require.NoError(t, err)
nanoEnroll(t, ds, host, false)
// no commands yet
res, err := ds.ListMDMAppleCommands(ctx, fleet.TeamFilter{User: test.UserAdmin}, &fleet.MDMCommandListOptions{})
require.NoError(t, err)
require.Empty(t, res)
cmd := &mdm.Command{}
cmd.CommandUUID = "cmd-uuid"
cmd.Command.RequestType = "DeviceLock"
cmd.Raw = []byte("')`,
host.UUID, "lock-cmd-uuid")
require.NoError(t, err)
// Now no pending command should exist
cmd, pin, err = ns.GetPendingLockCommand(ctx, host.UUID)
require.NoError(t, err)
require.Nil(t, cmd)
require.Empty(t, pin)
// Test 6: After acknowledgment, the lock_ref still exists in host_mdm_actions
// This is expected behavior - the device remains locked until manually unlocked
// Therefore, attempting to create a new lock command should still fail
lockCmd3 := &mdm.Command{}
lockCmd3.CommandUUID = "lock-cmd-uuid-3"
lockCmd3.Command.RequestType = "DeviceLock"
lockCmd3.Raw = []byte("')`,
renewCmdUUID)
require.NoError(t, err)
_, err = ds.writer(ctx).ExecContext(ctx,
`UPDATE nano_cert_auth_associations SET renew_command_uuid = ? WHERE id = ?`,
renewCmdUUID, deviceUUID)
require.NoError(t, err)
// Now call StoreAuthenticate again — this simulates the device checking in during SCEP renewal.
authMsg2 := &mdm.Authenticate{
Enrollment: mdm.Enrollment{UDID: deviceUUID},
Raw: []byte("auth-raw-scep-renewal"),
}
authMsg2.SerialNumber = "SERIAL1"
err = ns.StoreAuthenticate(req, authMsg2)
require.NoError(t, err)
token = getBootstrapToken()
require.True(t, token.Valid, "bootstrap token should be preserved during SCEP renewal")
require.Equal(t, bootstrapToken, token.String)
// --- Case 3: After SCEP renewal completes (renew_command_uuid cleared), token should be cleared again ---
_, err = ds.writer(ctx).ExecContext(ctx,
`UPDATE nano_cert_auth_associations SET renew_command_uuid = NULL WHERE id = ?`,
deviceUUID)
require.NoError(t, err)
authMsg3 := &mdm.Authenticate{
Enrollment: mdm.Enrollment{UDID: deviceUUID},
Raw: []byte("auth-raw-post-renewal"),
}
authMsg3.SerialNumber = "SERIAL1"
err = ns.StoreAuthenticate(req, authMsg3)
require.NoError(t, err)
token = getBootstrapToken()
require.False(t, token.Valid, "bootstrap token should be cleared after SCEP renewal completes")
}
func testEnqueueDeviceLockCommandRaceCondition(t *testing.T, ds *Datastore) {
ctx := context.Background()
// Create a test host
host, err := ds.NewHost(ctx, &fleet.Host{
UUID: "test-host-race-" + uuid.NewString(),
Platform: "darwin",
OsqueryHostID: ptr.String("test-osquery-id"),
NodeKey: ptr.String("test-node-key"),
Hostname: "test-host.local",
})
require.NoError(t, err)
// Enable MDM for the host
err = ds.SetOrUpdateMDMData(ctx, host.ID, false, true, "https://test.local", false, "test-ref", "", false)
require.NoError(t, err)
// Create nano_devices record first
deviceID := "device-" + host.UUID
_, err = ds.writer(ctx).Exec(`
INSERT INTO nano_devices (id, authenticate, token_update) VALUES (?, 'Authenticate', 0)`, deviceID)
require.NoError(t, err)
// Create nano_enrollments record (required for MDM commands)
_, err = ds.writer(ctx).Exec(`
INSERT INTO nano_enrollments (id, device_id, type, topic, push_magic, token_hex, last_seen_at)
VALUES (?, ?, 'Device', 'com.apple.mgmt.test', 'test-magic', 'deadbeef', NOW())`,
host.UUID, deviceID)
require.NoError(t, err)
// Create NanoMDMStorage
storage := &NanoMDMStorage{
db: ds.writer(ctx),
logger: slog.New(slog.DiscardHandler),
ds: ds,
}
// Number of concurrent lock attempts
numGoroutines := 20
var wg sync.WaitGroup
wg.Add(numGoroutines)
// Track successful locks
successCount := 0
var successMu sync.Mutex
// Track conflict errors
conflictCount := 0
// Collect all PINs that were generated
var pins []string
// Barrier to ensure all goroutines start at the same time
barrier := make(chan struct{})
for i := 0; i < numGoroutines; i++ {
go func(idx int) {
defer wg.Done()
// Wait for barrier
<-barrier
// Create a unique command for this goroutine
cmdUUID := fmt.Sprintf("test-lock-%03d", idx)
pin := fmt.Sprintf("%06d", 100000+idx) // Unique PIN for each request
cmd := &mdm.Command{
CommandUUID: cmdUUID,
Command: struct {
RequestType string
}{
RequestType: "DeviceLock",
},
Raw: []byte(fmt.Sprintf(`PIN%s`, pin)),
}
// Try to enqueue the lock command
err := storage.EnqueueDeviceLockCommand(ctx, host, cmd, pin)
switch {
case err == nil:
successMu.Lock()
successCount++
pins = append(pins, pin)
successMu.Unlock()
case isConflict(err):
successMu.Lock()
conflictCount++
successMu.Unlock()
default:
// Unexpected error
t.Logf("Request %d got unexpected error: %v", idx, err)
}
}(i)
}
// Release all goroutines at once
close(barrier)
// Wait for all to complete
wg.Wait()
// Check the database state
// 1. Count how many DeviceLock commands were created
var commandCount int
err = ds.writer(ctx).Get(&commandCount,
`SELECT COUNT(*) FROM nano_commands WHERE command_uuid LIKE 'test-lock-%'`)
require.NoError(t, err)
// 2. Check what's stored in host_mdm_actions
var storedPIN string
var lockRef string
err = ds.writer(ctx).QueryRow(
`SELECT COALESCE(unlock_pin, ''), COALESCE(lock_ref, '') FROM host_mdm_actions WHERE host_id = ?`,
host.ID).Scan(&storedPIN, &lockRef)
require.NoError(t, err)
// Log the results
t.Logf("===== RACE CONDITION TEST RESULTS =====")
t.Logf("Concurrent requests sent: %d", numGoroutines)
t.Logf("Successful lock commands: %d", successCount)
t.Logf("Conflict errors: %d", conflictCount)
t.Logf("Commands in nano_commands table: %d", commandCount)
t.Logf("Final PIN stored in database: %s", storedPIN)
t.Logf("Final lock_ref in database: %s", lockRef)
// Assertions - only one lock should succeed
require.Equal(t, 1, successCount, "Only one lock command should succeed")
require.Equal(t, numGoroutines-1, conflictCount, "All other requests should get conflict error")
require.Equal(t, 1, commandCount, "Only one command should be in nano_commands table")
require.Len(t, pins, 1, "Only one PIN should be generated")
require.Equal(t, pins[0], storedPIN, "Stored PIN should match the successful request")
}