From ed3e755641904c9aa0aa78f3d08e8283a41702b3 Mon Sep 17 00:00:00 2001 From: Carlo <1778532+cdcme@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:04:13 -0400 Subject: [PATCH] Fix MDM cert auth 500 errors (#32981) Fixes #30958. Maps certificate authentication errors to proper HTTP status codes (400/403) instead of 500. --- server/mdm/nanomdm/http/mdm/mdm_test.go | 163 ++++++++++++ .../mdm/nanomdm/service/certauth/certauth.go | 35 ++- .../nanomdm/service/certauth/certauth_test.go | 243 ++++++++++++++++++ server/service/integration_mdm_test.go | 4 +- 4 files changed, 437 insertions(+), 8 deletions(-) create mode 100644 server/mdm/nanomdm/service/certauth/certauth_test.go diff --git a/server/mdm/nanomdm/http/mdm/mdm_test.go b/server/mdm/nanomdm/http/mdm/mdm_test.go index 6600567dbb..002c842227 100644 --- a/server/mdm/nanomdm/http/mdm/mdm_test.go +++ b/server/mdm/nanomdm/http/mdm/mdm_test.go @@ -9,7 +9,12 @@ import ( "net/http/httptest" "testing" + "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" + "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service" + "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service/certauth" "github.com/micromdm/nanolib/log" + "github.com/micromdm/plist" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -66,3 +71,161 @@ func TestCertWithEnrollmentIDMiddleware(t *testing.T) { t.Error("body not equal") } } + +// mockCertAuthService simulates certificate auth errors +type mockCertAuthService struct { + authenticateErr error + tokenUpdateErr error + commandErr error +} + +func (m *mockCertAuthService) Authenticate(r *mdm.Request, msg *mdm.Authenticate) error { + return m.authenticateErr +} + +func (m *mockCertAuthService) TokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { + return m.tokenUpdateErr +} + +func (m *mockCertAuthService) CheckOut(r *mdm.Request, msg *mdm.CheckOut) error { + return nil +} + +func (m *mockCertAuthService) UserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) ([]byte, error) { + return nil, nil +} + +func (m *mockCertAuthService) SetBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error { + return nil +} + +func (m *mockCertAuthService) GetBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { + return nil, nil +} + +func (m *mockCertAuthService) DeclarativeManagement(r *mdm.Request, msg *mdm.DeclarativeManagement) ([]byte, error) { + return nil, nil +} + +func (m *mockCertAuthService) GetToken(r *mdm.Request, msg *mdm.GetToken) (*mdm.GetTokenResponse, error) { + return nil, nil +} + +func (m *mockCertAuthService) CommandAndReportResults(r *mdm.Request, results *mdm.CommandResults) (*mdm.Command, error) { + return nil, m.commandErr +} + +// TestCheckinAndCommandHandler_ErrorHandling verifies handlers return HTTP status codes for errors +func TestCheckinAndCommandHandler_ErrorHandling(t *testing.T) { + tests := []struct { + name string + wrapError bool // if true, wrap with HTTPStatusError + serviceError error + expectedStatus int + }{ + { + name: "Unwrapped_CertAuth_Error_Returns_500", + wrapError: false, + serviceError: certauth.ErrNoCertAssoc, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Wrapped_CertAuth_Error_Returns_403", + wrapError: true, + serviceError: certauth.ErrNoCertAssoc, + expectedStatus: http.StatusForbidden, + }, + { + name: "Wrapped_Missing_Cert_Returns_400", + wrapError: true, + serviceError: certauth.ErrMissingCert, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, isCheckin := range []bool{true, false} { + subtest := "Command" + if isCheckin { + subtest = "Checkin" + } + t.Run(subtest, func(t *testing.T) { + // Setup error + err := tt.serviceError + if tt.wrapError { + if errors.Is(tt.serviceError, certauth.ErrNoCertAssoc) || errors.Is(tt.serviceError, certauth.ErrNoCertReuse) { + err = service.NewHTTPStatusError(http.StatusForbidden, tt.serviceError) + } else if errors.Is(tt.serviceError, certauth.ErrMissingCert) { + err = service.NewHTTPStatusError(http.StatusBadRequest, tt.serviceError) + } + } + + mockSvc := &mockCertAuthService{ + tokenUpdateErr: err, + commandErr: err, + } + + // Create handler and request + var handler http.HandlerFunc + var body []byte + var contentType string + + if isCheckin { + handler = CheckinHandler(mockSvc, log.NopLogger) + tokenUpdate := &mdm.TokenUpdate{ + Enrollment: mdm.Enrollment{UDID: "test-udid"}, + MessageType: mdm.MessageType{MessageType: "TokenUpdate"}, + } + body, err = plist.Marshal(tokenUpdate) + require.NoError(t, err) + contentType = "application/x-apple-aspen-mdm-checkin" + } else { + handler = CommandAndReportResultsHandler(mockSvc, log.NopLogger) + cmdResults := &mdm.CommandResults{ + Enrollment: mdm.Enrollment{UDID: "test-udid"}, + CommandUUID: "test-cmd-uuid", + Status: "Acknowledged", + } + body, err = plist.Marshal(cmdResults) + require.NoError(t, err) + contentType = "application/x-apple-aspen-mdm" + } + + req := httptest.NewRequest(http.MethodPost, "/mdm", bytes.NewReader(body)) + req.Header.Set("Content-Type", contentType) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code, + "Expected status %d, got %d", tt.expectedStatus, rr.Code) + }) + } + }) + } +} + +// TestErrorResponseBody verifies error response bodies are correct +func TestErrorResponseBody(t *testing.T) { + mockSvc := &mockCertAuthService{ + tokenUpdateErr: service.NewHTTPStatusError(http.StatusForbidden, certauth.ErrNoCertAssoc), + } + + handler := CheckinHandler(mockSvc, log.NopLogger) + tokenUpdate := &mdm.TokenUpdate{ + Enrollment: mdm.Enrollment{UDID: "test-udid"}, + MessageType: mdm.MessageType{MessageType: "TokenUpdate"}, + } + body, err := plist.Marshal(tokenUpdate) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mdm", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/x-apple-aspen-mdm-checkin") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Equal(t, "Forbidden\n", rr.Body.String()) +} diff --git a/server/mdm/nanomdm/service/certauth/certauth.go b/server/mdm/nanomdm/service/certauth/certauth.go index e3421ba0cb..5732bb0444 100644 --- a/server/mdm/nanomdm/service/certauth/certauth.go +++ b/server/mdm/nanomdm/service/certauth/certauth.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "errors" "fmt" + "net/http" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service" @@ -106,9 +107,31 @@ func HashCert(cert *x509.Certificate) string { return hex.EncodeToString(b) } +// wrapCertAuthError wraps certificate authentication errors with appropriate HTTP status codes +func wrapCertAuthError(err error) error { + if err == nil { + return nil + } + + switch { + case errors.Is(err, ErrNoCertAssoc): + // cert not associated - authentication/authorization failure + return service.NewHTTPStatusError(http.StatusForbidden, err) + case errors.Is(err, ErrNoCertReuse): + // cert reuse attempt - authentication/authorization failure + return service.NewHTTPStatusError(http.StatusForbidden, err) + case errors.Is(err, ErrMissingCert): + // missing cert - bad request + return service.NewHTTPStatusError(http.StatusBadRequest, err) + default: + // Don't wrap other errors - let them bubble up as-is + return err + } +} + func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { if r.Certificate == nil { - return ErrMissingCert + return wrapCertAuthError(ErrMissingCert) } if err := r.EnrollID.Validate(); err != nil { return err @@ -139,7 +162,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { "hash", hash, ) if !s.warnOnly { - return ErrNoCertReuse + return wrapCertAuthError(ErrNoCertReuse) } } } @@ -157,7 +180,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { if r.Certificate == nil { - return ErrMissingCert + return wrapCertAuthError(ErrMissingCert) } if err := r.EnrollID.Validate(); err != nil { return err @@ -177,7 +200,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { "hash", hash, ) if !s.warnOnly { - return ErrNoCertAssoc + return wrapCertAuthError(ErrNoCertAssoc) } } // even if allowRetroactive is true we don't want to allow arbitrary @@ -193,7 +216,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { "id", r.ID, ) if !s.warnOnly { - return ErrNoCertReuse + return wrapCertAuthError(ErrNoCertReuse) } } // even if allowDup were true we don't want to allow arbitrary @@ -211,7 +234,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { "hash", hash, ) if !s.warnOnly { - return ErrNoCertReuse + return wrapCertAuthError(ErrNoCertReuse) } } if s.warnOnly { diff --git a/server/mdm/nanomdm/service/certauth/certauth_test.go b/server/mdm/nanomdm/service/certauth/certauth_test.go new file mode 100644 index 0000000000..5b092bd00d --- /dev/null +++ b/server/mdm/nanomdm/service/certauth/certauth_test.go @@ -0,0 +1,243 @@ +package certauth + +import ( + "context" + "crypto/x509" + "errors" + "net/http" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" + "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service" + "github.com/stretchr/testify/assert" +) + +// wrapCertAuthError wraps errors with HTTP status codes +func TestWrapCertAuthError(t *testing.T) { + tests := []struct { + name string + inputError error + expectedStatus int + shouldWrap bool + }{ + { + name: "ErrNoCertAssoc_Returns_403", + inputError: ErrNoCertAssoc, + expectedStatus: http.StatusForbidden, + shouldWrap: true, + }, + { + name: "ErrNoCertReuse_Returns_403", + inputError: ErrNoCertReuse, + expectedStatus: http.StatusForbidden, + shouldWrap: true, + }, + { + name: "ErrMissingCert_Returns_400", + inputError: ErrMissingCert, + expectedStatus: http.StatusBadRequest, + shouldWrap: true, + }, + { + name: "Other_Error_Not_Wrapped", + inputError: errors.New("some other error"), + expectedStatus: 0, + shouldWrap: false, + }, + { + name: "Nil_Error_Returns_Nil", + inputError: nil, + expectedStatus: 0, + shouldWrap: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := wrapCertAuthError(tt.inputError) + + if tt.inputError == nil { + assert.Nil(t, result) + return + } + + if tt.shouldWrap { + var statusErr *service.HTTPStatusError + assert.True(t, errors.As(result, &statusErr), "Expected HTTPStatusError wrapper") + if statusErr != nil { + assert.Equal(t, tt.expectedStatus, statusErr.Status, "Expected status %d, got %d", tt.expectedStatus, statusErr.Status) + assert.True(t, errors.Is(result, tt.inputError), "Original error should be preserved") + } + } else { + // Should return the error unchanged + assert.Equal(t, tt.inputError, result) + } + }) + } +} + +// mockInnerService for testing +type mockInnerService struct { + authenticateErr error + tokenUpdateErr error +} + +func (m *mockInnerService) Authenticate(r *mdm.Request, msg *mdm.Authenticate) error { + return m.authenticateErr +} + +func (m *mockInnerService) TokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { + return m.tokenUpdateErr +} + +func (m *mockInnerService) CheckOut(r *mdm.Request, msg *mdm.CheckOut) error { + return nil +} + +func (m *mockInnerService) UserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) ([]byte, error) { + return nil, nil +} + +func (m *mockInnerService) SetBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error { + return nil +} + +func (m *mockInnerService) GetBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { + return nil, nil +} + +func (m *mockInnerService) DeclarativeManagement(r *mdm.Request, msg *mdm.DeclarativeManagement) ([]byte, error) { + return nil, nil +} + +func (m *mockInnerService) GetToken(r *mdm.Request, msg *mdm.GetToken) (*mdm.GetTokenResponse, error) { + return nil, nil +} + +func (m *mockInnerService) CommandAndReportResults(r *mdm.Request, results *mdm.CommandResults) (*mdm.Command, error) { + return nil, nil +} + +// mockCertAuthStore for testing +type mockCertAuthStore struct { + hasCertHash bool + isAssociated bool + enrollmentHasHash bool +} + +func (m *mockCertAuthStore) HasCertHash(r *mdm.Request, hash string) (bool, error) { + return m.hasCertHash, nil +} + +func (m *mockCertAuthStore) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) { + return m.isAssociated, nil +} + +func (m *mockCertAuthStore) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) { + return m.enrollmentHasHash, nil +} + +func (m *mockCertAuthStore) AssociateCertHash(r *mdm.Request, hash string, exp time.Time) error { + return nil +} + +// certauth service returns wrapped errors +func TestCertAuthService_ErrorWrapping(t *testing.T) { + tests := []struct { + name string + setupStore func(*mockCertAuthStore) + includeCert bool + expectedError error + description string + }{ + { + name: "Missing_Certificate", + setupStore: func(s *mockCertAuthStore) { + // No setup needed + }, + includeCert: false, + expectedError: ErrMissingCert, + description: "Missing certificate should return ErrMissingCert", + }, + { + name: "No_Certificate_Association", + setupStore: func(s *mockCertAuthStore) { + s.hasCertHash = false + s.isAssociated = false + }, + includeCert: true, + expectedError: ErrNoCertAssoc, + description: "No certificate association should return ErrNoCertAssoc", + }, + // N.b., certificate reuse scenario is complex and depends on multiple conditions + // The actual error returned depends on the order of checks in validateAssociateExistingEnrollment + // This is tested more thoroughly in integration tests + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock store + store := &mockCertAuthStore{} + tt.setupStore(store) + + // Create certauth service + innerSvc := &mockInnerService{} + certAuthSvc := New(innerSvc, store) + + // Create request + r := &mdm.Request{ + Context: context.Background(), + EnrollID: &mdm.EnrollID{ + Type: mdm.Device, + ID: "test-device", + }, + } + + if tt.includeCert { + r.Certificate = &x509.Certificate{ + Raw: []byte("test-cert"), + } + } + + // Test with TokenUpdate (existing enrollment) + msg := &mdm.TokenUpdate{ + Enrollment: mdm.Enrollment{ + UDID: "test-device", + }, + } + + err := certAuthSvc.TokenUpdate(r, msg) + + // Check that the error is wrapped + if tt.expectedError != nil { + assert.NotNil(t, err, tt.description) + + // Debug output + t.Logf("Error returned: %v (type: %T)", err, err) + t.Logf("Expected error: %v", tt.expectedError) + + // Check if it's wrapped with HTTPStatusError + var statusErr *service.HTTPStatusError + if errors.As(err, &statusErr) { + // Good - it's wrapped + t.Logf("Found HTTPStatusError with status %d", statusErr.Status) + assert.True(t, errors.Is(err, tt.expectedError), "Should contain the original error") + + // Check status code + switch tt.expectedError { + case ErrMissingCert: + assert.Equal(t, http.StatusBadRequest, statusErr.Status) + case ErrNoCertAssoc, ErrNoCertReuse: + assert.Equal(t, http.StatusForbidden, statusErr.Status) + } + } else { + // For now, the error might not be wrapped yet + // Just check if it contains the expected error + t.Logf("Not wrapped with HTTPStatusError, checking if it contains expected error") + assert.True(t, errors.Is(err, tt.expectedError), "Should contain the expected error") + } + } + }) + } +} diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index 056d5a3c5e..0acd36a6f5 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -878,9 +878,9 @@ func (s *integrationMDMTestSuite) TestGetBootstrapToken() { }) checkStoredCertAuthAssociation(mdmDevice.UUID, 0) - // TODO: server returns 500 on account of cert auth but what is the expected behavior? + // server returns 403 Forbidden for cert auth failures (enrollment not associated with cert) res, err := mdmDevice.GetBootstrapToken() - require.ErrorContains(t, err, "500") // getbootstraptoken service: cert auth: existing enrollment: enrollment not associated with cert + require.ErrorContains(t, err, "403") // getbootstraptoken service: cert auth: existing enrollment: enrollment not associated with cert require.Nil(t, res) }) }