diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index e0f039fe46..a622527ca6 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -30,7 +30,9 @@ import ( "github.com/micromdm/nanodep/godep" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/push" + nanomdm_push "github.com/micromdm/nanomdm/push" "github.com/micromdm/nanomdm/storage" + nanomdm_storage "github.com/micromdm/nanomdm/storage" ) type createMDMAppleEnrollmentProfileRequest struct { @@ -805,16 +807,16 @@ func (svc *Service) EnqueueMDMAppleCommand( if err := svc.authz.Authorize(ctx, command, fleet.ActionWrite); err != nil { return 0, nil, ctxerr.Wrap(ctx, err) } - return rawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, command.Command, deviceIDs, noPush, svc.logger) + return deprecatedRawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, command.Command, deviceIDs, noPush, svc.logger) } -// rawCommandEnqueue enqueues a command to be executed on the given devices. +// deprecatedRawCommandEnqueue enqueues a command to be executed on the given devices. // // This method was extracted from: // https://github.com/fleetdm/nanomdm/blob/a261f081323c80fb7f6575a64ac1a912dffe44ba/http/api/api.go#L134-L261 // NOTE(lucas): At the time, I found no way to reuse Fleet's gokit middlewares with a raw http.Handler // like api.RawCommandEnqueueHandler. -func rawCommandEnqueue( +func deprecatedRawCommandEnqueue( ctx context.Context, enqueuer storage.CommandEnqueuer, pusher push.Pusher, @@ -1148,7 +1150,7 @@ func (svc *Service) EnqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Co return fleet.NewUserMessageError(ctxerr.New(ctx, fmt.Sprintf("mdm is not enabled for host %d", hostID)), http.StatusConflict) } - cmdUUID, err := svc.enqueueMDMAppleCommandRemoveEnrollmentProfile(ctx, h.UUID) + cmdUUID, err := svc.mdmAppleCommander.RemoveProfile(ctx, []string{h.UUID}, apple_mdm.FleetPayloadIdentifier) if err != nil { return ctxerr.Wrap(ctx, err, "enqueuing mdm apple remove profile command") } @@ -1164,59 +1166,6 @@ func (svc *Service) EnqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Co return svc.pollResultMDMAppleCommandRemoveEnrollmentProfile(ctx, cmdUUID, h.UUID) } -func (svc *Service) enqueueMDMAppleCommandRemoveEnrollmentProfile(ctx context.Context, hostUUID string) (string, error) { - cmd := new(mdm.Command) - cmdUUID := uuid.New().String() - cmd.CommandUUID = cmdUUID - cmd.Command.RequestType = "RemoveProfile" - cmd.Raw = []byte(generateMDMAppleCommandRemoveEnrollmentProfile(cmdUUID, apple_mdm.FleetPayloadIdentifier)) - - status, _, err := rawCommandEnqueue(ctx, svc.mdmStorage, svc.mdmPushService, cmd, []string{hostUUID}, false, svc.logger) - if err != nil { - // NOTE(sarah): rawCommandEnqueue does not currently return actionable errors so we rely on - // status code instead. - return cmdUUID, ctxerr.Wrap(ctx, err) - } - - if status != http.StatusOK { - level.Debug(svc.logger).Log( - "msg", fmt.Sprintf("enqueuing mdm apple remove profile command resulted in unexpected status %d", status), - "host_uuid", hostUUID, - "command_uuid", cmdUUID, - ) - if status != http.StatusMultiStatus { - // Status 207 is also possible with rawCommandEnqueue but should never happen in - // this case because we are only enqueueing this command to one device. If it does - // unexpectedly, we can proceed to the polling stage and will time out after 5 seconds - // in the worst case. - // - // Check the logs generated by rawCommandEnqueue if debugging unexpected results. - return cmdUUID, ctxerr.New(ctx, fmt.Sprintf("enqueuing mdm apple remove profile command resulted in unexpected status %d", status)) - } - } - - return cmdUUID, nil -} - -func generateMDMAppleCommandRemoveEnrollmentProfile(cmdUUID string, profileUUID string) string { - return fmt.Sprintf(` - - - - - CommandUUID - %s - Command - - RequestType - RemoveProfile - Identifier - %s - - -`, cmdUUID, profileUUID) -} - func (svc *Service) pollResultMDMAppleCommandRemoveEnrollmentProfile(ctx context.Context, cmdUUID string, deviceID string) error { ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(5*time.Second)) ticker := time.NewTicker(300 * time.Millisecond) @@ -1613,3 +1562,121 @@ func (svc *MDMAppleCheckinAndCommandService) DeclarativeManagement(*mdm.Request, func (svc *MDMAppleCheckinAndCommandService) CommandAndReportResults(*mdm.Request, *mdm.CommandResults) (*mdm.Command, error) { return nil, nil } + +// MDMAppleCommander contains methods to enqueue commands managed by Fleet and +// send push notifications to hosts. +// +// It's intentionally decoupled from fleet.Service so it can be used internally +// in crons and other services, leaving authentication/permission handling to +// the caller. +type MDMAppleCommander struct { + storage nanomdm_storage.AllStorage + pusher nanomdm_push.Pusher +} + +// NewMDMAppleCommander creates a new commander instance. +func NewMDMAppleCommander(mdmStorage nanomdm_storage.AllStorage, mdmPushService nanomdm_push.Pusher) *MDMAppleCommander { + return &MDMAppleCommander{ + storage: mdmStorage, + pusher: mdmPushService, + } +} + +// InstallProfile sends the homonymous MDM command to the given hosts, it also +// takes care of the base64 encoding of the provided profile bytes. +func (svc *MDMAppleCommander) InstallProfile(ctx context.Context, hostUUIDs []string, profile fleet.Mobileconfig) (string, error) { + base64Profile := base64.StdEncoding.EncodeToString(profile) + uuid := uuid.New().String() + raw := fmt.Sprintf(` + + + + CommandUUID + %s + Command + + RequestType + InstallProfile + Payload + %s + + +`, uuid, base64Profile) + err := svc.enqueue(ctx, hostUUIDs, raw) + return uuid, ctxerr.Wrap(ctx, err, "commander install profile") +} + +// InstallProfile sends the homonymous MDM command to the given hosts. +func (svc *MDMAppleCommander) RemoveProfile(ctx context.Context, hostUUIDs []string, profileIdentifier string) (string, error) { + uuid := uuid.New().String() + raw := fmt.Sprintf(` + + + + CommandUUID + %s + Command + + RequestType + RemoveProfile + Identifier + %s + + +`, uuid, profileIdentifier) + err := svc.enqueue(ctx, hostUUIDs, raw) + return uuid, ctxerr.Wrap(ctx, err, "commander remove profile") +} + +// enqueue takes care of enqueuing the commands and sending push notifications +// to the devices. +// +// Always sending the push notification when a command is enqueued was decided +// internally, leaving making pushes optional as an optimization to be tackled +// later. +func (svc *MDMAppleCommander) enqueue(ctx context.Context, hostUUIDs []string, rawCommand string) error { + cmd, err := mdm.DecodeCommand([]byte(rawCommand)) + if err != nil { + return ctxerr.Wrap(ctx, err, "commander enqueue") + } + + // MySQL implementation always returns nil for the first parameter + _, err = svc.storage.EnqueueCommand(ctx, hostUUIDs, cmd) + if err != nil { + return ctxerr.Wrap(ctx, err, "commander enqueue") + } + + apnsResponses, err := svc.pusher.Push(ctx, hostUUIDs) + if err != nil { + return ctxerr.Wrap(ctx, err, "commander push") + } + + // Even if we didn't get an error, some of the APNs + // responses might have failed, signal that to the caller. + var failed []string + for uuid, response := range apnsResponses { + if response.Err != nil { + failed = append(failed, uuid) + } + } + if len(failed) > 0 { + return &APNSDeliveryError{FailedUUIDs: failed, Err: err} + } + + return nil +} + +// APNSDeliveryError records an error and the associated host UUIDs in which it +// occurred. +type APNSDeliveryError struct { + FailedUUIDs []string + Err error +} + +func (e *APNSDeliveryError) Error() string { + return fmt.Sprintf("APNS delivery failed with: %e, for UUIDs: %v", e.Err, e.FailedUUIDs) +} + +func (e *APNSDeliveryError) Unwrap() error { return e.Err } + +func (e *APNSDeliveryError) StatusCode() int { return http.StatusBadGateway } diff --git a/server/service/apple_mdm_test.go b/server/service/apple_mdm_test.go index 65dee8f8fe..865dcea5e5 100644 --- a/server/service/apple_mdm_test.go +++ b/server/service/apple_mdm_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/base64" "fmt" "net/http" "net/http/httptest" @@ -946,6 +947,79 @@ func TestMDMBatchSetAppleProfiles(t *testing.T) { } } +func TestMDMAppleCommander(t *testing.T) { + ctx := context.Background() + mdmStorage := &nanomdm_mock.Storage{} + pushFactory, _ := newMockAPNSPushProviderFactory() + pusher := nanomdm_pushsvc.New( + mdmStorage, + mdmStorage, + pushFactory, + NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)), + ) + cmdr := NewMDMAppleCommander(mdmStorage, pusher) + + // TODO(roberto): there's a data race in the mock when more + // than one host ID is provided because the pusher uses one + // goroutine per uuid to send the commands + hostUUIDs := []string{"A"} + payloadName := "com.foo.bar" + payloadIdentifier := "com-foo-bar" + mc := mobileconfigForTest(payloadName, payloadIdentifier) + + mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { + require.NotNil(t, cmd) + require.Equal(t, cmd.Command.RequestType, "InstallProfile") + require.Contains(t, string(cmd.Raw), base64.StdEncoding.EncodeToString(mc)) + return nil, nil + } + + mdmStorage.RetrievePushInfoFunc = func(p0 context.Context, targetUUIDs []string) (map[string]*mdm.Push, error) { + require.ElementsMatch(t, hostUUIDs, targetUUIDs) + pushes := make(map[string]*mdm.Push, len(targetUUIDs)) + for _, uuid := range targetUUIDs { + pushes[uuid] = &mdm.Push{ + + PushMagic: "magic" + uuid, + Token: []byte("token" + uuid), + Topic: "topic" + uuid, + } + } + + return pushes, nil + } + + mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) { + cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key") + return &cert, "", err + } + mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) { + return false, nil + } + + uuid, err := cmdr.InstallProfile(ctx, hostUUIDs, mc) + require.NotEmpty(t, uuid) + require.NoError(t, err) + require.True(t, mdmStorage.EnqueueCommandFuncInvoked) + mdmStorage.EnqueueCommandFuncInvoked = false + require.True(t, mdmStorage.RetrievePushInfoFuncInvoked) + mdmStorage.RetrievePushInfoFuncInvoked = false + + mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { + require.NotNil(t, cmd) + require.Equal(t, "RemoveProfile", cmd.Command.RequestType) + require.Contains(t, string(cmd.Raw), payloadIdentifier) + return nil, nil + } + uuid, err = cmdr.RemoveProfile(ctx, hostUUIDs, payloadIdentifier) + require.True(t, mdmStorage.EnqueueCommandFuncInvoked) + mdmStorage.EnqueueCommandFuncInvoked = false + require.True(t, mdmStorage.RetrievePushInfoFuncInvoked) + mdmStorage.RetrievePushInfoFuncInvoked = false + require.NotEmpty(t, uuid) + require.NoError(t, err) +} + func mobileconfigForTest(name, identifier string) []byte { return []byte(fmt.Sprintf(` diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index 966c10a34c..412c3d8ecb 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509/pkix" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "math/big" @@ -574,13 +575,28 @@ func (s *integrationMDMTestSuite) TestMDMAppleUnenroll() { originalPushMock := s.pushProvider.PushFunc defer func() { s.pushProvider.PushFunc = originalPushMock }() - // TODO: this is not working as expected, we're still waiting on the - // device to unenroll even if we weren't able to send a push - // notification, we should return an error instead. - // - // the APNs service returns an error - // s.pushProvider.PushFunc = mockFailedPush - // s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK) + // if there's an error coming from APNs servers + s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) { + return map[string]*push.Response{ + pushes[0].Token.String(): { + Id: uuid.New().String(), + Err: errors.New("test"), + }, + }, nil + } + s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusBadGateway) + + // if there was an error unrelated to APNs + s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) { + res := map[string]*push.Response{ + pushes[0].Token.String(): { + Id: uuid.New().String(), + Err: nil, + }, + } + return res, errors.New("baz") + } + s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusInternalServerError) // try again, but this time the host is online and answers s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) { diff --git a/server/service/service.go b/server/service/service.go index 01e52f1230..bb1946527b 100644 --- a/server/service/service.go +++ b/server/service/service.go @@ -52,10 +52,11 @@ type Service struct { *fleet.EnterpriseOverrides - depStorage nanodep_storage.AllStorage - mdmStorage nanomdm_storage.AllStorage - mdmPushService nanomdm_push.Pusher - mdmPushCertTopic string + depStorage nanodep_storage.AllStorage + mdmStorage nanomdm_storage.AllStorage + mdmPushService nanomdm_push.Pusher + mdmPushCertTopic string + mdmAppleCommander *MDMAppleCommander cronSchedulesService fleet.CronSchedulesService } @@ -110,28 +111,32 @@ func NewService( } svc := &Service{ - ds: ds, - task: task, - carveStore: carveStore, - installerStore: installerStore, - resultStore: resultStore, - liveQueryStore: lq, - logger: logger, - config: config, - clock: c, - osqueryLogWriter: osqueryLogger, - mailService: mailService, - ssoSessionStore: sso, - failingPolicySet: failingPolicySet, - authz: authorizer, - jitterH: make(map[time.Duration]*jitterHashTable), - jitterMu: new(sync.Mutex), - geoIP: geoIP, - enrollHostLimiter: enrollHostLimiter, - depStorage: depStorage, + ds: ds, + task: task, + carveStore: carveStore, + installerStore: installerStore, + resultStore: resultStore, + liveQueryStore: lq, + logger: logger, + config: config, + clock: c, + osqueryLogWriter: osqueryLogger, + mailService: mailService, + ssoSessionStore: sso, + failingPolicySet: failingPolicySet, + authz: authorizer, + jitterH: make(map[time.Duration]*jitterHashTable), + jitterMu: new(sync.Mutex), + geoIP: geoIP, + enrollHostLimiter: enrollHostLimiter, + depStorage: depStorage, + // TODO: remove mdmStorage and mdmPushService when + // we remove deprecated top-level service methods + // from the prototype. mdmStorage: mdmStorage, mdmPushService: mdmPushService, mdmPushCertTopic: mdmPushCertTopic, + mdmAppleCommander: NewMDMAppleCommander(mdmStorage, mdmPushService), cronSchedulesService: cronSchedulesService, } return validationMiddleware{svc, ds, sso}, nil