mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 17:08:53 +00:00
Fix MDM cert auth 500 errors (#32981)
Fixes #30958. Maps certificate authentication errors to proper HTTP status codes (400/403) instead of 500.
This commit is contained in:
parent
fccdd8c152
commit
ed3e755641
4 changed files with 437 additions and 8 deletions
|
|
@ -9,7 +9,12 @@ import (
|
|||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service/certauth"
|
||||
"github.com/micromdm/nanolib/log"
|
||||
"github.com/micromdm/plist"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
|
@ -66,3 +71,161 @@ func TestCertWithEnrollmentIDMiddleware(t *testing.T) {
|
|||
t.Error("body not equal")
|
||||
}
|
||||
}
|
||||
|
||||
// mockCertAuthService simulates certificate auth errors
|
||||
type mockCertAuthService struct {
|
||||
authenticateErr error
|
||||
tokenUpdateErr error
|
||||
commandErr error
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) Authenticate(r *mdm.Request, msg *mdm.Authenticate) error {
|
||||
return m.authenticateErr
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) TokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error {
|
||||
return m.tokenUpdateErr
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) CheckOut(r *mdm.Request, msg *mdm.CheckOut) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) UserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) SetBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) GetBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) DeclarativeManagement(r *mdm.Request, msg *mdm.DeclarativeManagement) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) GetToken(r *mdm.Request, msg *mdm.GetToken) (*mdm.GetTokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthService) CommandAndReportResults(r *mdm.Request, results *mdm.CommandResults) (*mdm.Command, error) {
|
||||
return nil, m.commandErr
|
||||
}
|
||||
|
||||
// TestCheckinAndCommandHandler_ErrorHandling verifies handlers return HTTP status codes for errors
|
||||
func TestCheckinAndCommandHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wrapError bool // if true, wrap with HTTPStatusError
|
||||
serviceError error
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Unwrapped_CertAuth_Error_Returns_500",
|
||||
wrapError: false,
|
||||
serviceError: certauth.ErrNoCertAssoc,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Wrapped_CertAuth_Error_Returns_403",
|
||||
wrapError: true,
|
||||
serviceError: certauth.ErrNoCertAssoc,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Wrapped_Missing_Cert_Returns_400",
|
||||
wrapError: true,
|
||||
serviceError: certauth.ErrMissingCert,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for _, isCheckin := range []bool{true, false} {
|
||||
subtest := "Command"
|
||||
if isCheckin {
|
||||
subtest = "Checkin"
|
||||
}
|
||||
t.Run(subtest, func(t *testing.T) {
|
||||
// Setup error
|
||||
err := tt.serviceError
|
||||
if tt.wrapError {
|
||||
if errors.Is(tt.serviceError, certauth.ErrNoCertAssoc) || errors.Is(tt.serviceError, certauth.ErrNoCertReuse) {
|
||||
err = service.NewHTTPStatusError(http.StatusForbidden, tt.serviceError)
|
||||
} else if errors.Is(tt.serviceError, certauth.ErrMissingCert) {
|
||||
err = service.NewHTTPStatusError(http.StatusBadRequest, tt.serviceError)
|
||||
}
|
||||
}
|
||||
|
||||
mockSvc := &mockCertAuthService{
|
||||
tokenUpdateErr: err,
|
||||
commandErr: err,
|
||||
}
|
||||
|
||||
// Create handler and request
|
||||
var handler http.HandlerFunc
|
||||
var body []byte
|
||||
var contentType string
|
||||
|
||||
if isCheckin {
|
||||
handler = CheckinHandler(mockSvc, log.NopLogger)
|
||||
tokenUpdate := &mdm.TokenUpdate{
|
||||
Enrollment: mdm.Enrollment{UDID: "test-udid"},
|
||||
MessageType: mdm.MessageType{MessageType: "TokenUpdate"},
|
||||
}
|
||||
body, err = plist.Marshal(tokenUpdate)
|
||||
require.NoError(t, err)
|
||||
contentType = "application/x-apple-aspen-mdm-checkin"
|
||||
} else {
|
||||
handler = CommandAndReportResultsHandler(mockSvc, log.NopLogger)
|
||||
cmdResults := &mdm.CommandResults{
|
||||
Enrollment: mdm.Enrollment{UDID: "test-udid"},
|
||||
CommandUUID: "test-cmd-uuid",
|
||||
Status: "Acknowledged",
|
||||
}
|
||||
body, err = plist.Marshal(cmdResults)
|
||||
require.NoError(t, err)
|
||||
contentType = "application/x-apple-aspen-mdm"
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/mdm", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rr.Code,
|
||||
"Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResponseBody verifies error response bodies are correct
|
||||
func TestErrorResponseBody(t *testing.T) {
|
||||
mockSvc := &mockCertAuthService{
|
||||
tokenUpdateErr: service.NewHTTPStatusError(http.StatusForbidden, certauth.ErrNoCertAssoc),
|
||||
}
|
||||
|
||||
handler := CheckinHandler(mockSvc, log.NopLogger)
|
||||
tokenUpdate := &mdm.TokenUpdate{
|
||||
Enrollment: mdm.Enrollment{UDID: "test-udid"},
|
||||
MessageType: mdm.MessageType{MessageType: "TokenUpdate"},
|
||||
}
|
||||
body, err := plist.Marshal(tokenUpdate)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/mdm", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/x-apple-aspen-mdm-checkin")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
assert.Equal(t, "Forbidden\n", rr.Body.String())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service"
|
||||
|
|
@ -106,9 +107,31 @@ func HashCert(cert *x509.Certificate) string {
|
|||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// wrapCertAuthError wraps certificate authentication errors with appropriate HTTP status codes
|
||||
func wrapCertAuthError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, ErrNoCertAssoc):
|
||||
// cert not associated - authentication/authorization failure
|
||||
return service.NewHTTPStatusError(http.StatusForbidden, err)
|
||||
case errors.Is(err, ErrNoCertReuse):
|
||||
// cert reuse attempt - authentication/authorization failure
|
||||
return service.NewHTTPStatusError(http.StatusForbidden, err)
|
||||
case errors.Is(err, ErrMissingCert):
|
||||
// missing cert - bad request
|
||||
return service.NewHTTPStatusError(http.StatusBadRequest, err)
|
||||
default:
|
||||
// Don't wrap other errors - let them bubble up as-is
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error {
|
||||
if r.Certificate == nil {
|
||||
return ErrMissingCert
|
||||
return wrapCertAuthError(ErrMissingCert)
|
||||
}
|
||||
if err := r.EnrollID.Validate(); err != nil {
|
||||
return err
|
||||
|
|
@ -139,7 +162,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error {
|
|||
"hash", hash,
|
||||
)
|
||||
if !s.warnOnly {
|
||||
return ErrNoCertReuse
|
||||
return wrapCertAuthError(ErrNoCertReuse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -157,7 +180,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error {
|
|||
|
||||
func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error {
|
||||
if r.Certificate == nil {
|
||||
return ErrMissingCert
|
||||
return wrapCertAuthError(ErrMissingCert)
|
||||
}
|
||||
if err := r.EnrollID.Validate(); err != nil {
|
||||
return err
|
||||
|
|
@ -177,7 +200,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error {
|
|||
"hash", hash,
|
||||
)
|
||||
if !s.warnOnly {
|
||||
return ErrNoCertAssoc
|
||||
return wrapCertAuthError(ErrNoCertAssoc)
|
||||
}
|
||||
}
|
||||
// even if allowRetroactive is true we don't want to allow arbitrary
|
||||
|
|
@ -193,7 +216,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error {
|
|||
"id", r.ID,
|
||||
)
|
||||
if !s.warnOnly {
|
||||
return ErrNoCertReuse
|
||||
return wrapCertAuthError(ErrNoCertReuse)
|
||||
}
|
||||
}
|
||||
// even if allowDup were true we don't want to allow arbitrary
|
||||
|
|
@ -211,7 +234,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error {
|
|||
"hash", hash,
|
||||
)
|
||||
if !s.warnOnly {
|
||||
return ErrNoCertReuse
|
||||
return wrapCertAuthError(ErrNoCertReuse)
|
||||
}
|
||||
}
|
||||
if s.warnOnly {
|
||||
|
|
|
|||
243
server/mdm/nanomdm/service/certauth/certauth_test.go
Normal file
243
server/mdm/nanomdm/service/certauth/certauth_test.go
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
package certauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// wrapCertAuthError wraps errors with HTTP status codes
|
||||
func TestWrapCertAuthError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputError error
|
||||
expectedStatus int
|
||||
shouldWrap bool
|
||||
}{
|
||||
{
|
||||
name: "ErrNoCertAssoc_Returns_403",
|
||||
inputError: ErrNoCertAssoc,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
shouldWrap: true,
|
||||
},
|
||||
{
|
||||
name: "ErrNoCertReuse_Returns_403",
|
||||
inputError: ErrNoCertReuse,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
shouldWrap: true,
|
||||
},
|
||||
{
|
||||
name: "ErrMissingCert_Returns_400",
|
||||
inputError: ErrMissingCert,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
shouldWrap: true,
|
||||
},
|
||||
{
|
||||
name: "Other_Error_Not_Wrapped",
|
||||
inputError: errors.New("some other error"),
|
||||
expectedStatus: 0,
|
||||
shouldWrap: false,
|
||||
},
|
||||
{
|
||||
name: "Nil_Error_Returns_Nil",
|
||||
inputError: nil,
|
||||
expectedStatus: 0,
|
||||
shouldWrap: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := wrapCertAuthError(tt.inputError)
|
||||
|
||||
if tt.inputError == nil {
|
||||
assert.Nil(t, result)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.shouldWrap {
|
||||
var statusErr *service.HTTPStatusError
|
||||
assert.True(t, errors.As(result, &statusErr), "Expected HTTPStatusError wrapper")
|
||||
if statusErr != nil {
|
||||
assert.Equal(t, tt.expectedStatus, statusErr.Status, "Expected status %d, got %d", tt.expectedStatus, statusErr.Status)
|
||||
assert.True(t, errors.Is(result, tt.inputError), "Original error should be preserved")
|
||||
}
|
||||
} else {
|
||||
// Should return the error unchanged
|
||||
assert.Equal(t, tt.inputError, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockInnerService for testing
|
||||
type mockInnerService struct {
|
||||
authenticateErr error
|
||||
tokenUpdateErr error
|
||||
}
|
||||
|
||||
func (m *mockInnerService) Authenticate(r *mdm.Request, msg *mdm.Authenticate) error {
|
||||
return m.authenticateErr
|
||||
}
|
||||
|
||||
func (m *mockInnerService) TokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error {
|
||||
return m.tokenUpdateErr
|
||||
}
|
||||
|
||||
func (m *mockInnerService) CheckOut(r *mdm.Request, msg *mdm.CheckOut) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInnerService) UserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInnerService) SetBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInnerService) GetBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInnerService) DeclarativeManagement(r *mdm.Request, msg *mdm.DeclarativeManagement) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInnerService) GetToken(r *mdm.Request, msg *mdm.GetToken) (*mdm.GetTokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInnerService) CommandAndReportResults(r *mdm.Request, results *mdm.CommandResults) (*mdm.Command, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mockCertAuthStore for testing
|
||||
type mockCertAuthStore struct {
|
||||
hasCertHash bool
|
||||
isAssociated bool
|
||||
enrollmentHasHash bool
|
||||
}
|
||||
|
||||
func (m *mockCertAuthStore) HasCertHash(r *mdm.Request, hash string) (bool, error) {
|
||||
return m.hasCertHash, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthStore) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) {
|
||||
return m.isAssociated, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthStore) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) {
|
||||
return m.enrollmentHasHash, nil
|
||||
}
|
||||
|
||||
func (m *mockCertAuthStore) AssociateCertHash(r *mdm.Request, hash string, exp time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// certauth service returns wrapped errors
|
||||
func TestCertAuthService_ErrorWrapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupStore func(*mockCertAuthStore)
|
||||
includeCert bool
|
||||
expectedError error
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Missing_Certificate",
|
||||
setupStore: func(s *mockCertAuthStore) {
|
||||
// No setup needed
|
||||
},
|
||||
includeCert: false,
|
||||
expectedError: ErrMissingCert,
|
||||
description: "Missing certificate should return ErrMissingCert",
|
||||
},
|
||||
{
|
||||
name: "No_Certificate_Association",
|
||||
setupStore: func(s *mockCertAuthStore) {
|
||||
s.hasCertHash = false
|
||||
s.isAssociated = false
|
||||
},
|
||||
includeCert: true,
|
||||
expectedError: ErrNoCertAssoc,
|
||||
description: "No certificate association should return ErrNoCertAssoc",
|
||||
},
|
||||
// N.b., certificate reuse scenario is complex and depends on multiple conditions
|
||||
// The actual error returned depends on the order of checks in validateAssociateExistingEnrollment
|
||||
// This is tested more thoroughly in integration tests
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock store
|
||||
store := &mockCertAuthStore{}
|
||||
tt.setupStore(store)
|
||||
|
||||
// Create certauth service
|
||||
innerSvc := &mockInnerService{}
|
||||
certAuthSvc := New(innerSvc, store)
|
||||
|
||||
// Create request
|
||||
r := &mdm.Request{
|
||||
Context: context.Background(),
|
||||
EnrollID: &mdm.EnrollID{
|
||||
Type: mdm.Device,
|
||||
ID: "test-device",
|
||||
},
|
||||
}
|
||||
|
||||
if tt.includeCert {
|
||||
r.Certificate = &x509.Certificate{
|
||||
Raw: []byte("test-cert"),
|
||||
}
|
||||
}
|
||||
|
||||
// Test with TokenUpdate (existing enrollment)
|
||||
msg := &mdm.TokenUpdate{
|
||||
Enrollment: mdm.Enrollment{
|
||||
UDID: "test-device",
|
||||
},
|
||||
}
|
||||
|
||||
err := certAuthSvc.TokenUpdate(r, msg)
|
||||
|
||||
// Check that the error is wrapped
|
||||
if tt.expectedError != nil {
|
||||
assert.NotNil(t, err, tt.description)
|
||||
|
||||
// Debug output
|
||||
t.Logf("Error returned: %v (type: %T)", err, err)
|
||||
t.Logf("Expected error: %v", tt.expectedError)
|
||||
|
||||
// Check if it's wrapped with HTTPStatusError
|
||||
var statusErr *service.HTTPStatusError
|
||||
if errors.As(err, &statusErr) {
|
||||
// Good - it's wrapped
|
||||
t.Logf("Found HTTPStatusError with status %d", statusErr.Status)
|
||||
assert.True(t, errors.Is(err, tt.expectedError), "Should contain the original error")
|
||||
|
||||
// Check status code
|
||||
switch tt.expectedError {
|
||||
case ErrMissingCert:
|
||||
assert.Equal(t, http.StatusBadRequest, statusErr.Status)
|
||||
case ErrNoCertAssoc, ErrNoCertReuse:
|
||||
assert.Equal(t, http.StatusForbidden, statusErr.Status)
|
||||
}
|
||||
} else {
|
||||
// For now, the error might not be wrapped yet
|
||||
// Just check if it contains the expected error
|
||||
t.Logf("Not wrapped with HTTPStatusError, checking if it contains expected error")
|
||||
assert.True(t, errors.Is(err, tt.expectedError), "Should contain the expected error")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -878,9 +878,9 @@ func (s *integrationMDMTestSuite) TestGetBootstrapToken() {
|
|||
})
|
||||
checkStoredCertAuthAssociation(mdmDevice.UUID, 0)
|
||||
|
||||
// TODO: server returns 500 on account of cert auth but what is the expected behavior?
|
||||
// server returns 403 Forbidden for cert auth failures (enrollment not associated with cert)
|
||||
res, err := mdmDevice.GetBootstrapToken()
|
||||
require.ErrorContains(t, err, "500") // getbootstraptoken service: cert auth: existing enrollment: enrollment not associated with cert
|
||||
require.ErrorContains(t, err, "403") // getbootstraptoken service: cert auth: existing enrollment: enrollment not associated with cert
|
||||
require.Nil(t, res)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue