diff --git a/server/service/service.go b/server/service/service.go index e34fa0eb1b..62c6622c3d 100644 --- a/server/service/service.go +++ b/server/service/service.go @@ -4,13 +4,9 @@ package service import ( "context" - "crypto/rand" "fmt" "html/template" - "math" - "math/big" "sync" - "sync/atomic" "time" "github.com/WatchBeam/clock" @@ -46,7 +42,8 @@ type Service struct { authz *authz.Authorizer - jitterSeed int64 + jitterMu *sync.Mutex + jitterH map[time.Duration]*jitterHashTable } // NewService creates a new service from the config struct @@ -87,39 +84,12 @@ func NewService( license: license, failingPolicySet: failingPolicySet, authz: authorizer, + jitterH: make(map[time.Duration]*jitterHashTable), + jitterMu: new(sync.Mutex), } - - // Try setting a first seed - svc.updateJitterSeedRand() - go svc.updateJitterSeed(ctx) - return validationMiddleware{svc, ds, sso}, nil } -func (s *Service) updateJitterSeedRand() { - nBig, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt)) - if err != nil { - panic(err) - } - n := nBig.Int64() - atomic.StoreInt64(&s.jitterSeed, n) -} - -func (s *Service) updateJitterSeed(ctx context.Context) { - for { - select { - case <-time.After(1 * time.Hour): - s.updateJitterSeedRand() - case <-ctx.Done(): - return - } - } -} - -func (s *Service) getJitterSeed() int64 { - return atomic.LoadInt64(&s.jitterSeed) -} - func (s Service) SendEmail(mail fleet.Email) error { return s.mailService.SendEmail(mail) } diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go index 6c08f3cd2e..c2d739ed63 100644 --- a/server/service/service_osquery.go +++ b/server/service/service_osquery.go @@ -7,6 +7,7 @@ import ( "fmt" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -447,15 +448,97 @@ func (svc *Service) detailQueriesForHost(ctx context.Context, host *fleet.Host) } func (svc *Service) shouldUpdate(lastUpdated time.Time, interval time.Duration, hostID uint) bool { - var jitter time.Duration - if svc.config.Osquery.MaxJitterPercent > 0 { - maxJitter := int64(svc.config.Osquery.MaxJitterPercent) * int64(interval) / 100.0 - jitter = time.Duration((int64(hostID) + svc.getJitterSeed()) % maxJitter) + svc.jitterMu.Lock() + defer svc.jitterMu.Unlock() + + if svc.jitterH[interval] == nil { + svc.jitterH[interval] = newJitterHashTable(int(int64(svc.config.Osquery.MaxJitterPercent) * int64(interval.Minutes()) / 100.0)) + level.Debug(svc.logger).Log("jitter", "created", "bucketCount", svc.jitterH[interval].bucketCount) } + + jitter := svc.jitterH[interval].jitterForHost(hostID) cutoff := svc.clock.Now().Add(-(interval + jitter)) return lastUpdated.Before(cutoff) } +// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value +// that is properly balanced. Balance in this context means that hosts would be distributed uniformly +// across the total jitter time so there are no spikes. +// The way this structure works is as follows: +// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to +// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should +// end up with 1 per bucket. +// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount +// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we +// increase the maxCapacity by 1 for all buckets, and start placing hosts. +// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is +// running. +// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the +// next one will be tried. If all buckets are full, then capacity gets increased and the bucket +// selection process restarts. +// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of +// minutes added to the host check in time. +// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to +// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case. +// In the worst possible case that all hosts start at the same time, max jitter percent can be set to +// 100, and this method will distribute hosts evenly. +// The main caveat of this approach is that it works at the fleet instance. So depending on what +// instance gets chosen by the load balancer, the jitter might be different. However, load tests have +// shown that the distribution in practice is pretty balance even when all hosts try to check in at +// the same time. +type jitterHashTable struct { + mu sync.Mutex + maxCapacity int + bucketCount int + buckets map[int]int + cache map[uint]time.Duration +} + +func newJitterHashTable(bucketCount int) *jitterHashTable { + if bucketCount == 0 { + bucketCount = 1 + } + return &jitterHashTable{ + maxCapacity: 1, + bucketCount: bucketCount, + buckets: make(map[int]int), + cache: make(map[uint]time.Duration), + } +} + +func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration { + // if no jitter is configured just return 0 + if jh.bucketCount <= 1 { + return 0 + } + + jh.mu.Lock() + if jitter, ok := jh.cache[hostID]; ok { + jh.mu.Unlock() + return jitter + } + + for i := 0; i < jh.bucketCount; i++ { + possibleBucket := (int(hostID) + i) % jh.bucketCount + + // if the next bucket has capacity, great! + if jh.buckets[possibleBucket] < jh.maxCapacity { + jh.buckets[possibleBucket]++ + jitter := time.Duration(possibleBucket) * time.Minute + jh.cache[hostID] = jitter + + jh.mu.Unlock() + return jitter + } + } + + // otherwise, bump the capacity and restart the process + jh.maxCapacity++ + + jh.mu.Unlock() + return jh.jitterForHost(hostID) +} + func (svc *Service) labelQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) { labelReportedAt := svc.task.GetHostLabelReportedAt(ctx, host) if !svc.shouldUpdate(labelReportedAt, svc.config.Osquery.LabelUpdateInterval, host.ID) && !host.RefetchRequested { diff --git a/server/service/service_osquery_test.go b/server/service/service_osquery_test.go index 41a36d3c50..03a43d21c4 100644 --- a/server/service/service_osquery_test.go +++ b/server/service/service_osquery_test.go @@ -3,10 +3,13 @@ package service import ( "bytes" "context" + crand "crypto/rand" "encoding/json" "errors" "fmt" "io/ioutil" + "math" + "math/big" "reflect" "sort" "strconv" @@ -283,7 +286,14 @@ func TestHostDetailQueries(t *testing.T) { UUID: "test_uuid", } - svc := &Service{clock: mockClock, config: config.TestConfig(), ds: ds} + svc := &Service{ + clock: mockClock, + logger: log.NewNopLogger(), + config: config.TestConfig(), + ds: ds, + jitterMu: new(sync.Mutex), + jitterH: make(map[time.Duration]*jitterHashTable), + } queries, err := svc.detailQueriesForHost(context.Background(), &host) require.NoError(t, err) @@ -2458,3 +2468,45 @@ func TestLiveQueriesFailing(t *testing.T) { require.Contains(t, string(logs), "level=error") require.Contains(t, string(logs), "failed to get queries for host") } + +func TestJitterForHost(t *testing.T) { + jh := newJitterHashTable(30) + + histogram := make(map[int64]int) + hostCount := 3000 + for i := 0; i < hostCount; i++ { + hostID, err := crand.Int(crand.Reader, big.NewInt(10000)) + require.NoError(t, err) + jitter := jh.jitterForHost(uint(hostID.Int64() + 10000)) + jitterMinutes := int64(jitter.Minutes()) + histogram[jitterMinutes]++ + } + min, max := math.MaxInt, 0 + for jitterMinutes, count := range histogram { + if count < min { + min = count + } + if count > max { + max = count + } + t.Logf("jitterMinutes=%d \t count=%d\n", jitterMinutes, count) + } + variation := max - min + t.Logf("min=%d \t max=%d \t variation=%d\n", min, max, variation) + + // check that variation is below 1% of the total amount of hosts + require.Less(t, variation, int(float32(hostCount)/0.01)) +} + +func TestNoJitter(t *testing.T) { + jh := newJitterHashTable(0) + + hostCount := 3000 + for i := 0; i < hostCount; i++ { + hostID, err := crand.Int(crand.Reader, big.NewInt(10000)) + require.NoError(t, err) + jitter := jh.jitterForHost(uint(hostID.Int64() + 10000)) + jitterMinutes := int64(jitter.Minutes()) + require.Equal(t, int64(0), jitterMinutes) + } +}