diff --git a/server/service/middleware/authzcheck/authzcheck.go b/server/service/middleware/authzcheck/authzcheck.go index dc2a7939c6..4ef50cc613 100644 --- a/server/service/middleware/authzcheck/authzcheck.go +++ b/server/service/middleware/authzcheck/authzcheck.go @@ -6,11 +6,13 @@ package authzcheck import ( "context" + "errors" "reflect" "runtime" "github.com/fleetdm/fleet/server/authz" authz_ctx "github.com/fleetdm/fleet/server/contexts/authz" + "github.com/fleetdm/fleet/server/kolide" "github.com/go-kit/kit/endpoint" ) @@ -28,7 +30,15 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware { authzctx := &authz_ctx.AuthorizationContext{} ctx = authz_ctx.NewContext(ctx, authzctx) - response, error := next(ctx, req) + response, err := next(ctx, req) + + // If authentication check failed, return that error (so that we log + // appropriately). + var authFailedError *kolide.AuthFailedError + var authRequiredError *kolide.AuthRequiredError + if errors.As(err, &authFailedError) || errors.As(err, &authRequiredError) { + return nil, err + } // If authorization was not checked, return a response that will // marshal to a generic error and log that the check was missed. @@ -37,7 +47,7 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware { return nil, authz.ForbiddenWithInternal("missed authz check: " + funcName) } - return response, error + return response, err } } } diff --git a/server/service/middleware/authzcheck/authzcheck_test.go b/server/service/middleware/authzcheck/authzcheck_test.go index c64e337181..50f96df346 100644 --- a/server/service/middleware/authzcheck/authzcheck_test.go +++ b/server/service/middleware/authzcheck/authzcheck_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/fleetdm/fleet/server/contexts/authz" + "github.com/fleetdm/fleet/server/kolide" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,6 +27,36 @@ func TestAuthzCheck(t *testing.T) { assert.NoError(t, err) } +func TestAuthzCheckAuthFailed(t *testing.T) { + t.Parallel() + + checker := NewMiddleware() + + check := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, kolide.NewAuthFailedError("failed") + } + check = checker.AuthzCheck()(check) + + _, err := check(context.Background(), struct{}{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed") +} + +func TestAuthzCheckAuthRequired(t *testing.T) { + t.Parallel() + + checker := NewMiddleware() + + check := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, kolide.NewAuthRequiredError("required") + } + check = checker.AuthzCheck()(check) + + _, err := check(context.Background(), struct{}{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "required") +} + func TestAuthzCheckMissing(t *testing.T) { t.Parallel()