diff --git a/server/fleet/errors.go b/server/fleet/errors.go index 35f0c4a667..b4f3374208 100644 --- a/server/fleet/errors.go +++ b/server/fleet/errors.go @@ -140,6 +140,27 @@ func (e AuthRequiredError) StatusCode() int { return http.StatusUnauthorized } +type AuthHeaderRequiredError struct { + // internal is the reason that should only be logged internally + internal string +} + +func NewAuthHeaderRequiredError(internal string) *AuthHeaderRequiredError { + return &AuthHeaderRequiredError{internal: internal} +} + +func (e AuthHeaderRequiredError) Error() string { + return "Authorization header required" +} + +func (e AuthHeaderRequiredError) Internal() string { + return e.internal +} + +func (e AuthHeaderRequiredError) StatusCode() int { + return http.StatusUnauthorized +} + // PermissionError, set when user is authenticated, but not allowed to perform action type PermissionError struct { message string diff --git a/server/service/endpoint_middleware.go b/server/service/endpoint_middleware.go index 3ae16f4eb1..c7a1f4d358 100644 --- a/server/service/endpoint_middleware.go +++ b/server/service/endpoint_middleware.go @@ -73,7 +73,7 @@ func authenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo // if not succesful, try again this time with errors sessionKey, ok := token.FromContext(ctx) if !ok { - return nil, fleet.NewAuthRequiredError("no auth token") + return nil, fleet.NewAuthHeaderRequiredError("no auth token") } v, err := authViewer(ctx, string(sessionKey), svc) diff --git a/server/service/http_auth_test.go b/server/service/http_auth_test.go index ce1bd1c9d0..472c6b1d0a 100644 --- a/server/service/http_auth_test.go +++ b/server/service/http_auth_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/go-kit/kit/transport" "io" "io/ioutil" "net/http" @@ -26,30 +27,7 @@ import ( ) func TestLogin(t *testing.T) { - ds, _ := inmem.New(config.TestConfig()) - svc := newTestService(ds, nil, nil) - users := createTestUsers(t, ds) - logger := kitlog.NewLogfmtLogger(os.Stdout) - - opts := []kithttp.ServerOption{ - kithttp.ServerBefore( - setRequestsContexts(svc), - ), - kithttp.ServerErrorLogger(logger), - kithttp.ServerAfter( - kithttp.SetContentType("application/json; charset=utf-8"), - ), - } - r := mux.NewRouter() - limitStore, _ := memstore.New(0) - ke := MakeFleetServerEndpoints(svc, "", limitStore) - kh := makeKitHandlers(ke, opts) - attachFleetAPIRoutes(r, kh) - r.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "index") - })) - - server := httptest.NewServer(r) + ds, users, server := setupAuthTest(t) var loginTests = []struct { email string status int @@ -135,6 +113,56 @@ func TestLogin(t *testing.T) { } } +func setupAuthTest(t *testing.T) (*inmem.Datastore, map[string]fleet.User, *httptest.Server) { + ds, _ := inmem.New(config.TestConfig()) + svc := newTestService(ds, nil, nil) + users := createTestUsers(t, ds) + logger := kitlog.NewLogfmtLogger(os.Stdout) + + opts := []kithttp.ServerOption{ + kithttp.ServerBefore( + setRequestsContexts(svc), + ), + kithttp.ServerErrorHandler(transport.NewLogErrorHandler(logger)), + kithttp.ServerAfter( + kithttp.SetContentType("application/json; charset=utf-8"), + ), + } + r := mux.NewRouter() + limitStore, _ := memstore.New(0) + ke := MakeFleetServerEndpoints(svc, "", limitStore) + kh := makeKitHandlers(ke, opts) + attachFleetAPIRoutes(r, kh) + r.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "index") + })) + + server := httptest.NewServer(r) + return ds, users, server +} + +func TestNoHeaderErrorsDifferently(t *testing.T) { + _, _, server := setupAuthTest(t) + + req, _ := http.NewRequest("GET", server.URL+"/api/v1/fleet/users", nil) + client := &http.Client{} + resp, err := client.Do(req) + require.Nil(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + bodyBytes, err := ioutil.ReadAll(resp.Body) + require.Nil(t, err) + assert.Equal(t, "Authorization header required", string(bodyBytes)) + + req, _ = http.NewRequest("GET", server.URL+"/api/v1/fleet/users", nil) + req.Header.Add("Authorization", "Bearer AAAA") + resp, err = client.Do(req) + require.Nil(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + bodyBytes, err = ioutil.ReadAll(resp.Body) + require.Nil(t, err) + assert.Equal(t, "Authentication required", string(bodyBytes)) +} + // an io.ReadCloser for new request body type nopCloser struct { io.Reader diff --git a/server/service/middleware/authzcheck/authzcheck.go b/server/service/middleware/authzcheck/authzcheck.go index e2ee4dd6f0..d8e79000b9 100644 --- a/server/service/middleware/authzcheck/authzcheck.go +++ b/server/service/middleware/authzcheck/authzcheck.go @@ -34,7 +34,11 @@ func (m *Middleware) AuthzCheck() endpoint.Middleware { // appropriately). var authFailedError *fleet.AuthFailedError var authRequiredError *fleet.AuthRequiredError - if errors.As(err, &authFailedError) || errors.As(err, &authRequiredError) || errors.Is(err, fleet.ErrPasswordResetRequired) { + var authHeaderRequiredError *fleet.AuthHeaderRequiredError + if errors.As(err, &authFailedError) || + errors.As(err, &authRequiredError) || + errors.As(err, &authHeaderRequiredError) || + errors.Is(err, fleet.ErrPasswordResetRequired) { return nil, err }