update host profile status when we get a Windows MDM response (#15172)

related to #14364, this adds logic to update the `status` and `detail`
columns of `host_mdm_windows_profiles` when we get a management
response.
This commit is contained in:
Roberto Dip 2023-11-20 11:25:54 -03:00 committed by GitHub
parent 420dfe1cd0
commit d9f0f86002
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 391 additions and 29 deletions

View file

@ -225,15 +225,12 @@ func (ds *Datastore) MDMWindowsSaveResponse(ctx context.Context, deviceID string
return ctxerr.New(ctx, "empty raw response")
}
const findCommandsStmt = `SELECT command_uuid FROM windows_mdm_commands WHERE command_uuid IN (?)`
const findCommandsStmt = `SELECT command_uuid, raw_command FROM windows_mdm_commands WHERE command_uuid IN (?)`
const saveFullRespStmt = `INSERT INTO windows_mdm_responses (enrollment_id, raw_response) VALUES (?, ?)`
const dequeueCommandsStmt = `DELETE FROM windows_mdm_command_queue WHERE command_uuid IN (?)`
// raw_results and status_code values might be inserted on different requests?
// TODO: which response_id should we be tracking then? for now, using
// whatever comes first.
const insertResultsStmt = `
INSERT INTO windows_mdm_command_results
(enrollment_id, command_uuid, raw_result, response_id, status_code)
@ -288,13 +285,13 @@ ON DUPLICATE KEY UPDATE
if err != nil {
return ctxerr.Wrap(ctx, err, "building IN to search matching commands")
}
var matchingUUIDs []string
err = sqlx.SelectContext(ctx, tx, &matchingUUIDs, stmt, params...)
var matchingCmds []fleet.MDMWindowsCommand
err = sqlx.SelectContext(ctx, tx, &matchingCmds, stmt, params...)
if err != nil {
return ctxerr.Wrap(ctx, err, "selecting matching commands")
}
if len(matchingUUIDs) == 0 {
if len(matchingCmds) == 0 {
ds.logger.Log("warn", "unmatched commands", "uuids", cmdUUIDs)
return nil
}
@ -303,24 +300,36 @@ ON DUPLICATE KEY UPDATE
// <Result> entries to track them as responses.
var args []any
var sb strings.Builder
for _, uuid := range matchingUUIDs {
var potentialProfilePayloads []*fleet.MDMWindowsProfilePayload
for _, cmd := range matchingCmds {
statusCode := ""
if status, ok := uuidsToStatus[uuid]; ok && status.Data != nil {
if status, ok := uuidsToStatus[cmd.CommandUUID]; ok && status.Data != nil {
statusCode = *status.Data
if status.Cmd != nil && *status.Cmd == fleet.CmdAtomic {
pp, err := fleet.BuildMDMWindowsProfilePayloadFromMDMResponse(cmd, uuidsToStatus, enrollment.HostUUID)
if err != nil {
return err
}
potentialProfilePayloads = append(potentialProfilePayloads, pp)
}
}
rawResult := []byte{}
if result, ok := uuidsToResults[uuid]; ok && result.Data != nil {
if result, ok := uuidsToResults[cmd.CommandUUID]; ok && result.Data != nil {
var err error
rawResult, err = xml.Marshal(result)
if err != nil {
ds.logger.Log("err", err, "marshaling command result", "cmd_uuid", uuid)
ds.logger.Log("err", err, "marshaling command result", "cmd_uuid", cmd.CommandUUID)
}
}
args = append(args, enrollment.ID, uuid, rawResult, responseID, statusCode)
args = append(args, enrollment.ID, cmd.CommandUUID, rawResult, responseID, statusCode)
sb.WriteString("(?, ?, ?, ?, ?),")
}
if err := updateMDMWindowsHostProfileStatusFromResponseDB(ctx, tx, potentialProfilePayloads); err != nil {
return ctxerr.Wrap(ctx, err, "updating host profile status")
}
// store the command results
stmt = fmt.Sprintf(insertResultsStmt, strings.TrimSuffix(sb.String(), ","))
if _, err = tx.ExecContext(ctx, stmt, args...); err != nil {
@ -328,6 +337,10 @@ ON DUPLICATE KEY UPDATE
}
// dequeue the commands
var matchingUUIDs []string
for _, cmd := range matchingCmds {
matchingUUIDs = append(matchingUUIDs, cmd.CommandUUID)
}
stmt, params, err = sqlx.In(dequeueCommandsStmt, matchingUUIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "building IN to dequeue commands")
@ -340,6 +353,76 @@ ON DUPLICATE KEY UPDATE
})
}
// updateMDMWindowsHostProfileStatusFromResponseDB takes a slice of potential
// profile payloads and updates the corresponding `status` and `detail` columns
// in `host_mdm_windows_profiles`
func updateMDMWindowsHostProfileStatusFromResponseDB(
ctx context.Context,
tx sqlx.ExtContext,
payloads []*fleet.MDMWindowsProfilePayload,
) error {
if len(payloads) == 0 {
return nil
}
// this statement will act as a batch-update, no new host profiles
// should be inserted from a device MDM response, so we first check for
// matching entries and then perform the INSERT ... ON DUPLICATE KEY to
// update their detail and status.
const updateHostProfilesStmt = `
INSERT INTO host_mdm_windows_profiles
(host_uuid, profile_uuid, detail, status)
VALUES %s
ON DUPLICATE KEY UPDATE
detail = VALUES(detail),
status = VALUES(status)`
// MySQL will use the `host_uuid` part of the primary key as a first
// pass, and then filter that subset by `command_uuid`.
const getMatchingHostProfilesStmt = `
SELECT host_uuid, profile_uuid, command_uuid
FROM host_mdm_windows_profiles
WHERE host_uuid = ? AND command_uuid IN (?)`
// grab command UUIDs to find matching entries using `getMatchingHostProfilesStmt`
commandUUIDs := make([]string, len(payloads))
// also grab the payloads keyed by the command uuid, so we can easily
// grab the corresponding `Detail` and `Status` from the matching
// command later on.
uuidsToPayloads := make(map[string]*fleet.MDMWindowsProfilePayload, len(payloads))
hostUUID := payloads[0].HostUUID
for _, payload := range payloads {
if payload.HostUUID != hostUUID {
return errors.New("all payloads must be for the same host uuid")
}
commandUUIDs = append(commandUUIDs, payload.CommandUUID)
uuidsToPayloads[payload.CommandUUID] = payload
}
// find the matching entries for the given host_uuid, command_uuid combinations.
stmt, args, err := sqlx.In(getMatchingHostProfilesStmt, hostUUID, commandUUIDs)
if err != nil {
return err
}
var matchingHostProfiles []fleet.MDMWindowsProfilePayload
if err := sqlx.SelectContext(ctx, tx, &matchingHostProfiles, stmt, args...); err != nil {
return err
}
// batch-update the matching entries with the desired detail and status>
var sb strings.Builder
args = args[:0]
for _, hp := range matchingHostProfiles {
payload := uuidsToPayloads[hp.CommandUUID]
args = append(args, hp.HostUUID, hp.ProfileUUID, payload.Detail, payload.Status)
sb.WriteString("(?, ?, ?, ?),")
}
stmt = fmt.Sprintf(updateHostProfilesStmt, strings.TrimSuffix(sb.String(), ","))
_, err = tx.ExecContext(ctx, stmt, args...)
return err
}
func (ds *Datastore) GetMDMWindowsCommandResults(ctx context.Context, commandUUID string) ([]*fleet.MDMCommandResult, error) {
query := `
SELECT

View file

@ -921,13 +921,14 @@ type ProtoCmdOperation struct {
// Protocol Command
type SyncMLCmd struct {
XMLName xml.Name `xml:",omitempty"`
CmdID string `xml:"CmdID"`
MsgRef *string `xml:"MsgRef,omitempty"`
CmdRef *string `xml:"CmdRef,omitempty"`
Cmd *string `xml:"Cmd,omitempty"`
Data *string `xml:"Data,omitempty"`
Items []CmdItem `xml:"Item,omitempty"`
XMLName xml.Name `xml:",omitempty"`
CmdID string `xml:"CmdID"`
MsgRef *string `xml:"MsgRef,omitempty"`
CmdRef *string `xml:"CmdRef,omitempty"`
Cmd *string `xml:"Cmd,omitempty"`
Data *string `xml:"Data,omitempty"`
Items []CmdItem `xml:"Item,omitempty"`
ReplaceCommands []SyncMLCmd `xml:"Replace,omitempty"`
}
// ParseWindowsMDMCommand parses the raw XML as a single Windows MDM command.
@ -1377,3 +1378,109 @@ func GetEncodedBinarySecurityToken(typeID WindowsMDMEnrollmentType, payload stri
return base64.URLEncoding.EncodeToString(rawBytes), nil
}
// BuildMDMWindowsProfilePayloadFromMDMResponse builds a
// MDMWindowsProfilePayload for a command that was used to deliver a
// configuration profile.
//
// Profiles are groups of `<Replace>` commands wrapped in an `<Atomic>`, both
// the top-level atomic and each replace have different CmdID values and Status
// responses. For example a profile might look like:
//
// <Atomic>
//
// <CmdID>foo</CmdID>
// <Replace>
// <CmdID>bar</CmdID>
// ...
// </Replace>
// <Replace>
// <CmdID>baz</CmdID>
// ...
// </Replace>
//
// </Atomic>
//
// And the response from the MDM server will be something like:
//
// <SyncBody>
// <Status>
//
// <CmdID>foo</CmdID>
// <Cmd>Atomic</Cmd>
// <Data>200</Data>
// ...
//
// </Status>
// <Status>
//
// <CmdID>bar</CmdID>
// <Cmd>Replace</Cmd>
// <Data>200</Data>
// ...
//
// </Status>
// <Status>
//
// <CmdID>baz</CmdID>
// <Cmd>Replace</Cmd>
// <Data>200</Data>
// ...
//
// </Status>
// ...
// </SyncBody>
//
// As currently specified:
// - The status of the resulting command should be the status of the
// top-level `<Atomic>` operation
// - The detail of the resulting command should be an aggregate of all the
// status responses of every nested `Replace` operation
func BuildMDMWindowsProfilePayloadFromMDMResponse(
cmd MDMWindowsCommand,
statuses map[string]SyncMLCmd,
hostUUID string,
) (*MDMWindowsProfilePayload, error) {
status, ok := statuses[cmd.CommandUUID]
if !ok {
return nil, fmt.Errorf("missing status for root command %s", cmd.CommandUUID)
}
commandStatus := WindowsResponseToDeliveryStatus(*status.Data)
var details []string
if status.Data != nil && commandStatus == MDMDeliveryFailed {
syncML := new(SyncMLCmd)
if err := xml.Unmarshal(cmd.RawCommand, syncML); err != nil {
return nil, err
}
for _, nested := range syncML.ReplaceCommands {
if status, ok := statuses[nested.CmdID]; ok && status.Data != nil {
details = append(details, fmt.Sprintf("CmdID %s: status %s", nested.CmdID, *status.Data))
}
}
}
detail := strings.Join(details, ", ")
return &MDMWindowsProfilePayload{
HostUUID: hostUUID,
Status: &commandStatus,
OperationType: "",
Detail: detail,
CommandUUID: cmd.CommandUUID,
}, nil
}
// WindowsResponseToDeliveryStatus converts a response string from Windows MDM
// into an MDMDeliveryStatus.
//
// If the response starts with "2" (any 2xx response), it returns
// MDMDeliveryVerifying, otherwise, it returns MDMDeliveryFailed.
func WindowsResponseToDeliveryStatus(resp string) MDMDeliveryStatus {
if len(resp) == 0 {
return MDMDeliveryPending
}
if strings.HasPrefix(resp, "2") {
return MDMDeliveryVerifying
}
return MDMDeliveryFailed
}

View file

@ -4,6 +4,7 @@ import (
"encoding/xml"
"testing"
microsoft_mdm "github.com/fleetdm/fleet/v4/server/mdm/microsoft"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/require"
)
@ -64,3 +65,132 @@ func TestParseWindowsMDMCommand(t *testing.T) {
})
}
}
func TestBuildMDMWindowsProfilePayloadFromMDMResponse(t *testing.T) {
tests := []struct {
name string
cmd MDMWindowsCommand
statuses map[string]SyncMLCmd
hostUUID string
expectedError string
expectedPayload *MDMWindowsProfilePayload
}{
{
name: "missing status for command",
cmd: MDMWindowsCommand{
CommandUUID: "foo",
},
statuses: map[string]SyncMLCmd{},
hostUUID: "host-uuid",
expectedError: "missing status for root command",
},
{
name: "bad xml",
cmd: MDMWindowsCommand{
CommandUUID: "foo",
RawCommand: []byte(`<Atomic><Replace><</Atomic>`),
},
statuses: map[string]SyncMLCmd{
"foo": {CmdID: "foo", Data: ptr.String(microsoft_mdm.CmdStatusAtomicFailed)},
},
hostUUID: "host-uuid",
expectedError: "XML syntax error",
},
{
name: "all operations succeded",
cmd: MDMWindowsCommand{
CommandUUID: "foo",
RawCommand: []byte(`
<Atomic>
<CmdID>foo</CmdID>
<Replace><CmdID>bar</CmdID><Target><LocURI>./Device/Baz</LocURI></Target></Replace>
</Atomic>`),
},
statuses: map[string]SyncMLCmd{
"foo": {CmdID: "foo", Data: ptr.String("200")},
"bar": {CmdID: "bar", Data: ptr.String("200")},
},
hostUUID: "host-uuid",
expectedPayload: &MDMWindowsProfilePayload{
HostUUID: "host-uuid",
Status: &MDMDeliveryVerifying,
Detail: "",
CommandUUID: "foo",
},
},
{
name: "one operation failed",
cmd: MDMWindowsCommand{
CommandUUID: "foo",
RawCommand: []byte(`
<Atomic>
<CmdID>foo</CmdID>
<Replace><CmdID>bar</CmdID><Target><LocURI>./Device/Baz</LocURI></Target></Replace>
<Replace><CmdID>baz</CmdID><Target><LocURI>./Bad/Loc</LocURI></Target></Replace>
</Atomic>`),
},
statuses: map[string]SyncMLCmd{
"foo": {CmdID: "foo", Data: ptr.String(microsoft_mdm.CmdStatusAtomicFailed)},
"bar": {CmdID: "bar", Data: ptr.String(microsoft_mdm.CmdStatusOK)},
"baz": {CmdID: "baz", Data: ptr.String(microsoft_mdm.CmdStatusBadRequest)},
},
hostUUID: "host-uuid",
expectedPayload: &MDMWindowsProfilePayload{
HostUUID: "host-uuid",
Status: &MDMDeliveryFailed,
Detail: "CmdID bar: status 200, CmdID baz: status 400",
CommandUUID: "foo",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
payload, err := BuildMDMWindowsProfilePayloadFromMDMResponse(tt.cmd, tt.statuses, tt.hostUUID)
if tt.expectedError != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.expectedError)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedPayload, payload)
}
})
}
}
func TestWindowsResponseToDeliveryStatus(t *testing.T) {
tests := []struct {
name string
resp string
expected MDMDeliveryStatus
}{
{
name: "response starts with 2",
resp: "202",
expected: MDMDeliveryVerifying,
},
{
name: "bad requests",
resp: "400",
expected: MDMDeliveryFailed,
},
{
name: "errors",
resp: "500",
expected: MDMDeliveryFailed,
},
{
name: "empty response",
resp: "",
expected: MDMDeliveryPending,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := WindowsResponseToDeliveryStatus(tt.resp)
require.Equal(t, tt.expected, result)
})
}
}

View file

@ -3838,7 +3838,7 @@ func (s *integrationMDMTestSuite) TestDiskEncryptionRotation() {
require.False(t, resp.Notifications.RotateDiskEncryptionKey)
}
func (s *integrationMDMTestSuite) TestHostMDMProfilesStatus() {
func (s *integrationMDMTestSuite) TestHostMDMAppleProfilesStatus() {
t := s.T()
ctx := context.Background()
@ -7485,7 +7485,6 @@ func (s *integrationMDMTestSuite) TestWindowsMDM() {
require.Len(t, getMDMCmdResp.Results, 1)
require.NotZero(t, getMDMCmdResp.Results[0].UpdatedAt)
getMDMCmdResp.Results[0].UpdatedAt = time.Time{}
fmt.Println(string(getMDMCmdResp.Results[0].Result))
require.Equal(t, &fleet.MDMCommandResult{
HostUUID: orbitHost.UUID,
CommandUUID: cmdOneUUID,
@ -9443,7 +9442,22 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() {
})
require.NoError(t, err)
verifyProfiles := func(device *mdmtest.TestWindowsMDMClient, n int) {
verifyHostProfileStatus := func(cmds []fleet.ProtoCmdOperation, wantStatus string) {
for _, cmd := range cmds {
var gotStatus string
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
stmt := `SELECT status FROM host_mdm_windows_profiles WHERE command_uuid = ?`
return sqlx.GetContext(context.Background(), q, &gotStatus, stmt, cmd.Cmd.CmdID)
})
require.EqualValues(t, fleet.WindowsResponseToDeliveryStatus(wantStatus), gotStatus, "command_uuid", cmd.Cmd.CmdID)
}
}
verifyProfiles := func(device *mdmtest.TestWindowsMDMClient, n int, fail bool) {
mdmResponseStatus := microsoft_mdm.CmdStatusOK
if fail {
mdmResponseStatus = microsoft_mdm.CmdStatusAtomicFailed
}
s.awaitTriggerProfileSchedule(t)
cmds, err := device.StartManagementSession()
require.NoError(t, err)
@ -9455,15 +9469,17 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() {
require.NoError(t, err)
for _, c := range cmds {
cmdID := c.Cmd.CmdID
status := microsoft_mdm.CmdStatusOK
if c.Verb == "Atomic" {
atomicCmds = append(atomicCmds, c)
status = mdmResponseStatus
}
device.AppendResponse(fleet.SyncMLCmd{
XMLName: xml.Name{Local: mdm_types.CmdStatus},
MsgRef: &msgID,
CmdRef: &cmdID,
Cmd: ptr.String("Exec"),
Data: ptr.String("200"),
Cmd: ptr.String(c.Verb),
Data: &status,
Items: nil,
CmdID: uuid.NewString(),
})
@ -9471,10 +9487,16 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() {
// TODO: verify profile contents as well
require.Len(t, atomicCmds, n)
// before we send the response, commands should be "pending"
verifyHostProfileStatus(atomicCmds, "")
cmds, err = device.SendResponse()
require.NoError(t, err)
// the ack of the message should be the only returned command
require.Len(t, cmds, 1)
// verify that we updated status in the db
verifyHostProfileStatus(atomicCmds, mdmResponseStatus)
}
checkHostsProfilesMatch := func(host *fleet.Host, wantUUIDs []string) {
@ -9489,18 +9511,18 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() {
// Create a host and then enroll to MDM.
host, mdmDevice := createWindowsHostThenEnrollMDM(s.ds, s.server.URL, t)
// trigger a profile sync
verifyProfiles(mdmDevice, 3)
verifyProfiles(mdmDevice, 3, false)
checkHostsProfilesMatch(host, globalProfiles)
// another sync shouldn't return profiles
verifyProfiles(mdmDevice, 0)
verifyProfiles(mdmDevice, 0, false)
// add the host to a team
err = s.ds.AddHostsToTeam(ctx, &tm.ID, []uint{host.ID})
require.NoError(t, err)
// trigger a profile sync, device gets the team profile
verifyProfiles(mdmDevice, 2)
verifyProfiles(mdmDevice, 2, false)
checkHostsProfilesMatch(host, teamProfiles)
// set new team profiles (delete + addition)
@ -9515,13 +9537,33 @@ func (s *integrationMDMTestSuite) TestWindowsProfileManagement() {
}
// trigger a profile sync, device gets the team profile
verifyProfiles(mdmDevice, 1)
verifyProfiles(mdmDevice, 1, false)
// check that we deleted the old profile in the DB
checkHostsProfilesMatch(host, teamProfiles)
// another sync shouldn't return profiles
verifyProfiles(mdmDevice, 0)
verifyProfiles(mdmDevice, 0, false)
// set new team profiles (delete + addition)
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
stmt := `DELETE FROM mdm_windows_configuration_profiles WHERE profile_uuid = ?`
_, err := q.ExecContext(context.Background(), stmt, teamProfiles[1])
return err
})
teamProfiles = []string{
teamProfiles[0],
mysql.InsertWindowsProfileForTest(t, s.ds, tm.ID),
}
// trigger a profile sync, this time fail the delivery
verifyProfiles(mdmDevice, 1, true)
// check that we deleted the old profile in the DB
checkHostsProfilesMatch(host, teamProfiles)
// another sync shouldn't return profiles
verifyProfiles(mdmDevice, 0, false)
}
func (s *integrationMDMTestSuite) TestAppConfigMDMWindowsProfiles() {