mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #36670 Refactoring `middleware/endpoint_utils` package to remove direct dependencies on: - fleet.Service - android.Service Specific changes are: - replace AuthFunc+FleetService with AuthMiddleware - Move the definition of handler functions to the respective services and use a generic `CommonEndpointer[H any] struct` Although this was discovered as part of Activity bounded context research, this change is not directly related to bounded contexts. In retrospect, this decoupling should have been done when creating the Android service for ADR-0001. ## Testing - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Internal restructuring of endpoint handling and authentication middleware composition to improve code maintainability and type safety. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
125 lines
3.1 KiB
Go
125 lines
3.1 KiB
Go
package endpoint_utils
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/go-kit/kit/endpoint"
|
|
kithttp "github.com/go-kit/kit/transport/http"
|
|
"github.com/gorilla/mux"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// testHandlerFunc is a handler function type used for testing.
|
|
type testHandlerFunc func(ctx context.Context, request any, svc any) (fleet.Errorer, error)
|
|
|
|
func TestCustomMiddlewareAfterAuth(t *testing.T) {
|
|
var (
|
|
i = 0
|
|
beforeIndex = 0
|
|
authIndex = 0
|
|
afterFirstIndex = 0
|
|
afterSecondIndex = 0
|
|
)
|
|
beforeAuthMiddleware := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
beforeIndex = i
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
authMiddleware := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
authIndex = i
|
|
if authctx, ok := authz_ctx.FromContext(ctx); ok {
|
|
authctx.SetChecked()
|
|
}
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
afterAuthMiddlewareFirst := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
afterFirstIndex = i
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
afterAuthMiddlewareSecond := func(next endpoint.Endpoint) endpoint.Endpoint {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
i++
|
|
afterSecondIndex = i
|
|
return next(ctx, req)
|
|
}
|
|
}
|
|
|
|
r := mux.NewRouter()
|
|
ce := &CommonEndpointer[testHandlerFunc]{
|
|
EP: nopEP{},
|
|
MakeDecoderFn: func(iface interface{}) kithttp.DecodeRequestFunc {
|
|
return func(ctx context.Context, r *http.Request) (request interface{}, err error) {
|
|
return nopRequest{}, nil
|
|
}
|
|
},
|
|
EncodeFn: func(ctx context.Context, w http.ResponseWriter, i interface{}) error {
|
|
w.WriteHeader(http.StatusOK)
|
|
return nil
|
|
},
|
|
AuthMiddleware: authMiddleware,
|
|
CustomMiddleware: []endpoint.Middleware{
|
|
beforeAuthMiddleware,
|
|
},
|
|
CustomMiddlewareAfterAuth: []endpoint.Middleware{
|
|
afterAuthMiddlewareFirst,
|
|
afterAuthMiddlewareSecond,
|
|
},
|
|
Router: r,
|
|
}
|
|
ce.handleEndpoint("/", func(ctx context.Context, request interface{}, svc any) (fleet.Errorer, error) {
|
|
fmt.Printf("handler\n")
|
|
return nopResponse{}, nil
|
|
}, nil, "GET")
|
|
|
|
s := httptest.NewServer(r)
|
|
t.Cleanup(func() {
|
|
s.Close()
|
|
})
|
|
|
|
req, err := http.NewRequest("GET", s.URL+"/", nil)
|
|
require.NoError(t, err)
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
resp.Body.Close()
|
|
})
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
require.Equal(t, 1, beforeIndex)
|
|
require.Equal(t, 2, authIndex)
|
|
require.Equal(t, 3, afterFirstIndex)
|
|
require.Equal(t, 4, afterSecondIndex)
|
|
}
|
|
|
|
type nopRequest struct{}
|
|
|
|
type nopResponse struct{}
|
|
|
|
func (n nopResponse) Error() error {
|
|
return nil
|
|
}
|
|
|
|
type nopEP struct{}
|
|
|
|
func (n nopEP) CallHandlerFunc(_ testHandlerFunc, _ context.Context, _ any, _ any) (fleet.Errorer, error) {
|
|
return nopResponse{}, nil
|
|
}
|
|
|
|
func (n nopEP) Service() interface{} {
|
|
return nil
|
|
}
|