diff --git a/changes/sso-ratelimit b/changes/sso-ratelimit new file mode 100644 index 0000000000..ad91f5af42 --- /dev/null +++ b/changes/sso-ratelimit @@ -0,0 +1 @@ +* Added `FLEET_MDM_SSO_RATE_LIMIT_PER_MINUTE` environment variable to allow increasing MDM SSO endpoint rate limit from 10 per minute. diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index 22f222b713..69cdd6e707 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -71,6 +71,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" + "github.com/throttled/throttled/v2" "go.elastic.co/apm/module/apmhttp/v2" _ "go.elastic.co/apm/module/apmsql/v2" _ "go.elastic.co/apm/module/apmsql/v2/mysql" @@ -1086,8 +1087,13 @@ the way that the Fleet server works. frontendHandler = service.WithMDMEnrollmentMiddleware(svc, httpLogger, frontendHandler) + var extra []service.ExtraHandlerOption + if config.MDM.SSORateLimitPerMinute > 0 { + extra = append(extra, service.WithMdmSsoRateLimit(throttled.PerMin(config.MDM.SSORateLimitPerMinute))) + } + apiHandler = service.MakeHandler(svc, config, httpLogger, limiterStore, - []endpoint_utils.HandlerRoutesFunc{android_service.GetRoutes(svc, androidSvc)}) + []endpoint_utils.HandlerRoutesFunc{android_service.GetRoutes(svc, androidSvc)}, extra...) setupRequired, err := svc.SetupRequired(baseCtx) if err != nil { diff --git a/server/config/config.go b/server/config/config.go index 0e38a10fc3..2335605ce0 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -695,6 +695,8 @@ type MDMConfig struct { microsoftWSTEP *tls.Certificate microsoftWSTEPCertPEM []byte microsoftWSTEPKeyPEM []byte + + SSORateLimitPerMinute int `yaml:"sso_rate_limit_per_minute"` } type CalendarConfig struct { @@ -1405,6 +1407,7 @@ func (man Manager) addConfigs() { man.addConfigString("mdm.windows_wstep_identity_key", "", "Microsoft WSTEP PEM-encoded private key path") man.addConfigString("mdm.windows_wstep_identity_cert_bytes", "", "Microsoft WSTEP PEM-encoded certificate bytes") man.addConfigString("mdm.windows_wstep_identity_key_bytes", "", "Microsoft WSTEP PEM-encoded private key bytes") + man.addConfigInt("mdm.sso_rate_limit_per_minute", 10, "Number of allowed requests per minute to MDM SSO endpoints") // Calendar integration man.addConfigDuration( @@ -1689,6 +1692,7 @@ func (man Manager) LoadConfig() FleetConfig { WindowsWSTEPIdentityKey: man.getConfigString("mdm.windows_wstep_identity_key"), WindowsWSTEPIdentityCertBytes: man.getConfigString("mdm.windows_wstep_identity_cert_bytes"), WindowsWSTEPIdentityKeyBytes: man.getConfigString("mdm.windows_wstep_identity_key_bytes"), + SSORateLimitPerMinute: man.getConfigInt("mdm.sso_rate_limit_per_minute"), }, Calendar: CalendarConfig{ Periodicity: man.getConfigDuration("calendar.periodicity"), diff --git a/server/service/handler.go b/server/service/handler.go index 307e1853f3..d3ba046506 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -58,19 +58,27 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon } type extraHandlerOpts struct { - loginRateLimit *throttled.Rate + loginRateLimit *throttled.Rate + mdmSsoRateLimit *throttled.Rate } // ExtraHandlerOption allows adding extra configuration to the HTTP handler. type ExtraHandlerOption func(*extraHandlerOpts) -// WithLoginRateLimit configures the rate limit for the login endpoint. +// WithLoginRateLimit configures the rate limit for the login endpoints. func WithLoginRateLimit(r throttled.Rate) ExtraHandlerOption { return func(o *extraHandlerOpts) { o.loginRateLimit = &r } } +// WithMdmSsoRateLimit configures the rate limit for the MDM SSO endpoints (falls back to login rate limit otherwise). +func WithMdmSsoRateLimit(r throttled.Rate) ExtraHandlerOption { + return func(o *extraHandlerOpts) { + o.mdmSsoRateLimit = &r + } +} + // MakeHandler creates an HTTP handler for the Fleet server endpoints. func MakeHandler( svc fleet.Service, @@ -213,6 +221,7 @@ func addMetrics(r *mux.Router) { const ( desktopRateLimitMaxBurst = 100 // Max burst used for device request rate limiting. forgotPasswordRateLimitMaxBurst = 9 // Max burst used for rate limiting on the the forgot_password endpoint. + DefaultLoginRateLimit = 10 // Normal per-minute rate limit for logins (and MDM SSO if not overridden) ) func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig, @@ -979,10 +988,14 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC WithCustomMiddleware(limiter.Limit("forgot_password", quota)). POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{}) - loginRateLimit := throttled.PerMin(10) + loginRateLimit := throttled.PerMin(DefaultLoginRateLimit) if extra.loginRateLimit != nil { loginRateLimit = *extra.loginRateLimit } + mdmSsoRateLimit := loginRateLimit + if extra.mdmSsoRateLimit != nil { + mdmSsoRateLimit = *extra.mdmSsoRateLimit + } ne.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})). POST("/api/_version_/fleet/login", loginEndpoint, contract.LoginRequest{}) @@ -996,10 +1009,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC // This is a callback endpoint for calendar integration -- it is called to notify an event change in a user calendar ne.POST("/api/_version_/fleet/calendar/webhook/{event_uuid}", calendarWebhookEndpoint, calendarWebhookRequest{}) - neAppleMDM.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})). + neAppleMDM.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: mdmSsoRateLimit, MaxBurst: 9})). POST("/api/_version_/fleet/mdm/sso", initiateMDMAppleSSOEndpoint, initiateMDMAppleSSORequest{}) - neAppleMDM.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})). + neAppleMDM.WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: mdmSsoRateLimit, MaxBurst: 9})). POST("/api/_version_/fleet/mdm/sso/callback", callbackMDMAppleSSOEndpoint, callbackMDMAppleSSORequest{}) }