refactor how we send Fleet initiated MDM commands (#9903)

https://github.com/fleetdm/fleet/issues/9590

- move the logic to send commands into its own service method that can
be used internally by cron jobs and other services.
- deprecate the use of `rawEnqueueCommand` as it's copyied from the
nanomdm codebase where it's used in other context as a general command
API handler
This commit is contained in:
Roberto Dip 2023-02-17 16:26:51 -03:00 committed by GitHub
parent 345a1f4c36
commit 5a09ac0bfc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 249 additions and 87 deletions

View file

@ -30,7 +30,9 @@ import (
"github.com/micromdm/nanodep/godep"
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/push"
nanomdm_push "github.com/micromdm/nanomdm/push"
"github.com/micromdm/nanomdm/storage"
nanomdm_storage "github.com/micromdm/nanomdm/storage"
)
type createMDMAppleEnrollmentProfileRequest struct {
@ -805,16 +807,16 @@ func (svc *Service) EnqueueMDMAppleCommand(
if err := svc.authz.Authorize(ctx, command, fleet.ActionWrite); err != nil {
return 0, nil, ctxerr.Wrap(ctx, err)
}
return rawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, command.Command, deviceIDs, noPush, svc.logger)
return deprecatedRawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, command.Command, deviceIDs, noPush, svc.logger)
}
// rawCommandEnqueue enqueues a command to be executed on the given devices.
// deprecatedRawCommandEnqueue enqueues a command to be executed on the given devices.
//
// This method was extracted from:
// https://github.com/fleetdm/nanomdm/blob/a261f081323c80fb7f6575a64ac1a912dffe44ba/http/api/api.go#L134-L261
// NOTE(lucas): At the time, I found no way to reuse Fleet's gokit middlewares with a raw http.Handler
// like api.RawCommandEnqueueHandler.
func rawCommandEnqueue(
func deprecatedRawCommandEnqueue(
ctx context.Context,
enqueuer storage.CommandEnqueuer,
pusher push.Pusher,
@ -1148,7 +1150,7 @@ func (svc *Service) EnqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Co
return fleet.NewUserMessageError(ctxerr.New(ctx, fmt.Sprintf("mdm is not enabled for host %d", hostID)), http.StatusConflict)
}
cmdUUID, err := svc.enqueueMDMAppleCommandRemoveEnrollmentProfile(ctx, h.UUID)
cmdUUID, err := svc.mdmAppleCommander.RemoveProfile(ctx, []string{h.UUID}, apple_mdm.FleetPayloadIdentifier)
if err != nil {
return ctxerr.Wrap(ctx, err, "enqueuing mdm apple remove profile command")
}
@ -1164,59 +1166,6 @@ func (svc *Service) EnqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Co
return svc.pollResultMDMAppleCommandRemoveEnrollmentProfile(ctx, cmdUUID, h.UUID)
}
func (svc *Service) enqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Context, hostUUID string) (string, error) {
cmd := new(mdm.Command)
cmdUUID := uuid.New().String()
cmd.CommandUUID = cmdUUID
cmd.Command.RequestType = "RemoveProfile"
cmd.Raw = []byte(generateMDMAppleCommandRemoveEnrollmentProfile(cmdUUID, apple_mdm.FleetPayloadIdentifier))
status, _, err := rawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, cmd, []string{hostUUID}, false, svc.logger)
if err != nil {
// NOTE(sarah): rawCommandEnqueue does not currently return actionable errors so we rely on
// status code instead.
return cmdUUID, ctxerr.Wrap(ctx, err)
}
if status != http.StatusOK {
level.Debug(svc.logger).Log(
"msg", fmt.Sprintf("enqueuing mdm apple remove profile command resulted in unexpected status %d", status),
"host_uuid", hostUUID,
"command_uuid", cmdUUID,
)
if status != http.StatusMultiStatus {
// Status 207 is also possible with rawCommandEnqueue but should never happen in
// this case because we are only enqueueing this command to one device. If it does
// unexpectedly, we can proceed to the polling stage and will time out after 5 seconds
// in the worst case.
//
// Check the logs generated by rawCommandEnqueue if debugging unexpected results.
return cmdUUID, ctxerr.New(ctx, fmt.Sprintf("enqueuing mdm apple remove profile command resulted in unexpected status %d", status))
}
}
return cmdUUID, nil
}
func generateMDMAppleCommandRemoveEnrollmentProfile(cmdUUID string, profileUUID string) string {
return fmt.Sprintf(`
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CommandUUID</key>
<string>%s</string>
<key>Command</key>
<dict>
<key>RequestType</key>
<string>RemoveProfile</string>
<key>Identifier</key>
<string>%s</string>
</dict>
</dict>
</plist>`, cmdUUID, profileUUID)
}
func (svc *Service) pollResultMDMAppleCommandRemoveEnrollmentProfile(ctx context.Context, cmdUUID string, deviceID string) error {
ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(5*time.Second))
ticker := time.NewTicker(300 * time.Millisecond)
@ -1613,3 +1562,121 @@ func (svc *MDMAppleCheckinAndCommandService) DeclarativeManagement(*mdm.Request,
func (svc *MDMAppleCheckinAndCommandService) CommandAndReportResults(*mdm.Request, *mdm.CommandResults) (*mdm.Command, error) {
return nil, nil
}
// MDMAppleCommander contains methods to enqueue commands managed by Fleet and
// send push notifications to hosts.
//
// It's intentionally decoupled from fleet.Service so it can be used internally
// in crons and other services, leaving authentication/permission handling to
// the caller.
type MDMAppleCommander struct {
storage nanomdm_storage.AllStorage
pusher nanomdm_push.Pusher
}
// NewMDMAppleCommander creates a new commander instance.
func NewMDMAppleCommander(mdmStorage nanomdm_storage.AllStorage, mdmPushService nanomdm_push.Pusher) *MDMAppleCommander {
return &MDMAppleCommander{
storage: mdmStorage,
pusher: mdmPushService,
}
}
// InstallProfile sends the homonymous MDM command to the given hosts, it also
// takes care of the base64 encoding of the provided profile bytes.
func (svc *MDMAppleCommander) InstallProfile(ctx context.Context, hostUUIDs []string, profile fleet.Mobileconfig) (string, error) {
base64Profile := base64.StdEncoding.EncodeToString(profile)
uuid := uuid.New().String()
raw := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CommandUUID</key>
<string>%s</string>
<key>Command</key>
<dict>
<key>RequestType</key>
<string>InstallProfile</string>
<key>Payload</key>
<string>%s</string>
</dict>
</dict>
</plist>`, uuid, base64Profile)
err := svc.enqueue(ctx, hostUUIDs, raw)
return uuid, ctxerr.Wrap(ctx, err, "commander install profile")
}
// InstallProfile sends the homonymous MDM command to the given hosts.
func (svc *MDMAppleCommander) RemoveProfile(ctx context.Context, hostUUIDs []string, profileIdentifier string) (string, error) {
uuid := uuid.New().String()
raw := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CommandUUID</key>
<string>%s</string>
<key>Command</key>
<dict>
<key>RequestType</key>
<string>RemoveProfile</string>
<key>Identifier</key>
<string>%s</string>
</dict>
</dict>
</plist>`, uuid, profileIdentifier)
err := svc.enqueue(ctx, hostUUIDs, raw)
return uuid, ctxerr.Wrap(ctx, err, "commander remove profile")
}
// enqueue takes care of enqueuing the commands and sending push notifications
// to the devices.
//
// Always sending the push notification when a command is enqueued was decided
// internally, leaving making pushes optional as an optimization to be tackled
// later.
func (svc *MDMAppleCommander) enqueue(ctx context.Context, hostUUIDs []string, rawCommand string) error {
cmd, err := mdm.DecodeCommand([]byte(rawCommand))
if err != nil {
return ctxerr.Wrap(ctx, err, "commander enqueue")
}
// MySQL implementation always returns nil for the first parameter
_, err = svc.storage.EnqueueCommand(ctx, hostUUIDs, cmd)
if err != nil {
return ctxerr.Wrap(ctx, err, "commander enqueue")
}
apnsResponses, err := svc.pusher.Push(ctx, hostUUIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "commander push")
}
// Even if we didn't get an error, some of the APNs
// responses might have failed, signal that to the caller.
var failed []string
for uuid, response := range apnsResponses {
if response.Err != nil {
failed = append(failed, uuid)
}
}
if len(failed) > 0 {
return &APNSDeliveryError{FailedUUIDs: failed, Err: err}
}
return nil
}
// APNSDeliveryError records an error and the associated host UUIDs in which it
// occurred.
type APNSDeliveryError struct {
FailedUUIDs []string
Err error
}
func (e *APNSDeliveryError) Error() string {
return fmt.Sprintf("APNS delivery failed with: %e, for UUIDs: %v", e.Err, e.FailedUUIDs)
}
func (e *APNSDeliveryError) Unwrap() error { return e.Err }
func (e *APNSDeliveryError) StatusCode() int { return http.StatusBadGateway }

View file

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
@ -946,6 +947,79 @@ func TestMDMBatchSetAppleProfiles(t *testing.T) {
}
}
func TestMDMAppleCommander(t *testing.T) {
ctx := context.Background()
mdmStorage := &nanomdm_mock.Storage{}
pushFactory, _ := newMockAPNSPushProviderFactory()
pusher := nanomdm_pushsvc.New(
mdmStorage,
mdmStorage,
pushFactory,
NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)),
)
cmdr := NewMDMAppleCommander(mdmStorage, pusher)
// TODO(roberto): there's a data race in the mock when more
// than one host ID is provided because the pusher uses one
// goroutine per uuid to send the commands
hostUUIDs := []string{"A"}
payloadName := "com.foo.bar"
payloadIdentifier := "com-foo-bar"
mc := mobileconfigForTest(payloadName, payloadIdentifier)
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
require.NotNil(t, cmd)
require.Equal(t, cmd.Command.RequestType, "InstallProfile")
require.Contains(t, string(cmd.Raw), base64.StdEncoding.EncodeToString(mc))
return nil, nil
}
mdmStorage.RetrievePushInfoFunc = func(p0 context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) {
require.ElementsMatch(t, hostUUIDs, targetUUIDs)
pushes := make(map[string]*mdm.Push, len(targetUUIDs))
for _, uuid := range targetUUIDs {
pushes[uuid] = &mdm.Push{
PushMagic: "magic" + uuid,
Token: []byte("token" + uuid),
Topic: "topic" + uuid,
}
}
return pushes, nil
}
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key")
return &cert, "", err
}
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
return false, nil
}
uuid, err := cmdr.InstallProfile(ctx, hostUUIDs, mc)
require.NotEmpty(t, uuid)
require.NoError(t, err)
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
mdmStorage.EnqueueCommandFuncInvoked = false
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
mdmStorage.RetrievePushInfoFuncInvoked = false
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
require.NotNil(t, cmd)
require.Equal(t, "RemoveProfile", cmd.Command.RequestType)
require.Contains(t, string(cmd.Raw), payloadIdentifier)
return nil, nil
}
uuid, err = cmdr.RemoveProfile(ctx, hostUUIDs, payloadIdentifier)
require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
mdmStorage.EnqueueCommandFuncInvoked = false
require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
mdmStorage.RetrievePushInfoFuncInvoked = false
require.NotEmpty(t, uuid)
require.NoError(t, err)
}
func mobileconfigForTest(name, identifier string) []byte {
return []byte(fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">

View file

@ -9,6 +9,7 @@ import (
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
@ -574,13 +575,28 @@ func (s *integrationMDMTestSuite) TestMDMAppleUnenroll() {
originalPushMock := s.pushProvider.PushFunc
defer func() { s.pushProvider.PushFunc = originalPushMock }()
// TODO: this is not working as expected, we're still waiting on the
// device to unenroll even if we weren't able to send a push
// notification, we should return an error instead.
//
// the APNs service returns an error
// s.pushProvider.PushFunc = mockFailedPush
// s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK)
// if there's an error coming from APNs servers
s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
return map[string]*push.Response{
pushes[0].Token.String(): {
Id: uuid.New().String(),
Err: errors.New("test"),
},
}, nil
}
s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusBadGateway)
// if there was an error unrelated to APNs
s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
res := map[string]*push.Response{
pushes[0].Token.String(): {
Id: uuid.New().String(),
Err: nil,
},
}
return res, errors.New("baz")
}
s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusInternalServerError)
// try again, but this time the host is online and answers
s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {

View file

@ -52,10 +52,11 @@ type Service struct {
*fleet.EnterpriseOverrides
depStorage nanodep_storage.AllStorage
mdmStorage nanomdm_storage.AllStorage
mdmPushService nanomdm_push.Pusher
mdmPushCertTopic string
depStorage nanodep_storage.AllStorage
mdmStorage nanomdm_storage.AllStorage
mdmPushService nanomdm_push.Pusher
mdmPushCertTopic string
mdmAppleCommander *MDMAppleCommander
cronSchedulesService fleet.CronSchedulesService
}
@ -110,28 +111,32 @@ func NewService(
}
svc := &Service{
ds: ds,
task: task,
carveStore: carveStore,
installerStore: installerStore,
resultStore: resultStore,
liveQueryStore: lq,
logger: logger,
config: config,
clock: c,
osqueryLogWriter: osqueryLogger,
mailService: mailService,
ssoSessionStore: sso,
failingPolicySet: failingPolicySet,
authz: authorizer,
jitterH: make(map[time.Duration]*jitterHashTable),
jitterMu: new(sync.Mutex),
geoIP: geoIP,
enrollHostLimiter: enrollHostLimiter,
depStorage: depStorage,
ds: ds,
task: task,
carveStore: carveStore,
installerStore: installerStore,
resultStore: resultStore,
liveQueryStore: lq,
logger: logger,
config: config,
clock: c,
osqueryLogWriter: osqueryLogger,
mailService: mailService,
ssoSessionStore: sso,
failingPolicySet: failingPolicySet,
authz: authorizer,
jitterH: make(map[time.Duration]*jitterHashTable),
jitterMu: new(sync.Mutex),
geoIP: geoIP,
enrollHostLimiter: enrollHostLimiter,
depStorage: depStorage,
// TODO: remove mdmStorage and mdmPushService when
// we remove deprecated top-level service methods
// from the prototype.
mdmStorage: mdmStorage,
mdmPushService: mdmPushService,
mdmPushCertTopic: mdmPushCertTopic,
mdmAppleCommander: NewMDMAppleCommander(mdmStorage, mdmPushService),
cronSchedulesService: cronSchedulesService,
}
return validationMiddleware{svc, ds, sso}, nil