mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
parent
3755a58070
commit
9a0871a2f1
18 changed files with 876 additions and 311 deletions
2
.github/workflows/test-go.yaml
vendored
2
.github/workflows/test-go.yaml
vendored
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
# Pre-starting dependencies here means they are ready to go when we need them.
|
||||
- name: Start Infra Dependencies
|
||||
# Use & to background this
|
||||
run: docker-compose up -d mysql_test redis &
|
||||
run: docker-compose up -d mysql_test redis redis-cluster-1 redis-cluster-2 redis-cluster-3 redis-cluster-setup &
|
||||
|
||||
# It seems faster not to cache Go dependencies
|
||||
- name: Install Go Dependencies
|
||||
|
|
|
|||
1
changes/issue-1847-redis-cluster
Normal file
1
changes/issue-1847-redis-cluster
Normal file
|
|
@ -0,0 +1 @@
|
|||
* Fix Redis Cluster support.
|
||||
|
|
@ -27,6 +27,7 @@ import (
|
|||
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/s3"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/health"
|
||||
|
|
@ -199,12 +200,20 @@ the way that the Fleet server works.
|
|||
}
|
||||
}
|
||||
|
||||
redisPool, err := pubsub.NewRedisPool(config.Redis.Address, config.Redis.Password, config.Redis.Database, config.Redis.UseTLS)
|
||||
redisPool, err := redis.NewRedisPool(config.Redis.Address, config.Redis.Password, config.Redis.Database, config.Redis.UseTLS)
|
||||
if err != nil {
|
||||
initFatal(err, "initialize Redis")
|
||||
}
|
||||
resultStore := pubsub.NewRedisQueryResults(redisPool, config.Redis.DuplicateResults)
|
||||
liveQueryStore := live_query.NewRedisLiveQuery(redisPool)
|
||||
// TODO: should that only be done when a certain "migrate" flag is set,
|
||||
// to prevent affecting every startup?
|
||||
if err := liveQueryStore.MigrateKeys(); err != nil {
|
||||
level.Info(logger).Log(
|
||||
"err", err,
|
||||
"msg", "failed to migrate live query redis keys",
|
||||
)
|
||||
}
|
||||
ssoSessionStore := sso.NewSessionStore(redisPool)
|
||||
|
||||
osqueryLogger, err := logging.New(config, logger)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,50 @@ services:
|
|||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
redis-cluster-setup:
|
||||
image: redis:5
|
||||
command: redis-cli --cluster create 172.20.0.31:7001 172.20.0.32:7002 172.20.0.33:7003 --cluster-yes
|
||||
networks:
|
||||
cluster_network:
|
||||
ipv4_address: 172.20.0.30
|
||||
depends_on:
|
||||
- redis-cluster-1
|
||||
- redis-cluster-2
|
||||
- redis-cluster-3
|
||||
|
||||
redis-cluster-1:
|
||||
image: redis:5
|
||||
command: redis-server /usr/local/etc/redis/redis.conf
|
||||
ports:
|
||||
- '7001:7001'
|
||||
volumes:
|
||||
- ./tools/redis-tests/redis-cluster-1.conf:/usr/local/etc/redis/redis.conf
|
||||
networks:
|
||||
cluster_network:
|
||||
ipv4_address: 172.20.0.31
|
||||
|
||||
redis-cluster-2:
|
||||
image: redis:5
|
||||
command: redis-server /usr/local/etc/redis/redis.conf
|
||||
ports:
|
||||
- '7002:7002'
|
||||
volumes:
|
||||
- ./tools/redis-tests/redis-cluster-2.conf:/usr/local/etc/redis/redis.conf
|
||||
networks:
|
||||
cluster_network:
|
||||
ipv4_address: 172.20.0.32
|
||||
|
||||
redis-cluster-3:
|
||||
image: redis:5
|
||||
command: redis-server /usr/local/etc/redis/redis.conf
|
||||
ports:
|
||||
- '7003:7003'
|
||||
volumes:
|
||||
- ./tools/redis-tests/redis-cluster-3.conf:/usr/local/etc/redis/redis.conf
|
||||
networks:
|
||||
cluster_network:
|
||||
ipv4_address: 172.20.0.33
|
||||
|
||||
saml_idp:
|
||||
image: fleetdm/docker-idp:latest
|
||||
environment:
|
||||
|
|
@ -80,3 +124,10 @@ services:
|
|||
|
||||
volumes:
|
||||
mysql-persistent-volume:
|
||||
|
||||
networks:
|
||||
cluster_network:
|
||||
driver: bridge
|
||||
ipam:
|
||||
config:
|
||||
- subnet: 172.20.0.0/24
|
||||
|
|
|
|||
173
server/datastore/redis/redis.go
Normal file
173
server/datastore/redis/redis.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// this is an adapter type to implement the same Stats method as for
|
||||
// redisc.Cluster, so both can satisfy the same interface.
|
||||
type standalonePool struct {
|
||||
*redis.Pool
|
||||
addr string
|
||||
}
|
||||
|
||||
func (p *standalonePool) Stats() map[string]redis.PoolStats {
|
||||
return map[string]redis.PoolStats{
|
||||
p.addr: p.Pool.Stats(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewRedisPool creates a Redis connection pool using the provided server
|
||||
// address, password and database.
|
||||
func NewRedisPool(server, password string, database int, useTLS bool) (fleet.RedisPool, error) {
|
||||
cluster := newCluster(server, password, database, useTLS)
|
||||
if err := cluster.Refresh(); err != nil {
|
||||
if isClusterDisabled(err) || isClusterCommandUnknown(err) {
|
||||
// not a Redis Cluster setup, use a standalone Redis pool
|
||||
pool, _ := cluster.CreatePool(server)
|
||||
cluster.Close()
|
||||
return &standalonePool{pool, server}, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "refresh cluster")
|
||||
}
|
||||
|
||||
return cluster, nil
|
||||
}
|
||||
|
||||
// SplitRedisKeysBySlot takes a list of redis keys and groups them by hash slot
|
||||
// so that keys in a given group are guaranteed to hash to the same slot, making
|
||||
// them safe to run e.g. in a pipeline on the same connection or as part of a
|
||||
// multi-key command in a Redis Cluster setup. When using standalone Redis, it
|
||||
// simply returns all keys in the same group (i.e. the top-level slice has a
|
||||
// length of 1).
|
||||
func SplitRedisKeysBySlot(pool fleet.RedisPool, keys ...string) [][]string {
|
||||
if _, isCluster := pool.(*redisc.Cluster); isCluster {
|
||||
return redisc.SplitBySlot(keys...)
|
||||
}
|
||||
return [][]string{keys}
|
||||
}
|
||||
|
||||
// EachRedisNode calls fn for each node in the redis cluster, with a connection
|
||||
// to that node, until all nodes have been visited. The connection is
|
||||
// automatically closed after the call. If fn returns an error, the iteration
|
||||
// of nodes stops and EachRedisNode returns that error. For standalone redis,
|
||||
// fn is called only once.
|
||||
func EachRedisNode(pool fleet.RedisPool, fn func(conn redis.Conn) error) error {
|
||||
if cluster, isCluster := pool.(*redisc.Cluster); isCluster {
|
||||
addrs, err := getClusterPrimaryAddrs(cluster)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
err := func() error {
|
||||
// NOTE(mna): using CreatePool means that we respect the redis timeouts
|
||||
// and configs. This is a temporary pool as we can't reuse the
|
||||
// (internal) cluster pools for each host at the moment, would require
|
||||
// a change to redisc (one that would make sense to make for that
|
||||
// use-case of visiting each node, IMO).
|
||||
tempPool, err := cluster.CreatePool(addr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create pool")
|
||||
}
|
||||
defer tempPool.Close()
|
||||
|
||||
conn := tempPool.Get()
|
||||
defer conn.Close()
|
||||
return fn(conn)
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
conn := pool.Get()
|
||||
defer conn.Close()
|
||||
return fn(conn)
|
||||
}
|
||||
|
||||
func getClusterPrimaryAddrs(pool *redisc.Cluster) ([]string, error) {
|
||||
conn := pool.Get()
|
||||
defer conn.Close()
|
||||
nodes, err := redis.String(conn.Do("CLUSTER", "NODES"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get cluster nodes")
|
||||
}
|
||||
|
||||
var addrs []string
|
||||
s := bufio.NewScanner(strings.NewReader(nodes))
|
||||
for s.Scan() {
|
||||
fields := strings.Fields(s.Text())
|
||||
if len(fields) > 2 {
|
||||
flags := fields[2]
|
||||
if strings.Contains(flags, "master") {
|
||||
addrField := fields[1]
|
||||
if ix := strings.Index(addrField, "@"); ix >= 0 {
|
||||
addrField = addrField[:ix]
|
||||
}
|
||||
addrs = append(addrs, addrField)
|
||||
}
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func newCluster(server, password string, database int, useTLS bool) *redisc.Cluster {
|
||||
return &redisc.Cluster{
|
||||
StartupNodes: []string{server},
|
||||
CreatePool: func(server string, opts ...redis.DialOption) (*redis.Pool, error) {
|
||||
return &redis.Pool{
|
||||
MaxIdle: 3,
|
||||
IdleTimeout: 240 * time.Second,
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.Dial(
|
||||
"tcp",
|
||||
server,
|
||||
redis.DialDatabase(database),
|
||||
redis.DialUseTLS(useTLS),
|
||||
redis.DialConnectTimeout(5*time.Second),
|
||||
redis.DialKeepAlive(10*time.Second),
|
||||
// Read/Write timeouts not set here because we may see results
|
||||
// only rarely on the pub/sub channel.
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if password != "" {
|
||||
if _, err := c.Do("AUTH", password); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, err
|
||||
},
|
||||
TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
if time.Since(t) < time.Minute {
|
||||
return nil
|
||||
}
|
||||
_, err := c.Do("PING")
|
||||
return err
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func isClusterDisabled(err error) bool {
|
||||
return strings.Contains(err.Error(), "ERR This instance has cluster support disabled")
|
||||
}
|
||||
|
||||
// On GCP Memorystore the CLUSTER command is entirely unavailable and fails with
|
||||
// this error. See
|
||||
// https://cloud.google.com/memorystore/docs/redis/product-constraints#blocked_redis_commands
|
||||
func isClusterCommandUnknown(err error) bool {
|
||||
return strings.Contains(err.Error(), "ERR unknown command `CLUSTER`")
|
||||
}
|
||||
95
server/datastore/redis/redis_test.go
Normal file
95
server/datastore/redis/redis_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEachRedisNode(t *testing.T) {
|
||||
const prefix = "TestEachRedisNode:"
|
||||
|
||||
runTest := func(t *testing.T, pool fleet.RedisPool) {
|
||||
conn := pool.Get()
|
||||
defer conn.Close()
|
||||
if rc, err := redisc.RetryConn(conn, 3, 100*time.Millisecond); err == nil {
|
||||
conn = rc
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := conn.Do("SET", fmt.Sprintf("%s%d", prefix, i), i)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var keys []string
|
||||
err := EachRedisNode(pool, func(conn redis.Conn) error {
|
||||
var cursor int
|
||||
for {
|
||||
res, err := redis.Values(conn.Do("SCAN", cursor, "MATCH", prefix+"*"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var curKeys []string
|
||||
if _, err = redis.Scan(res, &cursor, &curKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
keys = append(keys, curKeys...)
|
||||
if cursor == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 10)
|
||||
}
|
||||
|
||||
t.Run("standalone", func(t *testing.T) {
|
||||
pool, teardown := setupRedisForTest(t, false)
|
||||
defer teardown()
|
||||
runTest(t, pool)
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
pool, teardown := setupRedisForTest(t, true)
|
||||
defer teardown()
|
||||
runTest(t, pool)
|
||||
})
|
||||
}
|
||||
|
||||
func setupRedisForTest(t *testing.T, cluster bool) (pool fleet.RedisPool, teardown func()) {
|
||||
var (
|
||||
addr = "127.0.0.1:"
|
||||
password = ""
|
||||
database = 0
|
||||
useTLS = false
|
||||
port = "6379"
|
||||
)
|
||||
if cluster {
|
||||
port = "7001"
|
||||
}
|
||||
addr += port
|
||||
|
||||
pool, err := NewRedisPool(addr, password, database, useTLS)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn := pool.Get()
|
||||
defer conn.Close()
|
||||
_, err = conn.Do("PING")
|
||||
require.Nil(t, err)
|
||||
|
||||
teardown = func() {
|
||||
err := EachRedisNode(pool, func(conn redis.Conn) error {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
pool.Close()
|
||||
}
|
||||
|
||||
return pool, teardown
|
||||
}
|
||||
11
server/fleet/redis_pool.go
Normal file
11
server/fleet/redis_pool.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package fleet
|
||||
|
||||
import "github.com/gomodule/redigo/redis"
|
||||
|
||||
// RedisPool is the common interface for redigo's Pool for standalone Redis
|
||||
// and redisc's Cluster for Redis Cluster.
|
||||
type RedisPool interface {
|
||||
Get() redis.Conn
|
||||
Close() error
|
||||
Stats() map[string]redis.PoolStats
|
||||
}
|
||||
|
|
@ -21,13 +21,31 @@
|
|||
// We believe that normal fleet usage has many hosts, and a small
|
||||
// number of live queries targeting all of them. This was a big
|
||||
// factor in choosing this implementation.
|
||||
//
|
||||
// Implementation
|
||||
//
|
||||
// As mentioned in the Design section, there are two keys for each
|
||||
// live query: the bitfield and the SQL of the query:
|
||||
//
|
||||
// livequery:<ID> is the bitfield that indicates the hosts
|
||||
// sql:livequery:<ID> is the SQL of the query.
|
||||
//
|
||||
// Both have an expiration, and <ID> is the campaign ID of the query. To make
|
||||
// efficient use of Redis Cluster (without impacting standalone Redis), the
|
||||
// <ID> is stored in braces (hash tags, e.g. livequery:{1} and
|
||||
// sql:livequery:{1}), so that the two keys for the same <ID> are always stored
|
||||
// on the same node (as they hash to the same cluster slot). See
|
||||
// https://redis.io/topics/cluster-spec#keys-hash-tags for details.
|
||||
//
|
||||
package live_query
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
|
@ -41,17 +59,111 @@ const (
|
|||
|
||||
type redisLiveQuery struct {
|
||||
// connection pool
|
||||
pool *redisc.Cluster
|
||||
pool fleet.RedisPool
|
||||
}
|
||||
|
||||
// NewRedisQueryResults creats a new Redis implementation of the
|
||||
// QueryResultStore interface using the provided Redis connection pool.
|
||||
func NewRedisLiveQuery(pool *redisc.Cluster) *redisLiveQuery {
|
||||
func NewRedisLiveQuery(pool fleet.RedisPool) *redisLiveQuery {
|
||||
return &redisLiveQuery{pool: pool}
|
||||
}
|
||||
|
||||
// generate keys for the bitfield and sql of a query - those always go in pair
|
||||
// and should live on the same cluster node when Redis Cluster is used, so
|
||||
// the common part of the key (the 'name' parameter) is used as key tag.
|
||||
func generateKeys(name string) (targetsKey, sqlKey string) {
|
||||
return queryKeyPrefix + name, sqlKeyPrefix + queryKeyPrefix + name
|
||||
keyTag := "{" + name + "}"
|
||||
return queryKeyPrefix + keyTag, sqlKeyPrefix + queryKeyPrefix + keyTag
|
||||
}
|
||||
|
||||
// returns the base name part of a target key, i.e. so that this is true:
|
||||
// tkey, _ := generateKeys(name)
|
||||
// baseName := extractTargetKeyName(tkey)
|
||||
// baseName == name
|
||||
func extractTargetKeyName(key string) string {
|
||||
name := strings.TrimPrefix(key, queryKeyPrefix)
|
||||
if len(name) > 0 && name[0] == '{' {
|
||||
name = name[1:]
|
||||
}
|
||||
if len(name) > 0 && name[len(name)-1] == '}' {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// MigrateKeys migrates keys using a deprecated format to the new format. It
|
||||
// should be called at startup and never after that, so for this reason it is
|
||||
// not added to the fleet.LiveQueryStore interface.
|
||||
func (r *redisLiveQuery) MigrateKeys() error {
|
||||
qkeys, err := scanKeys(r.pool, queryKeyPrefix+"*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// identify which of those keys are in a deprecated format
|
||||
var oldKeys []string
|
||||
for _, key := range qkeys {
|
||||
name := extractTargetKeyName(key)
|
||||
if !strings.Contains(key, "{"+name+"}") {
|
||||
// add the corresponding sql key to the list
|
||||
oldKeys = append(oldKeys, key, sqlKeyPrefix+key)
|
||||
}
|
||||
}
|
||||
|
||||
keysBySlot := redis.SplitRedisKeysBySlot(r.pool, oldKeys...)
|
||||
for _, keys := range keysBySlot {
|
||||
if err := migrateBatchKeys(r.pool, keys); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateBatchKeys(pool fleet.RedisPool, keys []string) error {
|
||||
readConn := pool.Get()
|
||||
defer readConn.Close()
|
||||
|
||||
writeConn := pool.Get()
|
||||
defer writeConn.Close()
|
||||
|
||||
// use a retry conn so that we follow MOVED redirections in a Redis Cluster,
|
||||
// as we will attempt to write new keys which may not belong to the same
|
||||
// cluster slot. It returns an error if writeConn is not a redis cluster
|
||||
// connection, in which case we simply continue with the standalone Redis
|
||||
// writeConn.
|
||||
if rc, err := redisc.RetryConn(writeConn, 3, 100*time.Millisecond); err == nil {
|
||||
writeConn = rc
|
||||
}
|
||||
|
||||
// using a straightforward "read one, write one" approach as this is meant to
|
||||
// run at startup, not on a hot path, and we expect a relatively small number
|
||||
// of queries vs hosts (as documented in the design comment at the top).
|
||||
for _, key := range keys {
|
||||
s, err := redigo.String(readConn.Do("GET", key))
|
||||
if err != nil {
|
||||
if err == redigo.ErrNil {
|
||||
// key may have expired since the scan, ignore
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var newKey string
|
||||
if strings.HasPrefix(key, sqlKeyPrefix) {
|
||||
name := extractTargetKeyName(strings.TrimPrefix(key, sqlKeyPrefix))
|
||||
_, newKey = generateKeys(name)
|
||||
} else {
|
||||
name := extractTargetKeyName(key)
|
||||
newKey, _ = generateKeys(name)
|
||||
}
|
||||
if _, err := writeConn.Do("SET", newKey, s, "EX", queryExpiration.Seconds()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// best-effort deletion of the old key, ignore error
|
||||
readConn.Do("DEL", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redisLiveQuery) RunQuery(name, sql string, hostIDs []uint) error {
|
||||
|
|
@ -94,68 +206,76 @@ func (r *redisLiveQuery) StopQuery(name string) error {
|
|||
}
|
||||
|
||||
func (r *redisLiveQuery) QueriesForHost(hostID uint) (map[string]string, error) {
|
||||
conn := r.pool.Get()
|
||||
defer conn.Close()
|
||||
|
||||
// Get keys for active queries
|
||||
queryKeys, err := scanKeys(conn, queryKeyPrefix+"*")
|
||||
queryKeys, err := scanKeys(r.pool, queryKeyPrefix+"*")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "scan active queries")
|
||||
}
|
||||
|
||||
keysBySlot := redis.SplitRedisKeysBySlot(r.pool, queryKeys...)
|
||||
queries := make(map[string]string)
|
||||
for _, qkeys := range keysBySlot {
|
||||
if err := r.collectBatchQueriesForHost(hostID, qkeys, queries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
func (r *redisLiveQuery) collectBatchQueriesForHost(hostID uint, queryKeys []string, queriesByHost map[string]string) error {
|
||||
conn := r.pool.Get()
|
||||
defer conn.Close()
|
||||
|
||||
// Pipeline redis calls to check for this host in the bitfield of the
|
||||
// targets of the query.
|
||||
for _, key := range queryKeys {
|
||||
if err := conn.Send("GETBIT", key, hostID); err != nil {
|
||||
return nil, errors.Wrap(err, "getbit query targets")
|
||||
return errors.Wrap(err, "getbit query targets")
|
||||
}
|
||||
|
||||
// Additionally get SQL even though we don't yet know whether this query
|
||||
// is targeted to the host. This allows us to avoid an additional
|
||||
// roundtrip to the Redis server and likely has little cost due to the
|
||||
// small number of queries and limited size of SQL
|
||||
if err = conn.Send("GET", sqlKeyPrefix+key); err != nil {
|
||||
return nil, errors.Wrap(err, "get query sql")
|
||||
if err := conn.Send("GET", sqlKeyPrefix+key); err != nil {
|
||||
return errors.Wrap(err, "get query sql")
|
||||
}
|
||||
}
|
||||
|
||||
// Flush calls to begin receiving results.
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, errors.Wrap(err, "flush pipeline")
|
||||
return errors.Wrap(err, "flush pipeline")
|
||||
}
|
||||
|
||||
// Receive target and SQL in order of pipelined calls.
|
||||
queries := make(map[string]string)
|
||||
for _, key := range queryKeys {
|
||||
name := strings.TrimPrefix(key, queryKeyPrefix)
|
||||
name := extractTargetKeyName(key)
|
||||
|
||||
targeted, err := redis.Int(conn.Receive())
|
||||
targeted, err := redigo.Int(conn.Receive())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "receive target")
|
||||
return errors.Wrap(err, "receive target")
|
||||
}
|
||||
|
||||
// Be sure to read SQL even if we are not going to include this query.
|
||||
// Otherwise we will read an incorrect number of returned results from
|
||||
// the pipeline.
|
||||
sql, err := redis.String(conn.Receive())
|
||||
sql, err := redigo.String(conn.Receive())
|
||||
if err != nil {
|
||||
// Not being able to get the sql for a matched could mean things
|
||||
// Not being able to get the sql for a matched query could mean things
|
||||
// have ended up in a weird state. Or it could be that the query was
|
||||
// stopped since we did the key scan. In any case, attempt to clean
|
||||
// up here.
|
||||
_ = r.StopQuery(name)
|
||||
return nil, errors.Wrap(err, "receive sql")
|
||||
return errors.Wrap(err, "receive sql")
|
||||
}
|
||||
|
||||
if targeted == 0 {
|
||||
// Host not targeted with this query
|
||||
continue
|
||||
}
|
||||
|
||||
queries[name] = sql
|
||||
queriesByHost[name] = sql
|
||||
}
|
||||
|
||||
return queries, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redisLiveQuery) QueryCompletedByHost(name string, hostID uint) error {
|
||||
|
|
@ -195,23 +315,29 @@ func mapBitfield(hostIDs []uint) []byte {
|
|||
return field
|
||||
}
|
||||
|
||||
func scanKeys(conn redis.Conn, pattern string) ([]string, error) {
|
||||
func scanKeys(pool fleet.RedisPool, pattern string) ([]string, error) {
|
||||
var keys []string
|
||||
cursor := 0
|
||||
for {
|
||||
res, err := redis.Values(conn.Do("SCAN", cursor, "MATCH", pattern))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "scan keys")
|
||||
}
|
||||
var curKeys []string
|
||||
_, err = redis.Scan(res, &cursor, &curKeys)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert scan results")
|
||||
}
|
||||
keys = append(keys, curKeys...)
|
||||
if cursor == 0 {
|
||||
break
|
||||
|
||||
err := redis.EachRedisNode(pool, func(conn redigo.Conn) error {
|
||||
cursor := 0
|
||||
for {
|
||||
res, err := redigo.Values(conn.Do("SCAN", cursor, "MATCH", pattern))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "scan keys")
|
||||
}
|
||||
var curKeys []string
|
||||
_, err = redigo.Scan(res, &cursor, &curKeys)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "convert scan results")
|
||||
}
|
||||
keys = append(keys, curKeys...)
|
||||
if cursor == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ package live_query
|
|||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -12,30 +15,117 @@ import (
|
|||
func TestRedisLiveQuery(t *testing.T) {
|
||||
for _, f := range testFunctions {
|
||||
t.Run(test.FunctionName(f), func(t *testing.T) {
|
||||
store, teardown := setupRedisLiveQuery(t)
|
||||
defer teardown()
|
||||
f(t, store)
|
||||
t.Run("standalone", func(t *testing.T) {
|
||||
store, teardown := setupRedisLiveQuery(t, false)
|
||||
defer teardown()
|
||||
f(t, store)
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
store, teardown := setupRedisLiveQuery(t, true)
|
||||
defer teardown()
|
||||
f(t, store)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupRedisLiveQuery(t *testing.T) (store *redisLiveQuery, teardown func()) {
|
||||
func TestMigrateKeys(t *testing.T) {
|
||||
startKeys := map[string]string{
|
||||
"unrelated": "u",
|
||||
queryKeyPrefix + "a": "a",
|
||||
sqlKeyPrefix + queryKeyPrefix + "a": "sqla",
|
||||
queryKeyPrefix + "b": "b",
|
||||
queryKeyPrefix + "{c}": "c",
|
||||
sqlKeyPrefix + queryKeyPrefix + "{c}": "sqlc",
|
||||
}
|
||||
|
||||
endKeys := map[string]string{
|
||||
"unrelated": "u",
|
||||
queryKeyPrefix + "{a}": "a",
|
||||
sqlKeyPrefix + queryKeyPrefix + "{a}": "sqla",
|
||||
queryKeyPrefix + "{b}": "b",
|
||||
queryKeyPrefix + "{c}": "c",
|
||||
sqlKeyPrefix + queryKeyPrefix + "{c}": "sqlc",
|
||||
}
|
||||
|
||||
runTest := func(t *testing.T, store *redisLiveQuery) {
|
||||
conn := store.pool.Get()
|
||||
defer conn.Close()
|
||||
if rc, err := redisc.RetryConn(conn, 3, 100*time.Millisecond); err == nil {
|
||||
conn = rc
|
||||
}
|
||||
|
||||
for k, v := range startKeys {
|
||||
_, err := conn.Do("SET", k, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := store.MigrateKeys()
|
||||
require.NoError(t, err)
|
||||
|
||||
got := make(map[string]string)
|
||||
err = redis.EachRedisNode(store.pool, func(conn redigo.Conn) error {
|
||||
keys, err := redigo.Strings(conn.Do("KEYS", "*"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
v, err := redigo.String(conn.Do("GET", k))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
got[k] = v
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.EqualValues(t, endKeys, got)
|
||||
}
|
||||
|
||||
t.Run("standalone", func(t *testing.T) {
|
||||
store, teardown := setupRedisLiveQuery(t, false)
|
||||
defer teardown()
|
||||
runTest(t, store)
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
store, teardown := setupRedisLiveQuery(t, true)
|
||||
defer teardown()
|
||||
runTest(t, store)
|
||||
})
|
||||
}
|
||||
|
||||
func setupRedisLiveQuery(t *testing.T, cluster bool) (store *redisLiveQuery, teardown func()) {
|
||||
var (
|
||||
addr = "127.0.0.1:6379"
|
||||
addr = "127.0.0.1:"
|
||||
password = ""
|
||||
database = 0
|
||||
useTLS = false
|
||||
port = "6379"
|
||||
)
|
||||
if cluster {
|
||||
port = "7001"
|
||||
}
|
||||
addr += port
|
||||
|
||||
pool, err := pubsub.NewRedisPool(addr, password, database, useTLS)
|
||||
pool, err := redis.NewRedisPool(addr, password, database, useTLS)
|
||||
require.NoError(t, err)
|
||||
store = NewRedisLiveQuery(pool)
|
||||
|
||||
_, err = store.pool.Get().Do("PING")
|
||||
conn := store.pool.Get()
|
||||
defer conn.Close()
|
||||
_, err = conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
|
||||
teardown = func() {
|
||||
store.pool.Get().Do("FLUSHDB")
|
||||
err := redis.EachRedisNode(store.pool, func(conn redigo.Conn) error {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
store.pool.Close()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,205 +27,227 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
|||
}
|
||||
|
||||
func TestQueryResultsStoreErrors(t *testing.T) {
|
||||
store, teardown := SetupRedisForTest(t)
|
||||
defer teardown()
|
||||
|
||||
// Write with no subscriber
|
||||
err := store.WriteResult(
|
||||
fleet.DistributedQueryResult{
|
||||
DistributedQueryCampaignID: 9999,
|
||||
Rows: []map[string]string{{"bing": "fds"}},
|
||||
Host: fleet.Host{
|
||||
ID: 4,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
runTest := func(t *testing.T, store *redisQueryResults) {
|
||||
// Write with no subscriber
|
||||
err := store.WriteResult(
|
||||
fleet.DistributedQueryResult{
|
||||
DistributedQueryCampaignID: 9999,
|
||||
Rows: []map[string]string{{"bing": "fds"}},
|
||||
Host: fleet.Host{
|
||||
ID: 4,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
)
|
||||
assert.NotNil(t, err)
|
||||
castErr, ok := err.(Error)
|
||||
if assert.True(t, ok, "err should be pubsub.Error") {
|
||||
assert.True(t, castErr.NoSubscriber(), "NoSubscriber() should be true")
|
||||
)
|
||||
assert.NotNil(t, err)
|
||||
castErr, ok := err.(Error)
|
||||
if assert.True(t, ok, "err should be pubsub.Error") {
|
||||
assert.True(t, castErr.NoSubscriber(), "NoSubscriber() should be true")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("standalone", func(t *testing.T) {
|
||||
store, teardown := SetupRedisForTest(t, false)
|
||||
defer teardown()
|
||||
runTest(t, store)
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
store, teardown := SetupRedisForTest(t, true)
|
||||
defer teardown()
|
||||
runTest(t, store)
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryResultsStore(t *testing.T) {
|
||||
store, teardown := SetupRedisForTest(t)
|
||||
defer teardown()
|
||||
runTest := func(t *testing.T, store *redisQueryResults) {
|
||||
// Test handling results for two campaigns in parallel
|
||||
campaign1 := fleet.DistributedQueryCampaign{ID: 1}
|
||||
|
||||
// Test handling results for two campaigns in parallel
|
||||
campaign1 := fleet.DistributedQueryCampaign{ID: 1}
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
channel1, err := store.ReadChannel(ctx1, campaign1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
channel1, err := store.ReadChannel(ctx1, campaign1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expected1 := []fleet.DistributedQueryResult{
|
||||
{
|
||||
DistributedQueryCampaignID: 1,
|
||||
Rows: []map[string]string{{"foo": "bar"}},
|
||||
Host: fleet.Host{
|
||||
ID: 1,
|
||||
// Note these times need to be set to avoid
|
||||
// issues with roundtrip serializing the zero
|
||||
// time value. See https://goo.gl/CCEs8x
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
expected1 := []fleet.DistributedQueryResult{
|
||||
{
|
||||
DistributedQueryCampaignID: 1,
|
||||
Rows: []map[string]string{{"foo": "bar"}},
|
||||
Host: fleet.Host{
|
||||
ID: 1,
|
||||
// Note these times need to be set to avoid
|
||||
// issues with roundtrip serializing the zero
|
||||
// time value. See https://goo.gl/CCEs8x
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
{
|
||||
DistributedQueryCampaignID: 1,
|
||||
Rows: []map[string]string{{"whoo": "wahh"}},
|
||||
Host: fleet.Host{
|
||||
ID: 3,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
{
|
||||
DistributedQueryCampaignID: 1,
|
||||
Rows: []map[string]string{{"whoo": "wahh"}},
|
||||
Host: fleet.Host{
|
||||
ID: 3,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
{
|
||||
DistributedQueryCampaignID: 1,
|
||||
Rows: []map[string]string{{"bing": "fds"}},
|
||||
Host: fleet.Host{
|
||||
ID: 4,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
{
|
||||
DistributedQueryCampaignID: 1,
|
||||
Rows: []map[string]string{{"bing": "fds"}},
|
||||
Host: fleet.Host{
|
||||
ID: 4,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
campaign2 := fleet.DistributedQueryCampaign{ID: 2}
|
||||
campaign2 := fleet.DistributedQueryCampaign{ID: 2}
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
channel2, err := store.ReadChannel(ctx2, campaign2)
|
||||
assert.Nil(t, err)
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
channel2, err := store.ReadChannel(ctx2, campaign2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expected2 := []fleet.DistributedQueryResult{
|
||||
{
|
||||
DistributedQueryCampaignID: 2,
|
||||
Rows: []map[string]string{{"tim": "tom"}},
|
||||
Host: fleet.Host{
|
||||
ID: 1,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
expected2 := []fleet.DistributedQueryResult{
|
||||
{
|
||||
DistributedQueryCampaignID: 2,
|
||||
Rows: []map[string]string{{"tim": "tom"}},
|
||||
Host: fleet.Host{
|
||||
ID: 1,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
{
|
||||
DistributedQueryCampaignID: 2,
|
||||
Rows: []map[string]string{{"slim": "slam"}},
|
||||
Host: fleet.Host{
|
||||
ID: 3,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
{
|
||||
DistributedQueryCampaignID: 2,
|
||||
Rows: []map[string]string{{"slim": "slam"}},
|
||||
Host: fleet.Host{
|
||||
ID: 3,
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
},
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
|
||||
DetailUpdatedAt: time.Now().UTC(),
|
||||
SeenTime: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var results1, results2 []fleet.DistributedQueryResult
|
||||
var results1, results2 []fleet.DistributedQueryResult
|
||||
|
||||
var readerWg, writerWg sync.WaitGroup
|
||||
var readerWg, writerWg sync.WaitGroup
|
||||
|
||||
readerWg.Add(1)
|
||||
go func() {
|
||||
defer readerWg.Done()
|
||||
for res := range channel1 {
|
||||
switch res := res.(type) {
|
||||
case fleet.DistributedQueryResult:
|
||||
results1 = append(results1, res)
|
||||
readerWg.Add(1)
|
||||
go func() {
|
||||
defer readerWg.Done()
|
||||
for res := range channel1 {
|
||||
switch res := res.(type) {
|
||||
case fleet.DistributedQueryResult:
|
||||
results1 = append(results1, res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
readerWg.Add(1)
|
||||
go func() {
|
||||
defer readerWg.Done()
|
||||
for res := range channel2 {
|
||||
switch res := res.(type) {
|
||||
case fleet.DistributedQueryResult:
|
||||
results2 = append(results2, res)
|
||||
}()
|
||||
readerWg.Add(1)
|
||||
go func() {
|
||||
defer readerWg.Done()
|
||||
for res := range channel2 {
|
||||
switch res := res.(type) {
|
||||
case fleet.DistributedQueryResult:
|
||||
results2 = append(results2, res)
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
// Wait to ensure subscriptions are activated before writing
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
writerWg.Add(1)
|
||||
go func() {
|
||||
defer writerWg.Done()
|
||||
for _, res := range expected1 {
|
||||
assert.Nil(t, store.WriteResult(res))
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
cancel1()
|
||||
}()
|
||||
writerWg.Add(1)
|
||||
go func() {
|
||||
defer writerWg.Done()
|
||||
for _, res := range expected2 {
|
||||
assert.Nil(t, store.WriteResult(res))
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
cancel2()
|
||||
}()
|
||||
|
||||
// wait with a timeout to ensure that the test can't hang
|
||||
if waitTimeout(&writerWg, 5*time.Second) {
|
||||
t.Error("Timed out waiting for writers to join")
|
||||
}
|
||||
if waitTimeout(&readerWg, 5*time.Second) {
|
||||
t.Error("Timed out waiting for readers to join")
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
// Wait to ensure subscriptions are activated before writing
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
writerWg.Add(1)
|
||||
go func() {
|
||||
defer writerWg.Done()
|
||||
for _, res := range expected1 {
|
||||
assert.Nil(t, store.WriteResult(res))
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
cancel1()
|
||||
}()
|
||||
writerWg.Add(1)
|
||||
go func() {
|
||||
defer writerWg.Done()
|
||||
for _, res := range expected2 {
|
||||
assert.Nil(t, store.WriteResult(res))
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
cancel2()
|
||||
}()
|
||||
|
||||
// wait with a timeout to ensure that the test can't hang
|
||||
if waitTimeout(&writerWg, 5*time.Second) {
|
||||
t.Error("Timed out waiting for writers to join")
|
||||
}
|
||||
if waitTimeout(&readerWg, 5*time.Second) {
|
||||
t.Error("Timed out waiting for readers to join")
|
||||
assert.EqualValues(t, expected1, results1)
|
||||
assert.EqualValues(t, expected2, results2)
|
||||
}
|
||||
|
||||
assert.EqualValues(t, expected1, results1)
|
||||
assert.EqualValues(t, expected2, results2)
|
||||
t.Run("standalone", func(t *testing.T) {
|
||||
store, teardown := SetupRedisForTest(t, false)
|
||||
defer teardown()
|
||||
runTest(t, store)
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
store, teardown := SetupRedisForTest(t, true)
|
||||
defer teardown()
|
||||
runTest(t, store)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,87 +4,24 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type redisQueryResults struct {
|
||||
// connection pool
|
||||
pool *redisc.Cluster
|
||||
pool fleet.RedisPool
|
||||
duplicateResults bool
|
||||
}
|
||||
|
||||
var _ fleet.QueryResultStore = &redisQueryResults{}
|
||||
|
||||
// NewRedisPool creates a Redis connection pool using the provided server
|
||||
// address, password and database.
|
||||
func NewRedisPool(server, password string, database int, useTLS bool) (*redisc.Cluster, error) {
|
||||
// Create the Cluster
|
||||
cluster := &redisc.Cluster{
|
||||
StartupNodes: []string{server},
|
||||
CreatePool: func(server string, opts ...redis.DialOption) (*redis.Pool, error) {
|
||||
return &redis.Pool{
|
||||
MaxIdle: 3,
|
||||
IdleTimeout: 240 * time.Second,
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.Dial(
|
||||
"tcp",
|
||||
server,
|
||||
redis.DialDatabase(database),
|
||||
redis.DialUseTLS(useTLS),
|
||||
redis.DialConnectTimeout(5*time.Second),
|
||||
redis.DialKeepAlive(10*time.Second),
|
||||
// Read/Write timeouts not set here because we may see results
|
||||
// only rarely on the pub/sub channel.
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if password != "" {
|
||||
if _, err := c.Do("AUTH", password); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, err
|
||||
},
|
||||
TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
if time.Since(t) < time.Minute {
|
||||
return nil
|
||||
}
|
||||
_, err := c.Do("PING")
|
||||
return err
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := cluster.Refresh(); err != nil && !isClusterDisabled(err) && !isClusterCommandUnknown(err) {
|
||||
return nil, errors.Wrap(err, "refresh cluster")
|
||||
}
|
||||
|
||||
return cluster, nil
|
||||
}
|
||||
|
||||
func isClusterDisabled(err error) bool {
|
||||
return strings.Contains(err.Error(), "ERR This instance has cluster support disabled")
|
||||
}
|
||||
|
||||
// On GCP Memorystore the CLUSTER command is entirely unavailable and fails with
|
||||
// this error. See
|
||||
// https://cloud.google.com/memorystore/docs/redis/product-constraints#blocked_redis_commands
|
||||
func isClusterCommandUnknown(err error) bool {
|
||||
return strings.Contains(err.Error(), "ERR unknown command `CLUSTER`")
|
||||
}
|
||||
|
||||
// NewRedisQueryResults creats a new Redis implementation of the
|
||||
// QueryResultStore interface using the provided Redis connection pool.
|
||||
func NewRedisQueryResults(pool *redisc.Cluster, duplicateResults bool) *redisQueryResults {
|
||||
func NewRedisQueryResults(pool fleet.RedisPool, duplicateResults bool) *redisQueryResults {
|
||||
return &redisQueryResults{pool: pool, duplicateResults: duplicateResults}
|
||||
}
|
||||
|
||||
|
|
@ -93,7 +30,7 @@ func pubSubForID(id uint) string {
|
|||
}
|
||||
|
||||
// Pool returns the redisc connection pool (used in tests).
|
||||
func (r *redisQueryResults) Pool() *redisc.Cluster {
|
||||
func (r *redisQueryResults) Pool() fleet.RedisPool {
|
||||
return r.pool
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,26 +3,40 @@ package pubsub
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
redigo "github.com/gomodule/redigo/redis"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func SetupRedisForTest(t *testing.T) (store *redisQueryResults, teardown func()) {
|
||||
func SetupRedisForTest(t *testing.T, cluster bool) (store *redisQueryResults, teardown func()) {
|
||||
var (
|
||||
addr = "127.0.0.1:6379"
|
||||
addr = "127.0.0.1:"
|
||||
password = ""
|
||||
database = 0
|
||||
useTLS = false
|
||||
dupResults = false
|
||||
port = "6379"
|
||||
)
|
||||
if cluster {
|
||||
port = "7001"
|
||||
}
|
||||
addr += port
|
||||
|
||||
pool, err := NewRedisPool(addr, password, database, useTLS)
|
||||
pool, err := redis.NewRedisPool(addr, password, database, useTLS)
|
||||
require.NoError(t, err)
|
||||
store = NewRedisQueryResults(pool, dupResults)
|
||||
|
||||
_, err = store.pool.Get().Do("PING")
|
||||
conn := store.pool.Get()
|
||||
defer conn.Close()
|
||||
_, err = conn.Do("PING")
|
||||
require.Nil(t, err)
|
||||
|
||||
teardown = func() {
|
||||
err := redis.EachRedisNode(store.pool, func(conn redigo.Conn) error {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
store.pool.Close()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) {
|
||||
t.Skip("Seems to be a bit problematic in CI")
|
||||
|
||||
store, teardown := pubsub.SetupRedisForTest(t)
|
||||
store, teardown := pubsub.SetupRedisForTest(t, false)
|
||||
defer teardown()
|
||||
|
||||
mockClock := clock.NewMockClock()
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import (
|
|||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -34,12 +34,12 @@ type SessionStore interface {
|
|||
}
|
||||
|
||||
// NewSessionStore creates a SessionStore
|
||||
func NewSessionStore(pool *redisc.Cluster) SessionStore {
|
||||
func NewSessionStore(pool fleet.RedisPool) SessionStore {
|
||||
return &store{pool}
|
||||
}
|
||||
|
||||
type store struct {
|
||||
pool *redisc.Cluster
|
||||
pool fleet.RedisPool
|
||||
}
|
||||
|
||||
func (s *store) create(requestID, originalURL, metadata string, lifetimeSecs uint) error {
|
||||
|
|
|
|||
|
|
@ -5,24 +5,31 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
"github.com/mna/redisc"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newPool(t *testing.T) *redisc.Cluster {
|
||||
func newPool(t *testing.T, cluster bool) fleet.RedisPool {
|
||||
if _, ok := os.LookupEnv("REDIS_TEST"); ok {
|
||||
var (
|
||||
addr = "127.0.0.1:6379"
|
||||
addr = "127.0.0.1:"
|
||||
password = ""
|
||||
database = 0
|
||||
useTLS = false
|
||||
port = "6379"
|
||||
)
|
||||
if cluster {
|
||||
port = "7001"
|
||||
}
|
||||
addr += port
|
||||
|
||||
pool, err := pubsub.NewRedisPool(addr, password, database, useTLS)
|
||||
pool, err := redis.NewRedisPool(addr, password, database, useTLS)
|
||||
require.NoError(t, err)
|
||||
_, err = pool.Get().Do("PING")
|
||||
conn := pool.Get()
|
||||
defer conn.Close()
|
||||
_, err = conn.Do("PING")
|
||||
require.Nil(t, err)
|
||||
return pool
|
||||
}
|
||||
|
|
@ -33,22 +40,36 @@ func TestSessionStore(t *testing.T) {
|
|||
if _, ok := os.LookupEnv("REDIS_TEST"); !ok {
|
||||
t.Skip("skipping sso session store tests")
|
||||
}
|
||||
p := newPool(t)
|
||||
require.NotNil(t, p)
|
||||
defer p.Close()
|
||||
store := NewSessionStore(p)
|
||||
require.NotNil(t, store)
|
||||
// Create session that lives for 1 second.
|
||||
err := store.create("request123", "https://originalurl.com", "some metadata", 1)
|
||||
require.Nil(t, err)
|
||||
sess, err := store.Get("request123")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, "https://originalurl.com", sess.OriginalURL)
|
||||
assert.Equal(t, "some metadata", sess.Metadata)
|
||||
// Wait a little bit more than one second, session should no longer be present.
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
sess, err = store.Get("request123")
|
||||
assert.Equal(t, ErrSessionNotFound, err)
|
||||
assert.Nil(t, sess)
|
||||
|
||||
runTest := func(t *testing.T, pool fleet.RedisPool) {
|
||||
store := NewSessionStore(pool)
|
||||
require.NotNil(t, store)
|
||||
// Create session that lives for 1 second.
|
||||
err := store.create("request123", "https://originalurl.com", "some metadata", 1)
|
||||
require.Nil(t, err)
|
||||
sess, err := store.Get("request123")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, "https://originalurl.com", sess.OriginalURL)
|
||||
assert.Equal(t, "some metadata", sess.Metadata)
|
||||
// Wait a little bit more than one second, session should no longer be present.
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
sess, err = store.Get("request123")
|
||||
assert.Equal(t, ErrSessionNotFound, err)
|
||||
assert.Nil(t, sess)
|
||||
}
|
||||
|
||||
t.Run("standalone", func(t *testing.T) {
|
||||
p := newPool(t, false)
|
||||
require.NotNil(t, p)
|
||||
defer p.Close()
|
||||
runTest(t, p)
|
||||
})
|
||||
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
p := newPool(t, true)
|
||||
require.NotNil(t, p)
|
||||
defer p.Close()
|
||||
runTest(t, p)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
5
tools/redis-tests/redis-cluster-1.conf
Normal file
5
tools/redis-tests/redis-cluster-1.conf
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
port 7001
|
||||
cluster-enabled yes
|
||||
cluster-config-file nodes.conf
|
||||
cluster-node-timeout 5000
|
||||
appendonly yes
|
||||
5
tools/redis-tests/redis-cluster-2.conf
Normal file
5
tools/redis-tests/redis-cluster-2.conf
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
port 7002
|
||||
cluster-enabled yes
|
||||
cluster-config-file nodes.conf
|
||||
cluster-node-timeout 5000
|
||||
appendonly yes
|
||||
5
tools/redis-tests/redis-cluster-3.conf
Normal file
5
tools/redis-tests/redis-cluster-3.conf
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
port 7003
|
||||
cluster-enabled yes
|
||||
cluster-config-file nodes.conf
|
||||
cluster-node-timeout 5000
|
||||
appendonly yes
|
||||
Loading…
Reference in a new issue