diff --git a/server/datastore/mysql/microsoft_mdm.go b/server/datastore/mysql/microsoft_mdm.go index e9d439d1d7..6e1053ba1c 100644 --- a/server/datastore/mysql/microsoft_mdm.go +++ b/server/datastore/mysql/microsoft_mdm.go @@ -225,15 +225,12 @@ func (ds *Datastore) MDMWindowsSaveResponse(ctx context.Context, deviceID string return ctxerr.New(ctx, "empty raw response") } - const findCommandsStmt = `SELECT command_uuid FROM windows_mdm_commands WHERE command_uuid IN (?)` + const findCommandsStmt = `SELECT command_uuid, raw_command FROM windows_mdm_commands WHERE command_uuid IN (?)` const saveFullRespStmt = `INSERT INTO windows_mdm_responses (enrollment_id, raw_response) VALUES (?, ?)` const dequeueCommandsStmt = `DELETE FROM windows_mdm_command_queue WHERE command_uuid IN (?)` - // raw_results and status_code values might be inserted on different requests? - // TODO: which response_id should we be tracking then? for now, using - // whatever comes first. const insertResultsStmt = ` INSERT INTO windows_mdm_command_results (enrollment_id, command_uuid, raw_result, response_id, status_code) @@ -288,13 +285,13 @@ ON DUPLICATE KEY UPDATE if err != nil { return ctxerr.Wrap(ctx, err, "building IN to search matching commands") } - var matchingUUIDs []string - err = sqlx.SelectContext(ctx, tx, &matchingUUIDs, stmt, params...) + var matchingCmds []fleet.MDMWindowsCommand + err = sqlx.SelectContext(ctx, tx, &matchingCmds, stmt, params...) if err != nil { return ctxerr.Wrap(ctx, err, "selecting matching commands") } - if len(matchingUUIDs) == 0 { + if len(matchingCmds) == 0 { ds.logger.Log("warn", "unmatched commands", "uuids", cmdUUIDs) return nil } @@ -303,24 +300,36 @@ ON DUPLICATE KEY UPDATE // entries to track them as responses. var args []any var sb strings.Builder - for _, uuid := range matchingUUIDs { + var potentialProfilePayloads []*fleet.MDMWindowsProfilePayload + for _, cmd := range matchingCmds { statusCode := "" - if status, ok := uuidsToStatus[uuid]; ok && status.Data != nil { + if status, ok := uuidsToStatus[cmd.CommandUUID]; ok && status.Data != nil { statusCode = *status.Data + if status.Cmd != nil && *status.Cmd == fleet.CmdAtomic { + pp, err := fleet.BuildMDMWindowsProfilePayloadFromMDMResponse(cmd, uuidsToStatus, enrollment.HostUUID) + if err != nil { + return err + } + potentialProfilePayloads = append(potentialProfilePayloads, pp) + } } rawResult := []byte{} - if result, ok := uuidsToResults[uuid]; ok && result.Data != nil { + if result, ok := uuidsToResults[cmd.CommandUUID]; ok && result.Data != nil { var err error rawResult, err = xml.Marshal(result) if err != nil { - ds.logger.Log("err", err, "marshaling command result", "cmd_uuid", uuid) + ds.logger.Log("err", err, "marshaling command result", "cmd_uuid", cmd.CommandUUID) } } - args = append(args, enrollment.ID, uuid, rawResult, responseID, statusCode) + args = append(args, enrollment.ID, cmd.CommandUUID, rawResult, responseID, statusCode) sb.WriteString("(?, ?, ?, ?, ?),") } + if err := updateMDMWindowsHostProfileStatusFromResponseDB(ctx, tx, potentialProfilePayloads); err != nil { + return ctxerr.Wrap(ctx, err, "updating host profile status") + } + // store the command results stmt = fmt.Sprintf(insertResultsStmt, strings.TrimSuffix(sb.String(), ",")) if _, err = tx.ExecContext(ctx, stmt, args...); err != nil { @@ -328,6 +337,10 @@ ON DUPLICATE KEY UPDATE } // dequeue the commands + var matchingUUIDs []string + for _, cmd := range matchingCmds { + matchingUUIDs = append(matchingUUIDs, cmd.CommandUUID) + } stmt, params, err = sqlx.In(dequeueCommandsStmt, matchingUUIDs) if err != nil { return ctxerr.Wrap(ctx, err, "building IN to dequeue commands") @@ -340,6 +353,76 @@ ON DUPLICATE KEY UPDATE }) } +// updateMDMWindowsHostProfileStatusFromResponseDB takes a slice of potential +// profile payloads and updates the corresponding `status` and `detail` columns +// in `host_mdm_windows_profiles` +func updateMDMWindowsHostProfileStatusFromResponseDB( + ctx context.Context, + tx sqlx.ExtContext, + payloads []*fleet.MDMWindowsProfilePayload, +) error { + if len(payloads) == 0 { + return nil + } + + // this statement will act as a batch-update, no new host profiles + // should be inserted from a device MDM response, so we first check for + // matching entries and then perform the INSERT ... ON DUPLICATE KEY to + // update their detail and status. + const updateHostProfilesStmt = ` + INSERT INTO host_mdm_windows_profiles + (host_uuid, profile_uuid, detail, status) + VALUES %s + ON DUPLICATE KEY UPDATE + detail = VALUES(detail), + status = VALUES(status)` + + // MySQL will use the `host_uuid` part of the primary key as a first + // pass, and then filter that subset by `command_uuid`. + const getMatchingHostProfilesStmt = ` + SELECT host_uuid, profile_uuid, command_uuid + FROM host_mdm_windows_profiles + WHERE host_uuid = ? AND command_uuid IN (?)` + + // grab command UUIDs to find matching entries using `getMatchingHostProfilesStmt` + commandUUIDs := make([]string, len(payloads)) + // also grab the payloads keyed by the command uuid, so we can easily + // grab the corresponding `Detail` and `Status` from the matching + // command later on. + uuidsToPayloads := make(map[string]*fleet.MDMWindowsProfilePayload, len(payloads)) + hostUUID := payloads[0].HostUUID + for _, payload := range payloads { + if payload.HostUUID != hostUUID { + return errors.New("all payloads must be for the same host uuid") + } + commandUUIDs = append(commandUUIDs, payload.CommandUUID) + uuidsToPayloads[payload.CommandUUID] = payload + } + + // find the matching entries for the given host_uuid, command_uuid combinations. + stmt, args, err := sqlx.In(getMatchingHostProfilesStmt, hostUUID, commandUUIDs) + if err != nil { + return err + } + var matchingHostProfiles []fleet.MDMWindowsProfilePayload + if err := sqlx.SelectContext(ctx, tx, &matchingHostProfiles, stmt, args...); err != nil { + return err + } + + // batch-update the matching entries with the desired detail and status> + var sb strings.Builder + args = args[:0] + for _, hp := range matchingHostProfiles { + payload := uuidsToPayloads[hp.CommandUUID] + args = append(args, hp.HostUUID, hp.ProfileUUID, payload.Detail, payload.Status) + sb.WriteString("(?, ?, ?, ?),") + } + + stmt = fmt.Sprintf(updateHostProfilesStmt, strings.TrimSuffix(sb.String(), ",")) + _, err = tx.ExecContext(ctx, stmt, args...) + return err +} + func (ds *Datastore) GetMDMWindowsCommandResults(ctx context.Context, commandUUID string) ([]*fleet.MDMCommandResult, error) { query := ` SELECT diff --git a/server/fleet/microsoft_mdm.go b/server/fleet/microsoft_mdm.go index c409470770..7850548ef0 100644 --- a/server/fleet/microsoft_mdm.go +++ b/server/fleet/microsoft_mdm.go @@ -921,13 +921,14 @@ type ProtoCmdOperation struct { // Protocol Command type SyncMLCmd struct { - XMLName xml.Name `xml:",omitempty"` - CmdID string `xml:"CmdID"` - MsgRef *string `xml:"MsgRef,omitempty"` - CmdRef *string `xml:"CmdRef,omitempty"` - Cmd *string `xml:"Cmd,omitempty"` - Data *string `xml:"Data,omitempty"` - Items []CmdItem `xml:"Item,omitempty"` + XMLName xml.Name `xml:",omitempty"` + CmdID string `xml:"CmdID"` + MsgRef *string `xml:"MsgRef,omitempty"` + CmdRef *string `xml:"CmdRef,omitempty"` + Cmd *string `xml:"Cmd,omitempty"` + Data *string `xml:"Data,omitempty"` + Items []CmdItem `xml:"Item,omitempty"` + ReplaceCommands []SyncMLCmd `xml:"Replace,omitempty"` } // ParseWindowsMDMCommand parses the raw XML as a single Windows MDM command. @@ -1377,3 +1378,109 @@ func GetEncodedBinarySecurityToken(typeID WindowsMDMEnrollmentType, payload stri return base64.URLEncoding.EncodeToString(rawBytes), nil } + +// BuildMDMWindowsProfilePayloadFromMDMResponse builds a +// MDMWindowsProfilePayload for a command that was used to deliver a +// configuration profile. +// +// Profiles are groups of `` commands wrapped in an ``, both +// the top-level atomic and each replace have different CmdID values and Status +// responses. For example a profile might look like: +// +// +// +// foo +// +// bar +// ... +// +// +// baz +// ... +// +// +// +// +// And the response from the MDM server will be something like: +// +// +// +// +// foo +// Atomic +// 200 +// ... +// +// +// +// +// bar +// Replace +// 200 +// ... +// +// +// +// +// baz +// Replace +// 200 +// ... +// +// +// ... +// +// +// As currently specified: +// - The status of the resulting command should be the status of the +// top-level `` operation +// - The detail of the resulting command should be an aggregate of all the +// status responses of every nested `Replace` operation +func BuildMDMWindowsProfilePayloadFromMDMResponse( + cmd MDMWindowsCommand, + statuses map[string]SyncMLCmd, + hostUUID string, +) (*MDMWindowsProfilePayload, error) { + status, ok := statuses[cmd.CommandUUID] + if !ok { + return nil, fmt.Errorf("missing status for root command %s", cmd.CommandUUID) + } + commandStatus := WindowsResponseToDeliveryStatus(*status.Data) + var details []string + if status.Data != nil && commandStatus == MDMDeliveryFailed { + syncML := new(SyncMLCmd) + if err := xml.Unmarshal(cmd.RawCommand, syncML); err != nil { + return nil, err + } + for _, nested := range syncML.ReplaceCommands { + if status, ok := statuses[nested.CmdID]; ok && status.Data != nil { + details = append(details, fmt.Sprintf("CmdID %s: status %s", nested.CmdID, *status.Data)) + } + } + } + detail := strings.Join(details, ", ") + return &MDMWindowsProfilePayload{ + HostUUID: hostUUID, + Status: &commandStatus, + OperationType: "", + Detail: detail, + CommandUUID: cmd.CommandUUID, + }, nil +} + +// WindowsResponseToDeliveryStatus converts a response string from Windows MDM +// into an MDMDeliveryStatus. +// +// If the response starts with "2" (any 2xx response), it returns +// MDMDeliveryVerifying, otherwise, it returns MDMDeliveryFailed. +func WindowsResponseToDeliveryStatus(resp string) MDMDeliveryStatus { + if len(resp) == 0 { + return MDMDeliveryPending + } + + if strings.HasPrefix(resp, "2") { + return MDMDeliveryVerifying + } + + return MDMDeliveryFailed +} diff --git a/server/fleet/microsoft_mdm_test.go b/server/fleet/microsoft_mdm_test.go index a829010013..f5c471d468 100644 --- a/server/fleet/microsoft_mdm_test.go +++ b/server/fleet/microsoft_mdm_test.go @@ -4,6 +4,7 @@ import ( "encoding/xml" "testing" + microsoft_mdm "github.com/fleetdm/fleet/v4/server/mdm/microsoft" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/stretchr/testify/require" ) @@ -64,3 +65,132 @@ func TestParseWindowsMDMCommand(t *testing.T) { }) } } + +func TestBuildMDMWindowsProfilePayloadFromMDMResponse(t *testing.T) { + tests := []struct { + name string + cmd MDMWindowsCommand + statuses map[string]SyncMLCmd + hostUUID string + expectedError string + expectedPayload *MDMWindowsProfilePayload + }{ + { + name: "missing status for command", + cmd: MDMWindowsCommand{ + CommandUUID: "foo", + }, + statuses: map[string]SyncMLCmd{}, + hostUUID: "host-uuid", + expectedError: "missing status for root command", + }, + { + name: "bad xml", + cmd: MDMWindowsCommand{ + CommandUUID: "foo", + RawCommand: []byte(`<`), + }, + statuses: map[string]SyncMLCmd{ + "foo": {CmdID: "foo", Data: ptr.String(microsoft_mdm.CmdStatusAtomicFailed)}, + }, + hostUUID: "host-uuid", + expectedError: "XML syntax error", + }, + { + name: "all operations succeded", + cmd: MDMWindowsCommand{ + CommandUUID: "foo", + RawCommand: []byte(` + + foo + bar./Device/Baz + `), + }, + statuses: map[string]SyncMLCmd{ + "foo": {CmdID: "foo", Data: ptr.String("200")}, + "bar": {CmdID: "bar", Data: ptr.String("200")}, + }, + hostUUID: "host-uuid", + expectedPayload: &MDMWindowsProfilePayload{ + HostUUID: "host-uuid", + Status: &MDMDeliveryVerifying, + Detail: "", + CommandUUID: "foo", + }, + }, + { + name: "one operation failed", + cmd: MDMWindowsCommand{ + CommandUUID: "foo", + RawCommand: []byte(` + + foo + bar./Device/Baz + baz./Bad/Loc + `), + }, + statuses: map[string]SyncMLCmd{ + "foo": {CmdID: "foo", Data: ptr.String(microsoft_mdm.CmdStatusAtomicFailed)}, + "bar": {CmdID: "bar", Data: ptr.String(microsoft_mdm.CmdStatusOK)}, + "baz": {CmdID: "baz", Data: ptr.String(microsoft_mdm.CmdStatusBadRequest)}, + }, + hostUUID: "host-uuid", + expectedPayload: &MDMWindowsProfilePayload{ + HostUUID: "host-uuid", + Status: &MDMDeliveryFailed, + Detail: "CmdID bar: status 200, CmdID baz: status 400", + CommandUUID: "foo", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := BuildMDMWindowsProfilePayloadFromMDMResponse(tt.cmd, tt.statuses, tt.hostUUID) + + if tt.expectedError != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.expectedError) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedPayload, payload) + } + }) + } +} + +func TestWindowsResponseToDeliveryStatus(t *testing.T) { + tests := []struct { + name string + resp string + expected MDMDeliveryStatus + }{ + { + name: "response starts with 2", + resp: "202", + expected: MDMDeliveryVerifying, + }, + { + name: "bad requests", + resp: "400", + expected: MDMDeliveryFailed, + }, + { + name: "errors", + resp: "500", + expected: MDMDeliveryFailed, + }, + { + name: "empty response", + resp: "", + expected: MDMDeliveryPending, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := WindowsResponseToDeliveryStatus(tt.resp) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index 05b817317b..22dda9132f 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -3838,7 +3838,7 @@ func (s *integrationMDMTestSuite) TestDiskEncryptionRotation() { require.False(t, resp.Notifications.RotateDiskEncryptionKey) } -func (s *integrationMDMTestSuite) TestHostMDMProfilesStatus() { +func (s *integrationMDMTestSuite) TestHostMDMAppleProfilesStatus() { t := s.T() ctx := context.Background() @@ -7485,7 +7485,6 @@ func (s *integrationMDMTestSuite) TestWindowsMDM() { require.Len(t, getMDMCmdResp.Results, 1) require.NotZero(t, getMDMCmdResp.Results[0].UpdatedAt) getMDMCmdResp.Results[0].UpdatedAt = time.Time{} - fmt.Println(string(getMDMCmdResp.Results[0].Result)) require.Equal(t, &fleet.MDMCommandResult{ HostUUID: orbitHost.UUID, CommandUUID: cmdOneUUID, @@ -9443,7 +9442,22 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() { }) require.NoError(t, err) - verifyProfiles := func(device *mdmtest.TestWindowsMDMClient, n int) { + verifyHostProfileStatus := func(cmds []fleet.ProtoCmdOperation, wantStatus string) { + for _, cmd := range cmds { + var gotStatus string + mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { + stmt := `SELECT status FROM host_mdm_windows_profiles WHERE command_uuid = ?` + return sqlx.GetContext(context.Background(), q, &gotStatus, stmt, cmd.Cmd.CmdID) + }) + require.EqualValues(t, fleet.WindowsResponseToDeliveryStatus(wantStatus), gotStatus, "command_uuid", cmd.Cmd.CmdID) + } + } + + verifyProfiles := func(device *mdmtest.TestWindowsMDMClient, n int, fail bool) { + mdmResponseStatus := microsoft_mdm.CmdStatusOK + if fail { + mdmResponseStatus = microsoft_mdm.CmdStatusAtomicFailed + } s.awaitTriggerProfileSchedule(t) cmds, err := device.StartManagementSession() require.NoError(t, err) @@ -9455,15 +9469,17 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() { require.NoError(t, err) for _, c := range cmds { cmdID := c.Cmd.CmdID + status := microsoft_mdm.CmdStatusOK if c.Verb == "Atomic" { atomicCmds = append(atomicCmds, c) + status = mdmResponseStatus } device.AppendResponse(fleet.SyncMLCmd{ XMLName: xml.Name{Local: mdm_types.CmdStatus}, MsgRef: &msgID, CmdRef: &cmdID, - Cmd: ptr.String("Exec"), - Data: ptr.String("200"), + Cmd: ptr.String(c.Verb), + Data: &status, Items: nil, CmdID: uuid.NewString(), }) @@ -9471,10 +9487,16 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() { // TODO: verify profile contents as well require.Len(t, atomicCmds, n) + // before we send the response, commands should be "pending" + verifyHostProfileStatus(atomicCmds, "") + cmds, err = device.SendResponse() require.NoError(t, err) // the ack of the message should be the only returned command require.Len(t, cmds, 1) + + // verify that we updated status in the db + verifyHostProfileStatus(atomicCmds, mdmResponseStatus) } checkHostsProfilesMatch := func(host *fleet.Host, wantUUIDs []string) { @@ -9489,18 +9511,18 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() { // Create a host and then enroll to MDM. host, mdmDevice := createWindowsHostThenEnrollMDM(s.ds, s.server.URL, t) // trigger a profile sync - verifyProfiles(mdmDevice, 3) + verifyProfiles(mdmDevice, 3, false) checkHostsProfilesMatch(host, globalProfiles) // another sync shouldn't return profiles - verifyProfiles(mdmDevice, 0) + verifyProfiles(mdmDevice, 0, false) // add the host to a team err = s.ds.AddHostsToTeam(ctx, &tm.ID, []uint{host.ID}) require.NoError(t, err) // trigger a profile sync, device gets the team profile - verifyProfiles(mdmDevice, 2) + verifyProfiles(mdmDevice, 2, false) checkHostsProfilesMatch(host, teamProfiles) // set new team profiles (delete + addition) @@ -9515,13 +9537,33 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() { } // trigger a profile sync, device gets the team profile - verifyProfiles(mdmDevice, 1) + verifyProfiles(mdmDevice, 1, false) // check that we deleted the old profile in the DB checkHostsProfilesMatch(host, teamProfiles) // another sync shouldn't return profiles - verifyProfiles(mdmDevice, 0) + verifyProfiles(mdmDevice, 0, false) + + // set new team profiles (delete + addition) + mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { + stmt := `DELETE FROM mdm_windows_configuration_profiles WHERE profile_uuid = ?` + _, err := q.ExecContext(context.Background(), stmt, teamProfiles[1]) + return err + }) + teamProfiles = []string{ + teamProfiles[0], + mysql.InsertWindowsProfileForTest(t, s.ds, tm.ID), + } + // trigger a profile sync, this time fail the delivery + verifyProfiles(mdmDevice, 1, true) + + // check that we deleted the old profile in the DB + checkHostsProfilesMatch(host, teamProfiles) + + // another sync shouldn't return profiles + verifyProfiles(mdmDevice, 0, false) + } func (s *integrationMDMTestSuite) TestAppConfigMDMWindowsProfiles() {