diff --git a/changes/31592-improve-offline-indicator b/changes/31592-improve-offline-indicator new file mode 100644 index 0000000000..84aefcbbf4 --- /dev/null +++ b/changes/31592-improve-offline-indicator @@ -0,0 +1,2 @@ +* Fixed invalid rate limiting applied on Fleet Desktop requests for which a public IP could not be determined. +* Improved public IP extraction for Fleet Desktop requests. diff --git a/orbit/changes/31592-improve-offline-indicator b/orbit/changes/31592-improve-offline-indicator new file mode 100644 index 0000000000..17937a6d20 --- /dev/null +++ b/orbit/changes/31592-improve-offline-indicator @@ -0,0 +1 @@ +* Fixed offline indicator to be less sensitive to transient network failures and faster recovery when connectivity is restored. diff --git a/orbit/cmd/desktop/desktop.go b/orbit/cmd/desktop/desktop.go index ce6d5bfa1b..4158e385ab 100644 --- a/orbit/cmd/desktop/desktop.go +++ b/orbit/cmd/desktop/desktop.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "strings" + "sync/atomic" "syscall" "time" @@ -140,10 +141,16 @@ func main() { var swiftDialogCh chan struct{} var offlineWatcher useraction.MDMOfflineWatcher - // This ticker is used for fetching the desktop summary. It is initialized here because it is + // We will execute the summary API every 5 minutes (to refresh policy state). + const desktopSummaryInterval = 5 * time.Minute + + // This ticker is used for checking connectivity. It is initialized here because it is // stopped in `OnExit.` - const checkInterval = 5 * time.Minute - summaryTicker := time.NewTicker(checkInterval) + const pingInterval = 10 * time.Second // same value as default distributed/read + pingTicker := time.NewTicker(pingInterval) + + // Used to trigger a policy check when clicking on "My device" or "About Fleet". + var fleetDesktopCheckTrigger atomic.Bool // we have seen some cases where systray.Run() does not call onReady seemingly due to early // initialization states with the GUI such as Windows Autopilot first time setup. This ensures @@ -181,16 +188,12 @@ func main() { // immediately begin showing the migrator again if we were showing it prior. showMDMMigrator := false - myDeviceItem := systray.AddMenuItem("Connecting...", "") + myDeviceItem := systray.AddMenuItem("My device", "") myDeviceItem.Disable() myDeviceItem.Hide() - // We are doing this using two menu items because line breaks - // are not rendered correctly on Windows and MacOS. - hostOfflineItemOne := systray.AddMenuItem("🛜🚫 Your computer is offline.", "") - hostOfflineItemTwo := systray.AddMenuItem("It might take up to 5 minutes to reconnect to Fleet.", "") + hostOfflineItemOne := systray.AddMenuItem("🛜🚫 Your computer is not connected to Fleet.", "") hostOfflineItemOne.Disable() - hostOfflineItemTwo.Disable() selfServiceItem := systray.AddMenuItem("Self-service", "") selfServiceItem.Disable() @@ -234,8 +237,8 @@ func main() { return newToken }) - disableTray := func() { - log.Debug().Msg("disabling tray items") + showConnecting := func() { + log.Debug().Msg("displaying Connecting...") myDeviceItem.SetTitle("Connecting...") myDeviceItem.Show() myDeviceItem.Disable() @@ -252,7 +255,6 @@ func main() { } hostOfflineItemOne.Hide() - hostOfflineItemTwo.Hide() } reportError := func(err error, info map[string]any) { @@ -296,6 +298,7 @@ func main() { // checkToken performs API test calls to enable the "My device" item as // soon as the device auth token is registered by Fleet. checkToken := func() <-chan interface{} { + showConnecting() done := make(chan interface{}) go func() { @@ -317,7 +320,6 @@ func main() { transparencyItem.Show() hostOfflineItemOne.Hide() - hostOfflineItemTwo.Hide() // Hide Self-Service for Free tier if errors.Is(err, service.ErrMissingLicense) || (summary.SelfService != nil && !*summary.SelfService) { @@ -364,51 +366,92 @@ func main() { case err != nil: log.Error().Err(err).Msg("check token file") case expired: - log.Info().Msg("token file changed, rechecking") - disableTray() + log.Info().Msg("token file expired or invalid, rechecking") <-checkToken() } } }() // poll the server to check the policy status of the host and update the - // tray icon accordingly + // tray icon accordingly. + // We first ping the server to check for connectivity, then get the policy status (every 5 minutes to + // not cause performance issues on the server). go func() { <-deviceEnabledChan - for { - <-summaryTicker.C - // Reset the ticker to the intended interval, in case we reset it to 1ms - summaryTicker.Reset(checkInterval) - sum, err := client.DesktopSummary(tokenReader.GetCached()) - switch { - case err == nil: - hostOfflineItemOne.Hide() - hostOfflineItemTwo.Hide() - case errors.Is(err, service.ErrMissingLicense): - myDeviceItem.SetTitle("My device") - myDeviceItem.Show() - hostOfflineItemOne.Hide() - hostOfflineItemTwo.Hide() - continue - case errors.Is(err, service.ErrUnauthenticated): - disableTray() - hostOfflineItemOne.Hide() - hostOfflineItemTwo.Hide() - <-checkToken() - continue - default: + var ( + pingErrCount = 0 + lastDesktopSummaryCheck time.Time + offlineIndicatorDisplayed = false + showOffline = func() { myDeviceItem.Hide() transparencyItem.Disable() transparencyItem.Hide() migrateMDMItem.Disable() migrateMDMItem.Hide() hostOfflineItemOne.Show() - hostOfflineItemTwo.Show() - log.Error().Err(err).Msg("get desktop summary") + selfServiceItem.Disable() + selfServiceItem.Hide() + offlineIndicatorDisplayed = true + } + ) + + for { + <-pingTicker.C + + // Reset the ticker to the intended interval, + // in case we reset it to 1ms (when clicking on "My device"). + pingTicker.Reset(pingInterval) + + if err := client.Ping(); err != nil { + log.Error().Err(err).Int("count", pingErrCount).Msg("ping failed") + pingErrCount++ + // We try 5 more times to make sure one bad request doesn't trigger the offline indicator. + // So it might take up to ~1m (6 * 10s) for Fleet Desktop to show the offline indicator. + if pingErrCount >= 6 { + showOffline() + } continue } + // Successfully connected to Fleet. + pingErrCount = 0 + + // Check if we need to fetch the "Fleet desktop" summary from Fleet. + if !offlineIndicatorDisplayed && + !fleetDesktopCheckTrigger.Load() && + (!lastDesktopSummaryCheck.IsZero() && time.Since(lastDesktopSummaryCheck) < desktopSummaryInterval) { + continue + } + + lastDesktopSummaryCheck = time.Now() + fleetDesktopCheckTrigger.Store(false) + // We set offlineIndicatorDisplayed to false because we do not want to retry the + // Fleet Desktop summary every 10s if Ping works but DesktopSummary doesn't + // (to avoid server load issues). + offlineIndicatorDisplayed = false + + sum, err := client.DesktopSummary(tokenReader.GetCached()) + if err != nil { + switch { + case errors.Is(err, service.ErrMissingLicense): + // Policy reporting in Fleet Desktop requires a license, + // so we just show the "My device" item as usual. + myDeviceItem.SetTitle("My device") + myDeviceItem.Show() + hostOfflineItemOne.Hide() + case errors.Is(err, service.ErrUnauthenticated): + log.Debug().Err(err).Msg("get desktop summary auth failure") + // This usually happens every ~1 hour when the token expires. + <-checkToken() + default: + log.Error().Err(err).Msg("get desktop summary failed") + } + continue + } + + hostOfflineItemOne.Hide() + refreshMenuItems(sum.DesktopSummary, selfServiceItem, myDeviceItem) myDeviceItem.Enable() myDeviceItem.Show() @@ -509,7 +552,8 @@ func main() { log.Error().Err(err).Str("url", openURL).Msg("open browser policies") } // Also refresh the device status by forcing the polling ticker to fire - summaryTicker.Reset(1 * time.Millisecond) + fleetDesktopCheckTrigger.Store(true) + pingTicker.Reset(1 * time.Millisecond) case <-transparencyItem.ClickedCh: openURL := client.BrowserTransparencyURL(tokenReader.GetCached()) if err := open.Browser(openURL); err != nil { @@ -521,7 +565,8 @@ func main() { log.Error().Err(err).Str("url", openURL).Msg("open browser self-service") } // Also refresh the device status by forcing the polling ticker to fire - summaryTicker.Reset(1 * time.Millisecond) + fleetDesktopCheckTrigger.Store(true) + pingTicker.Reset(1 * time.Millisecond) case <-migrateMDMItem.ClickedCh: if offline := offlineWatcher.ShowIfOffline(offlineWatcherCtx); offline { continue @@ -548,8 +593,8 @@ func main() { log.Debug().Err(err).Msg("exiting swiftDialogCh") close(swiftDialogCh) } - log.Debug().Msg("stopping ticker") - summaryTicker.Stop() + log.Debug().Msg("stopping ping ticker") + pingTicker.Stop() log.Debug().Msg("canceling offline watcher ctx") cancelOfflineWatcherCtx() } diff --git a/orbit/cmd/orbit/orbit.go b/orbit/cmd/orbit/orbit.go index 9bd5219408..608bcafed3 100644 --- a/orbit/cmd/orbit/orbit.go +++ b/orbit/cmd/orbit/orbit.go @@ -1248,9 +1248,9 @@ func main() { // This is better than using a ticker that ticks every hour because the // we can't ensure the tick actually runs every hour (eg: the computer is // asleep). - rotationDuration := 30 * time.Second - rotationTicker := time.NewTicker(rotationDuration) - defer rotationTicker.Stop() + localCheckDuration := 30 * time.Second + localCheckTicker := time.NewTicker(localCheckDuration) + defer localCheckTicker.Stop() // This timer is used to periodically check if the token is valid. The // server might deem a toked as invalid for reasons out of our control, @@ -1262,10 +1262,10 @@ func main() { for { select { - case <-rotationTicker.C: - rotationTicker.Reset(rotationDuration) + case <-localCheckTicker.C: + localCheckTicker.Reset(localCheckDuration) - log.Debug().Msgf("checking if token has changed or expired, cached mtime: %s", trw.GetMtime()) + log.Debug().Msgf("initiating local token check, cached mtime: %s", trw.GetMtime()) hasChanged, err := trw.HasChanged() if err != nil { log.Error().Err(err).Msg("error checking if token has changed") @@ -1281,10 +1281,10 @@ func main() { if err := trw.Rotate(); err != nil { log.Error().Err(err).Msg("error rotating token") } - } else if remain > 0 && remain < rotationDuration { + } else if remain > 0 && remain < localCheckDuration { // check again when the token will expire, which will happen // before the next rotation check - rotationTicker.Reset(remain) + localCheckTicker.Reset(remain) log.Debug().Msgf("token will expire soon, checking again in: %s", remain) } diff --git a/server/service/handler.go b/server/service/handler.go index c5964f41a6..b8fff0aa4e 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -801,68 +801,68 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC // properly desktopQuota := throttled.RateQuota{MaxRate: throttled.PerHour(720), MaxBurst: desktopRateLimitMaxBurst} de.WithCustomMiddleware( - errorLimiter.Limit("get_device_host", desktopQuota), + errorLimiter.Limit("get_device_host", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_fleet_desktop", desktopQuota), + errorLimiter.Limit("get_fleet_desktop", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/desktop", getFleetDesktopEndpoint, getFleetDesktopRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("ping_device_auth", desktopQuota), + errorLimiter.Limit("ping_device_auth", desktopQuota, logger), ).HEAD("/api/_version_/fleet/device/{token}/ping", devicePingEndpoint, deviceAuthPingRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("refetch_device_host", desktopQuota), + errorLimiter.Limit("refetch_device_host", desktopQuota, logger), ).POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_mapping", desktopQuota), + errorLimiter.Limit("get_device_mapping", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_macadmins", desktopQuota), + errorLimiter.Limit("get_device_macadmins", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_policies", desktopQuota), + errorLimiter.Limit("get_device_policies", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_transparency", desktopQuota), + errorLimiter.Limit("get_device_transparency", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("send_device_error", desktopQuota), + errorLimiter.Limit("send_device_error", desktopQuota, logger), ).POST("/api/_version_/fleet/device/{token}/debug/errors", fleetdError, fleetdErrorRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_software", desktopQuota), + errorLimiter.Limit("get_device_software", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/software", getDeviceSoftwareEndpoint, getDeviceSoftwareRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("install_self_service", desktopQuota), + errorLimiter.Limit("install_self_service", desktopQuota, logger), ).POST("/api/_version_/fleet/device/{token}/software/install/{software_title_id}", submitSelfServiceSoftwareInstall, fleetSelfServiceSoftwareInstallRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("uninstall_self_service", desktopQuota), + errorLimiter.Limit("uninstall_self_service", desktopQuota, logger), ).POST("/api/_version_/fleet/device/{token}/software/uninstall/{software_title_id}", submitDeviceSoftwareUninstall, fleetDeviceSoftwareUninstallRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_software_install_results", desktopQuota), + errorLimiter.Limit("get_device_software_install_results", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/software/install/{install_uuid}/results", getDeviceSoftwareInstallResultsEndpoint, getDeviceSoftwareInstallResultsRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_software_uninstall_results", desktopQuota), + errorLimiter.Limit("get_device_software_uninstall_results", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/software/uninstall/{execution_id}/results", getDeviceSoftwareUninstallResultsEndpoint, getDeviceSoftwareUninstallResultsRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("get_device_certificates", desktopQuota), + errorLimiter.Limit("get_device_certificates", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/certificates", listDeviceCertificatesEndpoint, listDeviceCertificatesRequest{}) // mdm-related endpoints available via device authentication demdm := de.WithCustomMiddleware(mdmConfiguredMiddleware.VerifyAppleMDM()) demdm.WithCustomMiddleware( - errorLimiter.Limit("get_device_mdm", desktopQuota), + errorLimiter.Limit("get_device_mdm", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/mdm/apple/manual_enrollment_profile", getDeviceMDMManualEnrollProfileEndpoint, getDeviceMDMManualEnrollProfileRequest{}) demdm.WithCustomMiddleware( - errorLimiter.Limit("get_device_software_mdm_command_results", desktopQuota), + errorLimiter.Limit("get_device_software_mdm_command_results", desktopQuota, logger), ).GET("/api/_version_/fleet/device/{token}/software/commands/{command_uuid}/results", getDeviceMDMCommandResultsEndpoint, getDeviceMDMCommandResultsRequest{}) demdm.WithCustomMiddleware( - errorLimiter.Limit("post_device_migrate_mdm", desktopQuota), + errorLimiter.Limit("post_device_migrate_mdm", desktopQuota, logger), ).POST("/api/_version_/fleet/device/{token}/migrate_mdm", migrateMDMDeviceEndpoint, deviceMigrateMDMRequest{}) de.WithCustomMiddleware( - errorLimiter.Limit("post_device_trigger_linux_escrow", desktopQuota), + errorLimiter.Limit("post_device_trigger_linux_escrow", desktopQuota, logger), ).POST("/api/_version_/fleet/device/{token}/mdm/linux/trigger_escrow", triggerLinuxDiskEncryptionEscrowEndpoint, triggerLinuxDiskEncryptionEscrowRequest{}) // host-authenticated endpoints diff --git a/server/service/middleware/ratelimit/ratelimit.go b/server/service/middleware/ratelimit/ratelimit.go index f4d9f1beea..8b427d926b 100644 --- a/server/service/middleware/ratelimit/ratelimit.go +++ b/server/service/middleware/ratelimit/ratelimit.go @@ -7,8 +7,10 @@ import ( authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/publicip" "github.com/go-kit/kit/endpoint" - kithttp "github.com/go-kit/kit/transport/http" + "github.com/go-kit/kit/log/level" + kitlog "github.com/go-kit/log" "github.com/throttled/throttled/v2" ) @@ -54,7 +56,7 @@ func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.M if az, ok := authz_ctx.FromContext(ctx); ok { az.SetChecked() } - return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result}) + return nil, ctxerr.Wrap(ctx, &rateLimitError{result: result}) } return next(ctx, req) @@ -77,7 +79,7 @@ func NewErrorMiddleware(store throttled.GCRAStore) *ErrorMiddleware { } // Limit returns a new middleware function enforcing the provided quota only when errors occur in the next middleware -func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware { +func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota, logger kitlog.Logger) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { limiter, err := throttled.NewGCRARateLimiter(m.store, quota) if err != nil { @@ -85,8 +87,12 @@ func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpo } return func(ctx context.Context, req interface{}) (response interface{}, err error) { - xForwardedFor, _ := ctx.Value(kithttp.ContextKeyRequestXForwardedFor).(string) - ipKeyName := fmt.Sprintf("%s-%s", keyName, xForwardedFor) + publicIP := publicip.FromContext(ctx) + if publicIP == "" { + level.Warn(logger).Log("msg", "missing public_ip, skipping rate limit") + return next(ctx, req) + } + ipKeyName := fmt.Sprintf("%s-%s", keyName, publicIP) // RateLimit with quantity 0 will never get limited=true, so we check result.Remaining instead _, result, err := limiter.RateLimit(ipKeyName, 0) @@ -104,13 +110,23 @@ func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpo if az, ok := authz_ctx.FromContext(ctx); ok { az.SetChecked() } - return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result}) + level.Warn(logger).Log( + "ip", publicIP, + "msg", "limit exceeded", + ) + return nil, ctxerr.Wrap(ctx, &rateLimitError{result: result}) } resp, err := next(ctx, req) if err != nil { _, _, rateErr := limiter.RateLimit(ipKeyName, 1) if rateErr != nil { + // This can happen if the limit store (e.g. Redis) is unavailable. + // + // We need to set authentication as checked, otherwise we end up returning HTTP 500 errors. + if az, ok := authz_ctx.FromContext(ctx); ok { + az.SetChecked() + } return nil, ctxerr.Wrap(ctx, err, "rate limit ErrorMiddleware: failed to increase rate limit") } } @@ -125,22 +141,29 @@ type Error interface { Result() throttled.RateLimitResult } -type ratelimitError struct { +type rateLimitError struct { result throttled.RateLimitResult } -func (r ratelimitError) Error() string { - return fmt.Sprintf("limit exceeded, retry after: %ds", int(r.result.RetryAfter.Seconds())) +func (r rateLimitError) Error() string { + // github.com/throttled/throttled has a bug where "peeking" with RateLimit(key, 0) + // always returns a RetryAfter=-1. So we just return "limit exceeded" to prevent confusing + // errors with "limit exceeded, retry after: 0s". + ra := int(r.result.RetryAfter.Seconds()) + if ra > 0 { + return fmt.Sprintf("limit exceeded, retry after: %ds", ra) + } + return "limit exceeded" } -func (r ratelimitError) StatusCode() int { +func (r rateLimitError) StatusCode() int { return http.StatusTooManyRequests } -func (r ratelimitError) RetryAfter() int { +func (r rateLimitError) RetryAfter() int { return int(r.result.RetryAfter.Seconds()) } -func (r ratelimitError) Result() throttled.RateLimitResult { +func (r rateLimitError) Result() throttled.RateLimitResult { return r.result } diff --git a/server/service/middleware/ratelimit/ratelimit_test.go b/server/service/middleware/ratelimit/ratelimit_test.go index 477b08f63d..1ddcaa4612 100644 --- a/server/service/middleware/ratelimit/ratelimit_test.go +++ b/server/service/middleware/ratelimit/ratelimit_test.go @@ -3,10 +3,14 @@ package ratelimit import ( "context" "errors" + "net/http" "testing" authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" + "github.com/fleetdm/fleet/v4/server/contexts/publicip" + kitlog "github.com/go-kit/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/throttled/throttled/v2" "github.com/throttled/throttled/v2/store/memstore" ) @@ -50,9 +54,13 @@ func TestLimit(t *testing.T) { _, err = wrapped(ctx, struct{}{}) assert.Error(t, err) var rle Error - assert.True(t, errors.As(err, &rle)) assert.True(t, authzCtx.Checked()) + require.Contains(t, rle.Error(), "limit exceeded, retry after: ") + rle_, ok := rle.(*rateLimitError) + require.True(t, ok) + require.NotZero(t, rle_.RetryAfter()) + require.Equal(t, http.StatusTooManyRequests, rle_.StatusCode()) // ensure that the same endpoint wrapped with a different limiter doesn't hit the error _, err = wrapped2(ctx, struct{}{}) @@ -85,27 +93,58 @@ func TestLimitOnlyWhenError(t *testing.T) { limiter := NewErrorMiddleware(store) endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } wrapped := limiter.Limit( - "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, + "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, kitlog.NewNopLogger(), )(endpoint) // Does NOT hit any rate limits because the endpoint doesn't fail - _, err := wrapped(context.Background(), struct{}{}) + ctx := publicip.NewContext(context.Background(), "0.0.0.0") + _, err := wrapped(ctx, struct{}{}) assert.NoError(t, err) - _, err = wrapped(context.Background(), struct{}{}) + _, err = wrapped(ctx, struct{}{}) assert.NoError(t, err) expectedError := errors.New("error") failingEndpoint := func(context.Context, interface{}) (interface{}, error) { return nil, expectedError } wrappedFailer := limiter.Limit( - "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, + "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, kitlog.NewNopLogger(), )(failingEndpoint) - _, err = wrappedFailer(context.Background(), struct{}{}) + // First request that fails should be allowed. + _, err = wrappedFailer(ctx, struct{}{}) assert.ErrorIs(t, err, expectedError) - // Hits rate limit now that it fails - _, err = wrappedFailer(context.Background(), struct{}{}) + // Second request that fails should not be allowed. + _, err = wrappedFailer(ctx, struct{}{}) assert.Error(t, err) var rle Error - assert.True(t, errors.As(err, &rle)) + require.True(t, errors.As(err, &rle)) + // github.com/throttled/throttled has a bug where "peeking" with RateLimit(key, 0) + // always returns a RetryAfter=-1. So I'll just leave this here but in the future + // we could return the correct Retry-After. Also, we are not making use of "Retry-After" + // on the agent side yet. + require.EqualValues(t, rle.Result().RetryAfter, -1) + require.Equal(t, "limit exceeded", rle.Error()) +} + +func TestNoRateLimitWithoutPublicIP(t *testing.T) { + t.Parallel() + + store, _ := memstore.New(1) + limiter := NewErrorMiddleware(store) + + expectedError := errors.New("error") + failingEndpoint := func(context.Context, interface{}) (interface{}, error) { + return nil, expectedError + } + wrappedFailer := limiter.Limit( + "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, kitlog.NewNopLogger(), + )(failingEndpoint) + + ctx := context.Background() + + // Requests should not be rate limited because there's no "Public IP" identifier in the request. + _, err := wrappedFailer(ctx, struct{}{}) + assert.ErrorIs(t, err, expectedError) + _, err = wrappedFailer(ctx, struct{}{}) + assert.ErrorIs(t, err, expectedError) }