mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 00:49:03 +00:00
Fixes to the offline indicator (#31685)
#31592 There's still some QA to be done for edge cases and re-connects, but this is ready for review. <img width="341" height="103" alt="Screenshot 2025-08-07 at 11 19 33 AM" src="https://github.com/user-attachments/assets/01e48ca2-8ab1-412c-be01-8e806a5a8b1c" /> Changes: - To improve UX I'm now using `HEAD /api/fleet/device/ping` API every 10 seconds for connectivity/offline check (instead of the expensive DesktopSummary one every 5 minutes). This is to address feedback from a customer: > "If the internet is not connected and we reconnect with an ethernet connection for example, it would be good to try to see if we can refresh it text from the offline indicator given that's not the case anymore. - It might take up to 1m for Fleet Desktop to show the offline indicator (we check every 10s with ping and now we are adding 6 more requests in 1 minute to make sure just one bad request doesn't unnecessarily display the offline indicator). - Requests without proper public IP were being incorrectly rate limited (all under the same bucket). So we will now not make these requests and instead log a WARNING. This is a-ok as the recommended approach to deploy Fleet is with a TLS terminator that will add the public IP of the request before sending it to Fleet. --- - [X] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. ## Testing - [X] Added/updated automated tests - [ ] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [ ] QA'd all new/changed functionality manually ## fleetd/orbit/Fleet Desktop - [ ] Verified compatibility with the latest released version of Fleet (see [Must rule](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/workflows/fleetd-development-and-release-strategy.md)) - [ ] Verified that fleetd runs on macOS, Linux and Windows - [ ] Verified auto-update works from the released version of component to the new version (see [tools/tuf/test](../tools/tuf/test/README.md)) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved accuracy in identifying client public IP addresses, reducing incorrect rate limiting for Fleet Desktop users. * Offline indicator is now less sensitive to brief network interruptions, reducing false offline signals and allowing faster recovery when connectivity is restored. * Updated offline message for clearer status communication. * **New Features** * Enhanced error messages and logging for rate limiting events, providing clearer feedback when limits are reached. * **Tests** * Expanded test coverage for rate limiting, including scenarios with missing public IPs and improved assertions for error handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
da38efa526
commit
12f2ee6ad1
7 changed files with 202 additions and 92 deletions
2
changes/31592-improve-offline-indicator
Normal file
2
changes/31592-improve-offline-indicator
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
* Fixed invalid rate limiting applied on Fleet Desktop requests for which a public IP could not be determined.
|
||||
* Improved public IP extraction for Fleet Desktop requests.
|
||||
1
orbit/changes/31592-improve-offline-indicator
Normal file
1
orbit/changes/31592-improve-offline-indicator
Normal file
|
|
@ -0,0 +1 @@
|
|||
* Fixed offline indicator to be less sensitive to transient network failures and faster recovery when connectivity is restored.
|
||||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
|
|
@ -140,10 +141,16 @@ func main() {
|
|||
var swiftDialogCh chan struct{}
|
||||
var offlineWatcher useraction.MDMOfflineWatcher
|
||||
|
||||
// This ticker is used for fetching the desktop summary. It is initialized here because it is
|
||||
// We will execute the summary API every 5 minutes (to refresh policy state).
|
||||
const desktopSummaryInterval = 5 * time.Minute
|
||||
|
||||
// This ticker is used for checking connectivity. It is initialized here because it is
|
||||
// stopped in `OnExit.`
|
||||
const checkInterval = 5 * time.Minute
|
||||
summaryTicker := time.NewTicker(checkInterval)
|
||||
const pingInterval = 10 * time.Second // same value as default distributed/read
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
|
||||
// Used to trigger a policy check when clicking on "My device" or "About Fleet".
|
||||
var fleetDesktopCheckTrigger atomic.Bool
|
||||
|
||||
// we have seen some cases where systray.Run() does not call onReady seemingly due to early
|
||||
// initialization states with the GUI such as Windows Autopilot first time setup. This ensures
|
||||
|
|
@ -181,16 +188,12 @@ func main() {
|
|||
// immediately begin showing the migrator again if we were showing it prior.
|
||||
showMDMMigrator := false
|
||||
|
||||
myDeviceItem := systray.AddMenuItem("Connecting...", "")
|
||||
myDeviceItem := systray.AddMenuItem("My device", "")
|
||||
myDeviceItem.Disable()
|
||||
myDeviceItem.Hide()
|
||||
|
||||
// We are doing this using two menu items because line breaks
|
||||
// are not rendered correctly on Windows and MacOS.
|
||||
hostOfflineItemOne := systray.AddMenuItem("🛜🚫 Your computer is offline.", "")
|
||||
hostOfflineItemTwo := systray.AddMenuItem("It might take up to 5 minutes to reconnect to Fleet.", "")
|
||||
hostOfflineItemOne := systray.AddMenuItem("🛜🚫 Your computer is not connected to Fleet.", "")
|
||||
hostOfflineItemOne.Disable()
|
||||
hostOfflineItemTwo.Disable()
|
||||
|
||||
selfServiceItem := systray.AddMenuItem("Self-service", "")
|
||||
selfServiceItem.Disable()
|
||||
|
|
@ -234,8 +237,8 @@ func main() {
|
|||
return newToken
|
||||
})
|
||||
|
||||
disableTray := func() {
|
||||
log.Debug().Msg("disabling tray items")
|
||||
showConnecting := func() {
|
||||
log.Debug().Msg("displaying Connecting...")
|
||||
myDeviceItem.SetTitle("Connecting...")
|
||||
myDeviceItem.Show()
|
||||
myDeviceItem.Disable()
|
||||
|
|
@ -252,7 +255,6 @@ func main() {
|
|||
}
|
||||
|
||||
hostOfflineItemOne.Hide()
|
||||
hostOfflineItemTwo.Hide()
|
||||
}
|
||||
|
||||
reportError := func(err error, info map[string]any) {
|
||||
|
|
@ -296,6 +298,7 @@ func main() {
|
|||
// checkToken performs API test calls to enable the "My device" item as
|
||||
// soon as the device auth token is registered by Fleet.
|
||||
checkToken := func() <-chan interface{} {
|
||||
showConnecting()
|
||||
done := make(chan interface{})
|
||||
|
||||
go func() {
|
||||
|
|
@ -317,7 +320,6 @@ func main() {
|
|||
transparencyItem.Show()
|
||||
|
||||
hostOfflineItemOne.Hide()
|
||||
hostOfflineItemTwo.Hide()
|
||||
|
||||
// Hide Self-Service for Free tier
|
||||
if errors.Is(err, service.ErrMissingLicense) || (summary.SelfService != nil && !*summary.SelfService) {
|
||||
|
|
@ -364,51 +366,92 @@ func main() {
|
|||
case err != nil:
|
||||
log.Error().Err(err).Msg("check token file")
|
||||
case expired:
|
||||
log.Info().Msg("token file changed, rechecking")
|
||||
disableTray()
|
||||
log.Info().Msg("token file expired or invalid, rechecking")
|
||||
<-checkToken()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// poll the server to check the policy status of the host and update the
|
||||
// tray icon accordingly
|
||||
// tray icon accordingly.
|
||||
// We first ping the server to check for connectivity, then get the policy status (every 5 minutes to
|
||||
// not cause performance issues on the server).
|
||||
go func() {
|
||||
<-deviceEnabledChan
|
||||
|
||||
for {
|
||||
<-summaryTicker.C
|
||||
// Reset the ticker to the intended interval, in case we reset it to 1ms
|
||||
summaryTicker.Reset(checkInterval)
|
||||
sum, err := client.DesktopSummary(tokenReader.GetCached())
|
||||
switch {
|
||||
case err == nil:
|
||||
hostOfflineItemOne.Hide()
|
||||
hostOfflineItemTwo.Hide()
|
||||
case errors.Is(err, service.ErrMissingLicense):
|
||||
myDeviceItem.SetTitle("My device")
|
||||
myDeviceItem.Show()
|
||||
hostOfflineItemOne.Hide()
|
||||
hostOfflineItemTwo.Hide()
|
||||
continue
|
||||
case errors.Is(err, service.ErrUnauthenticated):
|
||||
disableTray()
|
||||
hostOfflineItemOne.Hide()
|
||||
hostOfflineItemTwo.Hide()
|
||||
<-checkToken()
|
||||
continue
|
||||
default:
|
||||
var (
|
||||
pingErrCount = 0
|
||||
lastDesktopSummaryCheck time.Time
|
||||
offlineIndicatorDisplayed = false
|
||||
showOffline = func() {
|
||||
myDeviceItem.Hide()
|
||||
transparencyItem.Disable()
|
||||
transparencyItem.Hide()
|
||||
migrateMDMItem.Disable()
|
||||
migrateMDMItem.Hide()
|
||||
hostOfflineItemOne.Show()
|
||||
hostOfflineItemTwo.Show()
|
||||
log.Error().Err(err).Msg("get desktop summary")
|
||||
selfServiceItem.Disable()
|
||||
selfServiceItem.Hide()
|
||||
offlineIndicatorDisplayed = true
|
||||
}
|
||||
)
|
||||
|
||||
for {
|
||||
<-pingTicker.C
|
||||
|
||||
// Reset the ticker to the intended interval,
|
||||
// in case we reset it to 1ms (when clicking on "My device").
|
||||
pingTicker.Reset(pingInterval)
|
||||
|
||||
if err := client.Ping(); err != nil {
|
||||
log.Error().Err(err).Int("count", pingErrCount).Msg("ping failed")
|
||||
pingErrCount++
|
||||
// We try 5 more times to make sure one bad request doesn't trigger the offline indicator.
|
||||
// So it might take up to ~1m (6 * 10s) for Fleet Desktop to show the offline indicator.
|
||||
if pingErrCount >= 6 {
|
||||
showOffline()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully connected to Fleet.
|
||||
pingErrCount = 0
|
||||
|
||||
// Check if we need to fetch the "Fleet desktop" summary from Fleet.
|
||||
if !offlineIndicatorDisplayed &&
|
||||
!fleetDesktopCheckTrigger.Load() &&
|
||||
(!lastDesktopSummaryCheck.IsZero() && time.Since(lastDesktopSummaryCheck) < desktopSummaryInterval) {
|
||||
continue
|
||||
}
|
||||
|
||||
lastDesktopSummaryCheck = time.Now()
|
||||
fleetDesktopCheckTrigger.Store(false)
|
||||
// We set offlineIndicatorDisplayed to false because we do not want to retry the
|
||||
// Fleet Desktop summary every 10s if Ping works but DesktopSummary doesn't
|
||||
// (to avoid server load issues).
|
||||
offlineIndicatorDisplayed = false
|
||||
|
||||
sum, err := client.DesktopSummary(tokenReader.GetCached())
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, service.ErrMissingLicense):
|
||||
// Policy reporting in Fleet Desktop requires a license,
|
||||
// so we just show the "My device" item as usual.
|
||||
myDeviceItem.SetTitle("My device")
|
||||
myDeviceItem.Show()
|
||||
hostOfflineItemOne.Hide()
|
||||
case errors.Is(err, service.ErrUnauthenticated):
|
||||
log.Debug().Err(err).Msg("get desktop summary auth failure")
|
||||
// This usually happens every ~1 hour when the token expires.
|
||||
<-checkToken()
|
||||
default:
|
||||
log.Error().Err(err).Msg("get desktop summary failed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
hostOfflineItemOne.Hide()
|
||||
|
||||
refreshMenuItems(sum.DesktopSummary, selfServiceItem, myDeviceItem)
|
||||
myDeviceItem.Enable()
|
||||
myDeviceItem.Show()
|
||||
|
|
@ -509,7 +552,8 @@ func main() {
|
|||
log.Error().Err(err).Str("url", openURL).Msg("open browser policies")
|
||||
}
|
||||
// Also refresh the device status by forcing the polling ticker to fire
|
||||
summaryTicker.Reset(1 * time.Millisecond)
|
||||
fleetDesktopCheckTrigger.Store(true)
|
||||
pingTicker.Reset(1 * time.Millisecond)
|
||||
case <-transparencyItem.ClickedCh:
|
||||
openURL := client.BrowserTransparencyURL(tokenReader.GetCached())
|
||||
if err := open.Browser(openURL); err != nil {
|
||||
|
|
@ -521,7 +565,8 @@ func main() {
|
|||
log.Error().Err(err).Str("url", openURL).Msg("open browser self-service")
|
||||
}
|
||||
// Also refresh the device status by forcing the polling ticker to fire
|
||||
summaryTicker.Reset(1 * time.Millisecond)
|
||||
fleetDesktopCheckTrigger.Store(true)
|
||||
pingTicker.Reset(1 * time.Millisecond)
|
||||
case <-migrateMDMItem.ClickedCh:
|
||||
if offline := offlineWatcher.ShowIfOffline(offlineWatcherCtx); offline {
|
||||
continue
|
||||
|
|
@ -548,8 +593,8 @@ func main() {
|
|||
log.Debug().Err(err).Msg("exiting swiftDialogCh")
|
||||
close(swiftDialogCh)
|
||||
}
|
||||
log.Debug().Msg("stopping ticker")
|
||||
summaryTicker.Stop()
|
||||
log.Debug().Msg("stopping ping ticker")
|
||||
pingTicker.Stop()
|
||||
log.Debug().Msg("canceling offline watcher ctx")
|
||||
cancelOfflineWatcherCtx()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1248,9 +1248,9 @@ func main() {
|
|||
// This is better than using a ticker that ticks every hour because the
|
||||
// we can't ensure the tick actually runs every hour (eg: the computer is
|
||||
// asleep).
|
||||
rotationDuration := 30 * time.Second
|
||||
rotationTicker := time.NewTicker(rotationDuration)
|
||||
defer rotationTicker.Stop()
|
||||
localCheckDuration := 30 * time.Second
|
||||
localCheckTicker := time.NewTicker(localCheckDuration)
|
||||
defer localCheckTicker.Stop()
|
||||
|
||||
// This timer is used to periodically check if the token is valid. The
|
||||
// server might deem a toked as invalid for reasons out of our control,
|
||||
|
|
@ -1262,10 +1262,10 @@ func main() {
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-rotationTicker.C:
|
||||
rotationTicker.Reset(rotationDuration)
|
||||
case <-localCheckTicker.C:
|
||||
localCheckTicker.Reset(localCheckDuration)
|
||||
|
||||
log.Debug().Msgf("checking if token has changed or expired, cached mtime: %s", trw.GetMtime())
|
||||
log.Debug().Msgf("initiating local token check, cached mtime: %s", trw.GetMtime())
|
||||
hasChanged, err := trw.HasChanged()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error checking if token has changed")
|
||||
|
|
@ -1281,10 +1281,10 @@ func main() {
|
|||
if err := trw.Rotate(); err != nil {
|
||||
log.Error().Err(err).Msg("error rotating token")
|
||||
}
|
||||
} else if remain > 0 && remain < rotationDuration {
|
||||
} else if remain > 0 && remain < localCheckDuration {
|
||||
// check again when the token will expire, which will happen
|
||||
// before the next rotation check
|
||||
rotationTicker.Reset(remain)
|
||||
localCheckTicker.Reset(remain)
|
||||
log.Debug().Msgf("token will expire soon, checking again in: %s", remain)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -801,68 +801,68 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
|||
// properly
|
||||
desktopQuota := throttled.RateQuota{MaxRate: throttled.PerHour(720), MaxBurst: desktopRateLimitMaxBurst}
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_host", desktopQuota),
|
||||
errorLimiter.Limit("get_device_host", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}", getDeviceHostEndpoint, getDeviceHostRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_fleet_desktop", desktopQuota),
|
||||
errorLimiter.Limit("get_fleet_desktop", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/desktop", getFleetDesktopEndpoint, getFleetDesktopRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("ping_device_auth", desktopQuota),
|
||||
errorLimiter.Limit("ping_device_auth", desktopQuota, logger),
|
||||
).HEAD("/api/_version_/fleet/device/{token}/ping", devicePingEndpoint, deviceAuthPingRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("refetch_device_host", desktopQuota),
|
||||
errorLimiter.Limit("refetch_device_host", desktopQuota, logger),
|
||||
).POST("/api/_version_/fleet/device/{token}/refetch", refetchDeviceHostEndpoint, refetchDeviceHostRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_mapping", desktopQuota),
|
||||
errorLimiter.Limit("get_device_mapping", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/device_mapping", listDeviceHostDeviceMappingEndpoint, listDeviceHostDeviceMappingRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_macadmins", desktopQuota),
|
||||
errorLimiter.Limit("get_device_macadmins", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/macadmins", getDeviceMacadminsDataEndpoint, getDeviceMacadminsDataRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_policies", desktopQuota),
|
||||
errorLimiter.Limit("get_device_policies", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/policies", listDevicePoliciesEndpoint, listDevicePoliciesRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_transparency", desktopQuota),
|
||||
errorLimiter.Limit("get_device_transparency", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/transparency", transparencyURL, transparencyURLRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("send_device_error", desktopQuota),
|
||||
errorLimiter.Limit("send_device_error", desktopQuota, logger),
|
||||
).POST("/api/_version_/fleet/device/{token}/debug/errors", fleetdError, fleetdErrorRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_software", desktopQuota),
|
||||
errorLimiter.Limit("get_device_software", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/software", getDeviceSoftwareEndpoint, getDeviceSoftwareRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("install_self_service", desktopQuota),
|
||||
errorLimiter.Limit("install_self_service", desktopQuota, logger),
|
||||
).POST("/api/_version_/fleet/device/{token}/software/install/{software_title_id}", submitSelfServiceSoftwareInstall, fleetSelfServiceSoftwareInstallRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("uninstall_self_service", desktopQuota),
|
||||
errorLimiter.Limit("uninstall_self_service", desktopQuota, logger),
|
||||
).POST("/api/_version_/fleet/device/{token}/software/uninstall/{software_title_id}", submitDeviceSoftwareUninstall, fleetDeviceSoftwareUninstallRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_software_install_results", desktopQuota),
|
||||
errorLimiter.Limit("get_device_software_install_results", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/software/install/{install_uuid}/results", getDeviceSoftwareInstallResultsEndpoint,
|
||||
getDeviceSoftwareInstallResultsRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_software_uninstall_results", desktopQuota),
|
||||
errorLimiter.Limit("get_device_software_uninstall_results", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/software/uninstall/{execution_id}/results", getDeviceSoftwareUninstallResultsEndpoint, getDeviceSoftwareUninstallResultsRequest{})
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_certificates", desktopQuota),
|
||||
errorLimiter.Limit("get_device_certificates", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/certificates", listDeviceCertificatesEndpoint, listDeviceCertificatesRequest{})
|
||||
|
||||
// mdm-related endpoints available via device authentication
|
||||
demdm := de.WithCustomMiddleware(mdmConfiguredMiddleware.VerifyAppleMDM())
|
||||
demdm.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_mdm", desktopQuota),
|
||||
errorLimiter.Limit("get_device_mdm", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/mdm/apple/manual_enrollment_profile", getDeviceMDMManualEnrollProfileEndpoint, getDeviceMDMManualEnrollProfileRequest{})
|
||||
demdm.WithCustomMiddleware(
|
||||
errorLimiter.Limit("get_device_software_mdm_command_results", desktopQuota),
|
||||
errorLimiter.Limit("get_device_software_mdm_command_results", desktopQuota, logger),
|
||||
).GET("/api/_version_/fleet/device/{token}/software/commands/{command_uuid}/results", getDeviceMDMCommandResultsEndpoint,
|
||||
getDeviceMDMCommandResultsRequest{})
|
||||
|
||||
demdm.WithCustomMiddleware(
|
||||
errorLimiter.Limit("post_device_migrate_mdm", desktopQuota),
|
||||
errorLimiter.Limit("post_device_migrate_mdm", desktopQuota, logger),
|
||||
).POST("/api/_version_/fleet/device/{token}/migrate_mdm", migrateMDMDeviceEndpoint, deviceMigrateMDMRequest{})
|
||||
|
||||
de.WithCustomMiddleware(
|
||||
errorLimiter.Limit("post_device_trigger_linux_escrow", desktopQuota),
|
||||
errorLimiter.Limit("post_device_trigger_linux_escrow", desktopQuota, logger),
|
||||
).POST("/api/_version_/fleet/device/{token}/mdm/linux/trigger_escrow", triggerLinuxDiskEncryptionEscrowEndpoint, triggerLinuxDiskEncryptionEscrowRequest{})
|
||||
|
||||
// host-authenticated endpoints
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ import (
|
|||
|
||||
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/publicip"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/throttled/throttled/v2"
|
||||
)
|
||||
|
||||
|
|
@ -54,7 +56,7 @@ func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.M
|
|||
if az, ok := authz_ctx.FromContext(ctx); ok {
|
||||
az.SetChecked()
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result})
|
||||
return nil, ctxerr.Wrap(ctx, &rateLimitError{result: result})
|
||||
}
|
||||
|
||||
return next(ctx, req)
|
||||
|
|
@ -77,7 +79,7 @@ func NewErrorMiddleware(store throttled.GCRAStore) *ErrorMiddleware {
|
|||
}
|
||||
|
||||
// Limit returns a new middleware function enforcing the provided quota only when errors occur in the next middleware
|
||||
func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware {
|
||||
func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota, logger kitlog.Logger) endpoint.Middleware {
|
||||
return func(next endpoint.Endpoint) endpoint.Endpoint {
|
||||
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
|
||||
if err != nil {
|
||||
|
|
@ -85,8 +87,12 @@ func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpo
|
|||
}
|
||||
|
||||
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
|
||||
xForwardedFor, _ := ctx.Value(kithttp.ContextKeyRequestXForwardedFor).(string)
|
||||
ipKeyName := fmt.Sprintf("%s-%s", keyName, xForwardedFor)
|
||||
publicIP := publicip.FromContext(ctx)
|
||||
if publicIP == "" {
|
||||
level.Warn(logger).Log("msg", "missing public_ip, skipping rate limit")
|
||||
return next(ctx, req)
|
||||
}
|
||||
ipKeyName := fmt.Sprintf("%s-%s", keyName, publicIP)
|
||||
|
||||
// RateLimit with quantity 0 will never get limited=true, so we check result.Remaining instead
|
||||
_, result, err := limiter.RateLimit(ipKeyName, 0)
|
||||
|
|
@ -104,13 +110,23 @@ func (m *ErrorMiddleware) Limit(keyName string, quota throttled.RateQuota) endpo
|
|||
if az, ok := authz_ctx.FromContext(ctx); ok {
|
||||
az.SetChecked()
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, &ratelimitError{result: result})
|
||||
level.Warn(logger).Log(
|
||||
"ip", publicIP,
|
||||
"msg", "limit exceeded",
|
||||
)
|
||||
return nil, ctxerr.Wrap(ctx, &rateLimitError{result: result})
|
||||
}
|
||||
|
||||
resp, err := next(ctx, req)
|
||||
if err != nil {
|
||||
_, _, rateErr := limiter.RateLimit(ipKeyName, 1)
|
||||
if rateErr != nil {
|
||||
// This can happen if the limit store (e.g. Redis) is unavailable.
|
||||
//
|
||||
// We need to set authentication as checked, otherwise we end up returning HTTP 500 errors.
|
||||
if az, ok := authz_ctx.FromContext(ctx); ok {
|
||||
az.SetChecked()
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "rate limit ErrorMiddleware: failed to increase rate limit")
|
||||
}
|
||||
}
|
||||
|
|
@ -125,22 +141,29 @@ type Error interface {
|
|||
Result() throttled.RateLimitResult
|
||||
}
|
||||
|
||||
type ratelimitError struct {
|
||||
type rateLimitError struct {
|
||||
result throttled.RateLimitResult
|
||||
}
|
||||
|
||||
func (r ratelimitError) Error() string {
|
||||
return fmt.Sprintf("limit exceeded, retry after: %ds", int(r.result.RetryAfter.Seconds()))
|
||||
func (r rateLimitError) Error() string {
|
||||
// github.com/throttled/throttled has a bug where "peeking" with RateLimit(key, 0)
|
||||
// always returns a RetryAfter=-1. So we just return "limit exceeded" to prevent confusing
|
||||
// errors with "limit exceeded, retry after: 0s".
|
||||
ra := int(r.result.RetryAfter.Seconds())
|
||||
if ra > 0 {
|
||||
return fmt.Sprintf("limit exceeded, retry after: %ds", ra)
|
||||
}
|
||||
return "limit exceeded"
|
||||
}
|
||||
|
||||
func (r ratelimitError) StatusCode() int {
|
||||
func (r rateLimitError) StatusCode() int {
|
||||
return http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
func (r ratelimitError) RetryAfter() int {
|
||||
func (r rateLimitError) RetryAfter() int {
|
||||
return int(r.result.RetryAfter.Seconds())
|
||||
}
|
||||
|
||||
func (r ratelimitError) Result() throttled.RateLimitResult {
|
||||
func (r rateLimitError) Result() throttled.RateLimitResult {
|
||||
return r.result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,14 @@ package ratelimit
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/publicip"
|
||||
kitlog "github.com/go-kit/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/throttled/throttled/v2"
|
||||
"github.com/throttled/throttled/v2/store/memstore"
|
||||
)
|
||||
|
|
@ -50,9 +54,13 @@ func TestLimit(t *testing.T) {
|
|||
_, err = wrapped(ctx, struct{}{})
|
||||
assert.Error(t, err)
|
||||
var rle Error
|
||||
|
||||
assert.True(t, errors.As(err, &rle))
|
||||
assert.True(t, authzCtx.Checked())
|
||||
require.Contains(t, rle.Error(), "limit exceeded, retry after: ")
|
||||
rle_, ok := rle.(*rateLimitError)
|
||||
require.True(t, ok)
|
||||
require.NotZero(t, rle_.RetryAfter())
|
||||
require.Equal(t, http.StatusTooManyRequests, rle_.StatusCode())
|
||||
|
||||
// ensure that the same endpoint wrapped with a different limiter doesn't hit the error
|
||||
_, err = wrapped2(ctx, struct{}{})
|
||||
|
|
@ -85,27 +93,58 @@ func TestLimitOnlyWhenError(t *testing.T) {
|
|||
limiter := NewErrorMiddleware(store)
|
||||
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
|
||||
wrapped := limiter.Limit(
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, kitlog.NewNopLogger(),
|
||||
)(endpoint)
|
||||
|
||||
// Does NOT hit any rate limits because the endpoint doesn't fail
|
||||
_, err := wrapped(context.Background(), struct{}{})
|
||||
ctx := publicip.NewContext(context.Background(), "0.0.0.0")
|
||||
_, err := wrapped(ctx, struct{}{})
|
||||
assert.NoError(t, err)
|
||||
_, err = wrapped(context.Background(), struct{}{})
|
||||
_, err = wrapped(ctx, struct{}{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedError := errors.New("error")
|
||||
failingEndpoint := func(context.Context, interface{}) (interface{}, error) { return nil, expectedError }
|
||||
wrappedFailer := limiter.Limit(
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, kitlog.NewNopLogger(),
|
||||
)(failingEndpoint)
|
||||
|
||||
_, err = wrappedFailer(context.Background(), struct{}{})
|
||||
// First request that fails should be allowed.
|
||||
_, err = wrappedFailer(ctx, struct{}{})
|
||||
assert.ErrorIs(t, err, expectedError)
|
||||
|
||||
// Hits rate limit now that it fails
|
||||
_, err = wrappedFailer(context.Background(), struct{}{})
|
||||
// Second request that fails should not be allowed.
|
||||
_, err = wrappedFailer(ctx, struct{}{})
|
||||
assert.Error(t, err)
|
||||
var rle Error
|
||||
assert.True(t, errors.As(err, &rle))
|
||||
require.True(t, errors.As(err, &rle))
|
||||
// github.com/throttled/throttled has a bug where "peeking" with RateLimit(key, 0)
|
||||
// always returns a RetryAfter=-1. So I'll just leave this here but in the future
|
||||
// we could return the correct Retry-After. Also, we are not making use of "Retry-After"
|
||||
// on the agent side yet.
|
||||
require.EqualValues(t, rle.Result().RetryAfter, -1)
|
||||
require.Equal(t, "limit exceeded", rle.Error())
|
||||
}
|
||||
|
||||
func TestNoRateLimitWithoutPublicIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := memstore.New(1)
|
||||
limiter := NewErrorMiddleware(store)
|
||||
|
||||
expectedError := errors.New("error")
|
||||
failingEndpoint := func(context.Context, interface{}) (interface{}, error) {
|
||||
return nil, expectedError
|
||||
}
|
||||
wrappedFailer := limiter.Limit(
|
||||
"test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, kitlog.NewNopLogger(),
|
||||
)(failingEndpoint)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Requests should not be rate limited because there's no "Public IP" identifier in the request.
|
||||
_, err := wrappedFailer(ctx, struct{}{})
|
||||
assert.ErrorIs(t, err, expectedError)
|
||||
_, err = wrappedFailer(ctx, struct{}{})
|
||||
assert.ErrorIs(t, err, expectedError)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue