diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go
index e0f039fe46..a622527ca6 100644
--- a/server/service/apple_mdm.go
+++ b/server/service/apple_mdm.go
@@ -30,7 +30,9 @@ import (
"github.com/micromdm/nanodep/godep"
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/push"
+ nanomdm_push "github.com/micromdm/nanomdm/push"
"github.com/micromdm/nanomdm/storage"
+ nanomdm_storage "github.com/micromdm/nanomdm/storage"
)
type createMDMAppleEnrollmentProfileRequest struct {
@@ -805,16 +807,16 @@ func (svc *Service) EnqueueMDMAppleCommand(
if err := svc.authz.Authorize(ctx, command, fleet.ActionWrite); err != nil {
return 0, nil, ctxerr.Wrap(ctx, err)
}
- return rawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, command.Command, deviceIDs, noPush, svc.logger)
+ return deprecatedRawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, command.Command, deviceIDs, noPush, svc.logger)
}
-// rawCommandEnqueue enqueues a command to be executed on the given devices.
+// deprecatedRawCommandEnqueue enqueues a command to be executed on the given devices.
//
// This method was extracted from:
// https://github.com/fleetdm/nanomdm/blob/a261f081323c80fb7f6575a64ac1a912dffe44ba/http/api/api.go#L134-L261
// NOTE(lucas): At the time, I found no way to reuse Fleet's gokit middlewares with a raw http.Handler
// like api.RawCommandEnqueueHandler.
-func rawCommandEnqueue(
+func deprecatedRawCommandEnqueue(
ctx context.Context,
enqueuer storage.CommandEnqueuer,
pusher push.Pusher,
@@ -1148,7 +1150,7 @@ func (svc *Service) EnqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Co
return fleet.NewUserMessageError(ctxerr.New(ctx, fmt.Sprintf("mdm is not enabled for host %d", hostID)), http.StatusConflict)
}
- cmdUUID, err := svc.enqueueMDMAppleCommandRemoveEnrollmentProfile(ctx, h.UUID)
+ cmdUUID, err := svc.mdmAppleCommander.RemoveProfile(ctx, []string{h.UUID}, apple_mdm.FleetPayloadIdentifier)
if err != nil {
return ctxerr.Wrap(ctx, err, "enqueuing mdm apple remove profile command")
}
@@ -1164,59 +1166,6 @@ func (svc *Service) EnqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Co
return svc.pollResultMDMAppleCommandRemoveEnrollmentProfile(ctx, cmdUUID, h.UUID)
}
-func (svc *Service) enqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Context, hostUUID string) (string, error) {
- cmd := new(mdm.Command)
- cmdUUID := uuid.New().String()
- cmd.CommandUUID = cmdUUID
- cmd.Command.RequestType = "RemoveProfile"
- cmd.Raw = []byte(generateMDMAppleCommandRemoveEnrollmentProfile(cmdUUID, apple_mdm.FleetPayloadIdentifier))
-
- status, _, err := rawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, cmd, []string{hostUUID}, false, svc.logger)
- if err != nil {
- // NOTE(sarah): rawCommandEnqueue does not currently return actionable errors so we rely on
- // status code instead.
- return cmdUUID, ctxerr.Wrap(ctx, err)
- }
-
- if status != http.StatusOK {
- level.Debug(svc.logger).Log(
- "msg", fmt.Sprintf("enqueuing mdm apple remove profile command resulted in unexpected status %d", status),
- "host_uuid", hostUUID,
- "command_uuid", cmdUUID,
- )
- if status != http.StatusMultiStatus {
- // Status 207 is also possible with rawCommandEnqueue but should never happen in
- // this case because we are only enqueueing this command to one device. If it does
- // unexpectedly, we can proceed to the polling stage and will time out after 5 seconds
- // in the worst case.
- //
- // Check the logs generated by rawCommandEnqueue if debugging unexpected results.
- return cmdUUID, ctxerr.New(ctx, fmt.Sprintf("enqueuing mdm apple remove profile command resulted in unexpected status %d", status))
- }
- }
-
- return cmdUUID, nil
-}
-
-func generateMDMAppleCommandRemoveEnrollmentProfile(cmdUUID string, profileUUID string) string {
- return fmt.Sprintf(`
-
-
-
-
- CommandUUID
- %s
- Command
-
- RequestType
- RemoveProfile
- Identifier
- %s
-
-
-`, cmdUUID, profileUUID)
-}
-
func (svc *Service) pollResultMDMAppleCommandRemoveEnrollmentProfile(ctx context.Context, cmdUUID string, deviceID string) error {
ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(5*time.Second))
ticker := time.NewTicker(300 * time.Millisecond)
@@ -1613,3 +1562,121 @@ func (svc *MDMAppleCheckinAndCommandService) DeclarativeManagement(*mdm.Request,
func (svc *MDMAppleCheckinAndCommandService) CommandAndReportResults(*mdm.Request, *mdm.CommandResults) (*mdm.Command, error) {
return nil, nil
}
+
+// MDMAppleCommander contains methods to enqueue commands managed by Fleet and
+// send push notifications to hosts.
+//
+// It's intentionally decoupled from fleet.Service so it can be used internally
+// in crons and other services, leaving authentication/permission handling to
+// the caller.
+type MDMAppleCommander struct {
+ storage nanomdm_storage.AllStorage
+ pusher nanomdm_push.Pusher
+}
+
+// NewMDMAppleCommander creates a new commander instance.
+func NewMDMAppleCommander(mdmStorage nanomdm_storage.AllStorage, mdmPushService nanomdm_push.Pusher) *MDMAppleCommander {
+ return &MDMAppleCommander{
+ storage: mdmStorage,
+ pusher: mdmPushService,
+ }
+}
+
+// InstallProfile sends the homonymous MDM command to the given hosts, it also
+// takes care of the base64 encoding of the provided profile bytes.
+func (svc *MDMAppleCommander) InstallProfile(ctx context.Context, hostUUIDs []string, profile fleet.Mobileconfig) (string, error) {
+ base64Profile := base64.StdEncoding.EncodeToString(profile)
+ uuid := uuid.New().String()
+ raw := fmt.Sprintf(`
+
+
+
+ CommandUUID
+ %s
+ Command
+
+ RequestType
+ InstallProfile
+ Payload
+ %s
+
+
+`, uuid, base64Profile)
+ err := svc.enqueue(ctx, hostUUIDs, raw)
+ return uuid, ctxerr.Wrap(ctx, err, "commander install profile")
+}
+
+// InstallProfile sends the homonymous MDM command to the given hosts.
+func (svc *MDMAppleCommander) RemoveProfile(ctx context.Context, hostUUIDs []string, profileIdentifier string) (string, error) {
+ uuid := uuid.New().String()
+ raw := fmt.Sprintf(`
+
+
+
+ CommandUUID
+ %s
+ Command
+
+ RequestType
+ RemoveProfile
+ Identifier
+ %s
+
+
+`, uuid, profileIdentifier)
+ err := svc.enqueue(ctx, hostUUIDs, raw)
+ return uuid, ctxerr.Wrap(ctx, err, "commander remove profile")
+}
+
+// enqueue takes care of enqueuing the commands and sending push notifications
+// to the devices.
+//
+// Always sending the push notification when a command is enqueued was decided
+// internally, leaving making pushes optional as an optimization to be tackled
+// later.
+func (svc *MDMAppleCommander) enqueue(ctx context.Context, hostUUIDs []string, rawCommand string) error {
+ cmd, err := mdm.DecodeCommand([]byte(rawCommand))
+ if err != nil {
+ return ctxerr.Wrap(ctx, err, "commander enqueue")
+ }
+
+ // MySQL implementation always returns nil for the first parameter
+ _, err = svc.storage.EnqueueCommand(ctx, hostUUIDs, cmd)
+ if err != nil {
+ return ctxerr.Wrap(ctx, err, "commander enqueue")
+ }
+
+ apnsResponses, err := svc.pusher.Push(ctx, hostUUIDs)
+ if err != nil {
+ return ctxerr.Wrap(ctx, err, "commander push")
+ }
+
+ // Even if we didn't get an error, some of the APNs
+ // responses might have failed, signal that to the caller.
+ var failed []string
+ for uuid, response := range apnsResponses {
+ if response.Err != nil {
+ failed = append(failed, uuid)
+ }
+ }
+ if len(failed) > 0 {
+ return &APNSDeliveryError{FailedUUIDs: failed, Err: err}
+ }
+
+ return nil
+}
+
+// APNSDeliveryError records an error and the associated host UUIDs in which it
+// occurred.
+type APNSDeliveryError struct {
+ FailedUUIDs []string
+ Err error
+}
+
+func (e *APNSDeliveryError) Error() string {
+ return fmt.Sprintf("APNS delivery failed with: %e, for UUIDs: %v", e.Err, e.FailedUUIDs)
+}
+
+func (e *APNSDeliveryError) Unwrap() error { return e.Err }
+
+func (e *APNSDeliveryError) StatusCode() int { return http.StatusBadGateway }
diff --git a/server/service/apple_mdm_test.go b/server/service/apple_mdm_test.go
index 65dee8f8fe..865dcea5e5 100644
--- a/server/service/apple_mdm_test.go
+++ b/server/service/apple_mdm_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/tls"
+ "encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
@@ -946,6 +947,79 @@ func TestMDMBatchSetAppleProfiles(t *testing.T) {
}
}
+func TestMDMAppleCommander(t *testing.T) {
+ ctx := context.Background()
+ mdmStorage := &nanomdm_mock.Storage{}
+ pushFactory, _ := newMockAPNSPushProviderFactory()
+ pusher := nanomdm_pushsvc.New(
+ mdmStorage,
+ mdmStorage,
+ pushFactory,
+ NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)),
+ )
+ cmdr := NewMDMAppleCommander(mdmStorage, pusher)
+
+ // TODO(roberto): there's a data race in the mock when more
+ // than one host ID is provided because the pusher uses one
+ // goroutine per uuid to send the commands
+ hostUUIDs := []string{"A"}
+ payloadName := "com.foo.bar"
+ payloadIdentifier := "com-foo-bar"
+ mc := mobileconfigForTest(payloadName, payloadIdentifier)
+
+ mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
+ require.NotNil(t, cmd)
+ require.Equal(t, cmd.Command.RequestType, "InstallProfile")
+ require.Contains(t, string(cmd.Raw), base64.StdEncoding.EncodeToString(mc))
+ return nil, nil
+ }
+
+ mdmStorage.RetrievePushInfoFunc = func(p0 context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) {
+ require.ElementsMatch(t, hostUUIDs, targetUUIDs)
+ pushes := make(map[string]*mdm.Push, len(targetUUIDs))
+ for _, uuid := range targetUUIDs {
+ pushes[uuid] = &mdm.Push{
+
+ PushMagic: "magic" + uuid,
+ Token: []byte("token" + uuid),
+ Topic: "topic" + uuid,
+ }
+ }
+
+ return pushes, nil
+ }
+
+ mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
+ cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key")
+ return &cert, "", err
+ }
+ mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
+ return false, nil
+ }
+
+ uuid, err := cmdr.InstallProfile(ctx, hostUUIDs, mc)
+ require.NotEmpty(t, uuid)
+ require.NoError(t, err)
+ require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
+ mdmStorage.EnqueueCommandFuncInvoked = false
+ require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
+ mdmStorage.RetrievePushInfoFuncInvoked = false
+
+ mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
+ require.NotNil(t, cmd)
+ require.Equal(t, "RemoveProfile", cmd.Command.RequestType)
+ require.Contains(t, string(cmd.Raw), payloadIdentifier)
+ return nil, nil
+ }
+ uuid, err = cmdr.RemoveProfile(ctx, hostUUIDs, payloadIdentifier)
+ require.True(t, mdmStorage.EnqueueCommandFuncInvoked)
+ mdmStorage.EnqueueCommandFuncInvoked = false
+ require.True(t, mdmStorage.RetrievePushInfoFuncInvoked)
+ mdmStorage.RetrievePushInfoFuncInvoked = false
+ require.NotEmpty(t, uuid)
+ require.NoError(t, err)
+}
+
func mobileconfigForTest(name, identifier string) []byte {
return []byte(fmt.Sprintf(`
diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go
index 966c10a34c..412c3d8ecb 100644
--- a/server/service/integration_mdm_test.go
+++ b/server/service/integration_mdm_test.go
@@ -9,6 +9,7 @@ import (
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
+ "errors"
"fmt"
"io"
"math/big"
@@ -574,13 +575,28 @@ func (s *integrationMDMTestSuite) TestMDMAppleUnenroll() {
originalPushMock := s.pushProvider.PushFunc
defer func() { s.pushProvider.PushFunc = originalPushMock }()
- // TODO: this is not working as expected, we're still waiting on the
- // device to unenroll even if we weren't able to send a push
- // notification, we should return an error instead.
- //
- // the APNs service returns an error
- // s.pushProvider.PushFunc = mockFailedPush
- // s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK)
+ // if there's an error coming from APNs servers
+ s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
+ return map[string]*push.Response{
+ pushes[0].Token.String(): {
+ Id: uuid.New().String(),
+ Err: errors.New("test"),
+ },
+ }, nil
+ }
+ s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusBadGateway)
+
+ // if there was an error unrelated to APNs
+ s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
+ res := map[string]*push.Response{
+ pushes[0].Token.String(): {
+ Id: uuid.New().String(),
+ Err: nil,
+ },
+ }
+ return res, errors.New("baz")
+ }
+ s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusInternalServerError)
// try again, but this time the host is online and answers
s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
diff --git a/server/service/service.go b/server/service/service.go
index 01e52f1230..bb1946527b 100644
--- a/server/service/service.go
+++ b/server/service/service.go
@@ -52,10 +52,11 @@ type Service struct {
*fleet.EnterpriseOverrides
- depStorage nanodep_storage.AllStorage
- mdmStorage nanomdm_storage.AllStorage
- mdmPushService nanomdm_push.Pusher
- mdmPushCertTopic string
+ depStorage nanodep_storage.AllStorage
+ mdmStorage nanomdm_storage.AllStorage
+ mdmPushService nanomdm_push.Pusher
+ mdmPushCertTopic string
+ mdmAppleCommander *MDMAppleCommander
cronSchedulesService fleet.CronSchedulesService
}
@@ -110,28 +111,32 @@ func NewService(
}
svc := &Service{
- ds: ds,
- task: task,
- carveStore: carveStore,
- installerStore: installerStore,
- resultStore: resultStore,
- liveQueryStore: lq,
- logger: logger,
- config: config,
- clock: c,
- osqueryLogWriter: osqueryLogger,
- mailService: mailService,
- ssoSessionStore: sso,
- failingPolicySet: failingPolicySet,
- authz: authorizer,
- jitterH: make(map[time.Duration]*jitterHashTable),
- jitterMu: new(sync.Mutex),
- geoIP: geoIP,
- enrollHostLimiter: enrollHostLimiter,
- depStorage: depStorage,
+ ds: ds,
+ task: task,
+ carveStore: carveStore,
+ installerStore: installerStore,
+ resultStore: resultStore,
+ liveQueryStore: lq,
+ logger: logger,
+ config: config,
+ clock: c,
+ osqueryLogWriter: osqueryLogger,
+ mailService: mailService,
+ ssoSessionStore: sso,
+ failingPolicySet: failingPolicySet,
+ authz: authorizer,
+ jitterH: make(map[time.Duration]*jitterHashTable),
+ jitterMu: new(sync.Mutex),
+ geoIP: geoIP,
+ enrollHostLimiter: enrollHostLimiter,
+ depStorage: depStorage,
+ // TODO: remove mdmStorage and mdmPushService when
+ // we remove deprecated top-level service methods
+ // from the prototype.
mdmStorage: mdmStorage,
mdmPushService: mdmPushService,
mdmPushCertTopic: mdmPushCertTopic,
+ mdmAppleCommander: NewMDMAppleCommander(mdmStorage, mdmPushService),
cronSchedulesService: cronSchedulesService,
}
return validationMiddleware{svc, ds, sso}, nil