From f7048b711c06761104fbec949be33c3b6e0ed1ec Mon Sep 17 00:00:00 2001 From: Lucas Manuel Rodriguez Date: Tue, 19 Apr 2022 10:35:53 -0300 Subject: [PATCH] Fix race condition in tests when using global var loginRateLimit (#5197) --- server/service/handler.go | 50 +++++++++++++++++++++++--------- server/service/testing_client.go | 4 +-- server/service/testing_utils.go | 3 +- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/server/service/handler.go b/server/service/handler.go index cc75372a0d..0f2cd7ad53 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -75,8 +75,33 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon } } +type extraHandlerOpts struct { + loginRateLimit *throttled.Rate +} + +// ExtraHandlerOption allows adding extra configuration to the HTTP handler. +type ExtraHandlerOption func(*extraHandlerOpts) + +// WithLoginRateLimit configures the rate limit for the login endpoint. +func WithLoginRateLimit(r throttled.Rate) ExtraHandlerOption { + return func(o *extraHandlerOpts) { + o.loginRateLimit = &r + } +} + // MakeHandler creates an HTTP handler for the Fleet server endpoints. -func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Logger, limitStore throttled.GCRAStore) http.Handler { +func MakeHandler( + svc fleet.Service, + config config.FleetConfig, + logger kitlog.Logger, + limitStore throttled.GCRAStore, + extra ...ExtraHandlerOption, +) http.Handler { + var eopts extraHandlerOpts + for _, fn := range extra { + fn(&eopts) + } + fleetAPIOptions := []kithttp.ServerOption{ kithttp.ServerBefore( kithttp.PopulateRequestContext, // populate the request context with common fields @@ -98,7 +123,7 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log r.Use(publicIP) - attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions) + attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions, eopts) // Results endpoint is handled different due to websockets use @@ -203,14 +228,9 @@ func addMetrics(r *mux.Router) { r.Walk(walkFn) } -var ( - // those are conceptually constants, but var so they can be changed in tests - forgotPasswordRateLimit = throttled.PerHour(10) - loginRateLimit = throttled.PerMin(10) -) - func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig, logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption, + extra extraHandlerOpts, ) { apiVersions := []string{"v1", "2022-04"} @@ -413,8 +433,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ne.WithAltPaths("/api/v1/osquery/enroll"). POST("/api/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{}) - // For some reason osquery does not provide a node key with the block data. - // Instead the carve session ID should be verified in the service method. + // For some reason osquery does not provide a node key with the block data. + // Instead the carve session ID should be verified in the service method. ne.WithAltPaths("/api/v1/osquery/carve/block"). POST("/api/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{}) @@ -429,11 +449,15 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC limiter := ratelimit.NewMiddleware(limitStore) ne. - WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: forgotPasswordRateLimit, MaxBurst: 9})). + WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})). POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{}) - ne. - WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})). + loginRateLimit := throttled.PerMin(10) + if extra.loginRateLimit != nil { + loginRateLimit = *extra.loginRateLimit + } + + ne.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})). POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{}) } diff --git a/server/service/testing_client.go b/server/service/testing_client.go index 5f71b7cc6e..e58e162c93 100644 --- a/server/service/testing_client.go +++ b/server/service/testing_client.go @@ -17,7 +17,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/throttled/throttled/v2" ) type withDS struct { @@ -46,7 +45,6 @@ type withServer struct { func (ts *withServer) SetupSuite(dbName string) { ts.withDS.SetupSuite(dbName) - loginRateLimit = throttled.PerMin(100) rs := pubsub.NewInmemQueryResults() users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, TestServerOpts{Rs: rs}) ts.server = server @@ -149,7 +147,7 @@ func (ts *withServer) getTestToken(email string, password string) string { defer resp.Body.Close() assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode) - var jsn = struct { + jsn := struct { User *fleet.User `json:"user"` Token string `json:"token"` Err []map[string]string `json:"errors,omitempty"` diff --git a/server/service/testing_utils.go b/server/service/testing_utils.go index a191ec27af..e35982f81b 100644 --- a/server/service/testing_utils.go +++ b/server/service/testing_utils.go @@ -22,6 +22,7 @@ import ( kitlog "github.com/go-kit/kit/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/throttled/throttled/v2" "github.com/throttled/throttled/v2/store/memstore" ) @@ -177,7 +178,7 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe } limitStore, _ := memstore.New(0) - r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore) + r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore, WithLoginRateLimit(throttled.PerMin(100))) server := httptest.NewServer(r) t.Cleanup(func() { server.Close()