record declarative checkin command responses (#17693)

this is to prevent nanomdm to send the DeclarativeManagement command
every time the host checks in.
This commit is contained in:
Roberto Dip 2024-03-18 14:41:33 -03:00 committed by GitHub
parent f5cf156653
commit e26d23460c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 159 additions and 0 deletions

View file

@ -3080,3 +3080,76 @@ WHERE h.uuid = ?
return nil
}
func (ds *Datastore) MDMAppleRecordDeclarativeCheckIn(ctx context.Context, hostUUID string, result []byte) error {
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
res, err := tx.ExecContext(
ctx,
`UPDATE nano_enrollments SET last_seen_at = CURRENT_TIMESTAMP WHERE id = ?`,
hostUUID,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "updating last_seen times")
}
if n, _ := res.RowsAffected(); n == 0 {
return ctxerr.New(ctx, "host is not enrolled in MDM")
}
// NOTE: DeclarativeManagement checkin commands sent by the device
// don't carry a CommandUUID reference like commands in
// CommandAndReportResults messages do.
//
// In nanomdm's view of the world, a command is pending until
// it receives a result or is deactivated, so we'll grab the
// command_uuid of the oldest DeclarativeManagement command we
// sent and assume this is the response for it.
//
// Other DeclarativeManagement commands will still be in the
// queue and they will trigger DDM syncs when the device checks
// in, so eventually all DDM commands wil get acknowledged.
//
// Alternatively, we could mark all DDM commands as
// acknowledged here, TBD based on the behaviors we see.
var cmdUUID string
err = sqlx.GetContext(ctx, tx, &cmdUUID, `
SELECT nc.command_uuid
FROM nano_enrollment_queue neq
JOIN nano_commands nc
ON neq.command_uuid = nc.command_uuid
WHERE
id = ? AND
request_type = 'DeclarativeManagement'
ORDER BY neq.created_at ASC
LIMIT 1
`, hostUUID)
if err != nil {
// it's okay if the host doesn't have matching command enqueued, the
// check-in could be initiated by the device.
if err == sql.ErrNoRows {
return nil
}
return ctxerr.Wrap(ctx, err, "getting DDM command")
}
_, err = tx.ExecContext(
ctx, `
INSERT INTO nano_command_results
(id, command_uuid, status, result)
VALUES
(?, ?, ?, ?)
ON DUPLICATE KEY
UPDATE
status = VALUES(status),
result = VALUES(result)`,
hostUUID,
cmdUUID,
fleet.MDMAppleStatusAcknowledged,
result,
)
return ctxerr.Wrap(ctx, err, "updating nano_command_results")
})
return ctxerr.Wrap(ctx, err, "saving declarative management response")
}

View file

@ -69,6 +69,7 @@ func TestMDMApple(t *testing.T) {
{"TestMDMAppleDeleteHostDEPAssignments", testMDMAppleDeleteHostDEPAssignments},
{"LockUnlockWipeMacOS", testLockUnlockWipeMacOS},
{"ScreenDEPAssignProfileSerialsForCooldown", testScreenDEPAssignProfileSerialsForCooldown},
{"MDMAppleRecordDeclarativeCheckIn", testMDMAppleRecordDeclarativeCheckIn},
}
for _, c := range cases {
@ -4570,6 +4571,49 @@ func testScreenDEPAssignProfileSerialsForCooldown(t *testing.T, ds *Datastore) {
require.Empty(t, assign)
}
func testMDMAppleRecordDeclarativeCheckIn(t *testing.T, ds *Datastore) {
ctx := context.Background()
host, err := ds.NewHost(ctx, &fleet.Host{
Hostname: "test-host1-name",
OsqueryHostID: ptr.String("1337"),
NodeKey: ptr.String("1337"),
UUID: "test-uuid-1",
TeamID: nil,
Platform: "darwin",
})
require.NoError(t, err)
// error if the host is not enrolled
err = ds.MDMAppleRecordDeclarativeCheckIn(ctx, host.UUID, []byte{})
require.Error(t, err)
// enroll the host
nanoEnroll(t, ds, host, true)
// it's okay if the host doesn't have matching command enqueued, the
// check-in could be initiated by the device.
err = ds.MDMAppleRecordDeclarativeCheckIn(ctx, host.UUID, []byte{})
require.NoError(t, err)
// enqueue a declarative checkin request
commander, _ := createMDMAppleCommanderAndStorage(t, ds)
cmdUUID := uuid.New().String()
err = commander.DeclarativeManagement(ctx, []string{host.UUID}, cmdUUID)
require.NoError(t, err)
// record a response from the host
err = ds.MDMAppleRecordDeclarativeCheckIn(ctx, host.UUID, []byte("foo"))
require.NoError(t, err)
res, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID)
require.NoError(t, err)
require.Len(t, res, 1)
require.Equal(t, host.UUID, res[0].HostUUID)
require.Equal(t, fleet.MDMAppleStatusAcknowledged, res[0].Status)
require.EqualValues(t, []byte("foo"), res[0].Result)
}
func TestMDMAppleProfileVerification(t *testing.T) {
ds := CreateMySQLDS(t)
ctx := context.Background()

View file

@ -1137,6 +1137,11 @@ type Datastore interface {
// host_dep_assignments for host with matching serials.
DeleteHostDEPAssignments(ctx context.Context, serials []string) error
// MDMAppleRecordDeclarativeCheckIn records a DeclarativeManagement
// checking from a host, so we know the host received the command to
// start the declarative management sync.
MDMAppleRecordDeclarativeCheckIn(ctx context.Context, hostUUID string, response []byte) error
// UpdateHostDEPAssignProfileResponses receives a profile UUID and threes lists of serials, each representing
// one of the three possible responses, and updates the host_dep_assignments table with the corresponding responses.
UpdateHostDEPAssignProfileResponses(ctx context.Context, resp *godep.ProfileResponse) error

View file

@ -226,6 +226,28 @@ func (svc *MDMAppleCommander) AccountConfiguration(ctx context.Context, hostUUID
return svc.EnqueueCommand(ctx, hostUUIDs, raw)
}
// DeclarativeManagement sends the homonym [command][1] to the device to enable DDM or start a new DDM session.
//
// [1]: https://developer.apple.com/documentation/devicemanagement/declarativemanagementcommand
func (svc *MDMAppleCommander) DeclarativeManagement(ctx context.Context, hostUUIDs []string, uuid string) error {
raw := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Command</key>
<dict>
<key>RequestType</key>
<string>DeclarativeManagement</string>
</dict>
<key>CommandUUID</key>
<string>%s</string>
</dict>
</plist>`, uuid)
return svc.EnqueueCommand(ctx, hostUUIDs, raw)
}
// EnqueueCommand takes care of enqueuing the commands and sending push
// notifications to the devices.
//

View file

@ -748,6 +748,8 @@ type GetMatchingHostSerialsFunc func(ctx context.Context, serials []string) (map
type DeleteHostDEPAssignmentsFunc func(ctx context.Context, serials []string) error
type MDMAppleRecordDeclarativeCheckInFunc func(ctx context.Context, hostUUID string, response []byte) error
type UpdateHostDEPAssignProfileResponsesFunc func(ctx context.Context, resp *godep.ProfileResponse) error
type ScreenDEPAssignProfileSerialsForCooldownFunc func(ctx context.Context, serials []string) (skipSerials []string, assignSerials []string, err error)
@ -1954,6 +1956,9 @@ type DataStore struct {
DeleteHostDEPAssignmentsFunc DeleteHostDEPAssignmentsFunc
DeleteHostDEPAssignmentsFuncInvoked bool
MDMAppleRecordDeclarativeCheckInFunc MDMAppleRecordDeclarativeCheckInFunc
MDMAppleRecordDeclarativeCheckInFuncInvoked bool
UpdateHostDEPAssignProfileResponsesFunc UpdateHostDEPAssignProfileResponsesFunc
UpdateHostDEPAssignProfileResponsesFuncInvoked bool
@ -4677,6 +4682,13 @@ func (s *DataStore) DeleteHostDEPAssignments(ctx context.Context, serials []stri
return s.DeleteHostDEPAssignmentsFunc(ctx, serials)
}
func (s *DataStore) MDMAppleRecordDeclarativeCheckIn(ctx context.Context, hostUUID string, response []byte) error {
s.mu.Lock()
s.MDMAppleRecordDeclarativeCheckInFuncInvoked = true
s.mu.Unlock()
return s.MDMAppleRecordDeclarativeCheckInFunc(ctx, hostUUID, response)
}
func (s *DataStore) UpdateHostDEPAssignProfileResponses(ctx context.Context, resp *godep.ProfileResponse) error {
s.mu.Lock()
s.UpdateHostDEPAssignProfileResponsesFuncInvoked = true

View file

@ -2976,6 +2976,9 @@ func (svc *MDMAppleDDMService) DeclarativeManagement(r *mdm.Request, dm *mdm.Dec
switch {
case dm.Endpoint == "tokens":
if err := svc.ds.MDMAppleRecordDeclarativeCheckIn(r.Context, dm.UDID, dm.Raw); err != nil {
return nil, ctxerr.Wrap(r.Context, err, "recording declarative checkin")
}
// TODO(sarah): handle tokens
level.Debug(svc.logger).Log("msg", "received tokens request")
return nil, nil