diff --git a/changes/issue-2495-redis-read-replica b/changes/issue-2495-redis-read-replica new file mode 100644 index 0000000000..09bbcadf14 --- /dev/null +++ b/changes/issue-2495-redis-read-replica @@ -0,0 +1 @@ +* Add `cluster_read_from_replica` configuration option to `fleet` command. diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index 1c849fa4fc..1e899c81a8 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -202,7 +202,7 @@ the way that the Fleet server works. } } - redisPool, err := redis.NewRedisPool(redis.PoolConfig{ + redisPool, err := redis.NewPool(redis.PoolConfig{ Server: config.Redis.Address, Password: config.Redis.Password, Database: config.Redis.Database, @@ -211,6 +211,7 @@ the way that the Fleet server works. KeepAlive: config.Redis.KeepAlive, ConnectRetryAttempts: config.Redis.ConnectRetryAttempts, ClusterFollowRedirections: config.Redis.ClusterFollowRedirections, + ClusterReadFromReplica: config.Redis.ClusterReadFromReplica, }) if err != nil { initFatal(err, "initialize Redis") diff --git a/docs/02-Deploying/02-Configuration.md b/docs/02-Deploying/02-Configuration.md index 5b7598f74c..fb37c2f858 100644 --- a/docs/02-Deploying/02-Configuration.md +++ b/docs/02-Deploying/02-Configuration.md @@ -416,6 +416,20 @@ handled transparently instead of ending in an error. cluster_follow_redirections: true ``` +##### redis_cluster_read_from_replica + +Whether or not to prefer reading from a replica when possible. Applies only +to Redis Cluster setups, ignored in standalone Redis. + +- Default value: false +- Environment variable: `FLEET_REDIS_CLUSTER_READ_FROM_REPLICA` +- Config file format: + + ``` + redis: + cluster_read_from_replica: true + ``` + #### Server ##### server_address @@ -681,7 +695,7 @@ Valid time units are `s`, `m`, `h`. osquery: label_update_interval: 30m ``` - + ##### osquery_policy_update_interval The interval at which Fleet will ask osquery agents to update their results for policy queries. @@ -1346,7 +1360,7 @@ AWS STS role ARN to use for S3 authentication. ##### s3_endpoint_url -AWS S3 Endpoint URL. Override when using a different S3 compatible object storage backend (such as Minio), +AWS S3 Endpoint URL. Override when using a different S3 compatible object storage backend (such as Minio), or running s3 locally with localstack. Leave this blank to use the default S3 service endpoint. - Default value: none @@ -1376,7 +1390,7 @@ AWS S3 Disable SSL. Useful for local testing. AWS S3 Force S3 Path Style. Set this to `true` to force the request to use path-style addressing, i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will use virtual hosted bucket addressing when possible -(`http://BUCKET.s3.amazonaws.com/KEY`). +(`http://BUCKET.s3.amazonaws.com/KEY`). See [here](http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html) for details. @@ -1393,7 +1407,7 @@ See [here](http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html) f AWS S3 Region. Leave blank to enable region discovery. -- Default value: +- Default value: - Environment variable: `FLEET_S3_REGION` - Config file format: diff --git a/server/config/config.go b/server/config/config.go index 81e99ae62c..251e8dacf8 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -46,6 +46,7 @@ type RedisConfig struct { KeepAlive time.Duration `yaml:"keep_alive"` ConnectRetryAttempts int `yaml:"connect_retry_attempts"` ClusterFollowRedirections bool `yaml:"cluster_follow_redirections"` + ClusterReadFromReplica bool `yaml:"cluster_read_from_replica"` } const ( @@ -253,6 +254,7 @@ func (man Manager) addConfigs() { man.addConfigDuration("redis.keep_alive", 10*time.Second, "Interval between keep alive probes") man.addConfigInt("redis.connect_retry_attempts", 0, "Number of attempts to retry a failed connection") man.addConfigBool("redis.cluster_follow_redirections", false, "Automatically follow Redis Cluster redirections") + man.addConfigBool("redis.cluster_read_from_replica", false, "Prefer reading from a replica when possible (for Redis Cluster)") // Server man.addConfigString("server.address", "0.0.0.0:8080", @@ -444,6 +446,7 @@ func (man Manager) LoadConfig() FleetConfig { KeepAlive: man.getConfigDuration("redis.keep_alive"), ConnectRetryAttempts: man.getConfigInt("redis.connect_retry_attempts"), ClusterFollowRedirections: man.getConfigBool("redis.cluster_follow_redirections"), + ClusterReadFromReplica: man.getConfigBool("redis.cluster_read_from_replica"), }, Server: ServerConfig{ Address: man.getConfigString("server.address"), diff --git a/server/datastore/cached_mysql/cached_mysql.go b/server/datastore/cached_mysql/cached_mysql.go index 41debf3358..86254cdb3f 100644 --- a/server/datastore/cached_mysql/cached_mysql.go +++ b/server/datastore/cached_mysql/cached_mysql.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/fleetdm/fleet/v4/server/datastore/redis" "github.com/fleetdm/fleet/v4/server/fleet" redigo "github.com/gomodule/redigo/redis" "github.com/pkg/errors" @@ -28,7 +29,7 @@ func New(ds fleet.Datastore, redisPool fleet.RedisPool) fleet.Datastore { } func (ds *cachedMysql) storeInRedis(key string, v interface{}) error { - conn := ds.redisPool.ConfigureDoer(ds.redisPool.Get()) + conn := redis.ConfigureDoer(ds.redisPool, ds.redisPool.Get()) defer conn.Close() b, err := json.Marshal(v) @@ -44,7 +45,8 @@ func (ds *cachedMysql) storeInRedis(key string, v interface{}) error { } func (ds *cachedMysql) getFromRedis(key string, v interface{}) error { - conn := ds.redisPool.ConfigureDoer(ds.redisPool.Get()) + conn := redis.ReadOnlyConn(ds.redisPool, + redis.ConfigureDoer(ds.redisPool, ds.redisPool.Get())) defer conn.Close() data, err := redigo.Bytes(conn.Do("GET", key)) diff --git a/server/datastore/cached_mysql/cached_mysql_test.go b/server/datastore/cached_mysql/cached_mysql_test.go index cda15822b7..e2041531e3 100644 --- a/server/datastore/cached_mysql/cached_mysql_test.go +++ b/server/datastore/cached_mysql/cached_mysql_test.go @@ -3,12 +3,9 @@ package cached_mysql import ( "context" "encoding/json" - "os" - "runtime" "testing" - "time" - "github.com/fleetdm/fleet/v4/server/datastore/redis" + "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/ptr" @@ -17,47 +14,8 @@ import ( "github.com/stretchr/testify/require" ) -// newPool is basically repeated in every package that uses redis -// I tried to move this to a datastoretest package, but there's an import loop with redis -// so I decided to copy and past for now -func newPool(t *testing.T, cluster bool) fleet.RedisPool { - if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) - } - - if _, ok := os.LookupEnv("REDIS_TEST"); ok { - var ( - addr = "127.0.0.1:" - password = "" - database = 0 - useTLS = false - port = "6379" - ) - if cluster { - port = "7001" - } - addr += port - - pool, err := redis.NewRedisPool(redis.PoolConfig{ - Server: addr, - Password: password, - Database: database, - UseTLS: useTLS, - ConnTimeout: 5 * time.Second, - KeepAlive: 10 * time.Second, - }) - require.NoError(t, err) - conn := pool.Get() - defer conn.Close() - _, err = conn.Do("PING") - require.Nil(t, err) - return pool - } - return nil -} - func TestCachedAppConfig(t *testing.T) { - pool := newPool(t, false) + pool := redistest.SetupRedis(t, false, false, false) conn := pool.Get() _, err := conn.Do("DEL", CacheKeyAppConfig) require.NoError(t, err) diff --git a/server/datastore/redis/redis.go b/server/datastore/redis/redis.go index 3679b7a61a..ba09040968 100644 --- a/server/datastore/redis/redis.go +++ b/server/datastore/redis/redis.go @@ -19,10 +19,6 @@ type standalonePool struct { addr string } -func (p *standalonePool) ConfigureDoer(conn redis.Conn) redis.Conn { - return conn -} - func (p *standalonePool) Stats() map[string]redis.PoolStats { return map[string]redis.PoolStats{ p.addr: p.Pool.Stats(), @@ -32,19 +28,7 @@ func (p *standalonePool) Stats() map[string]redis.PoolStats { type clusterPool struct { *redisc.Cluster followRedirs bool -} - -// ConfigureDoer configures conn to follow redirections if the redis -// configuration requested it. If the conn is already in error, or -// if it is not a redisc cluster connection, it is returned unaltered. -func (p *clusterPool) ConfigureDoer(conn redis.Conn) redis.Conn { - if err := conn.Err(); err == nil && p.followRedirs { - rc, err := redisc.RetryConn(conn, 3, 300*time.Millisecond) - if err == nil { - return rc - } - } - return conn + readReplica bool } // PoolConfig holds the redis pool configuration options. @@ -57,14 +41,15 @@ type PoolConfig struct { KeepAlive time.Duration ConnectRetryAttempts int ClusterFollowRedirections bool + ClusterReadFromReplica bool // allows for testing dial retries and other dial-related scenarios testRedisDialFunc func(net, addr string, opts ...redis.DialOption) (redis.Conn, error) } -// NewRedisPool creates a Redis connection pool using the provided server +// NewPool creates a Redis connection pool using the provided server // address, password and database. -func NewRedisPool(config PoolConfig) (fleet.RedisPool, error) { +func NewPool(config PoolConfig) (fleet.RedisPool, error) { cluster := newCluster(config) if err := cluster.Refresh(); err != nil { if isClusterDisabled(err) || isClusterCommandUnknown(err) { @@ -76,30 +61,76 @@ func NewRedisPool(config PoolConfig) (fleet.RedisPool, error) { return nil, errors.Wrap(err, "refresh cluster") } - return &clusterPool{cluster, config.ClusterFollowRedirections}, nil + return &clusterPool{ + cluster, + config.ClusterFollowRedirections, + config.ClusterReadFromReplica, + }, nil } -// SplitRedisKeysBySlot takes a list of redis keys and groups them by hash slot +// ReadOnlyConn turns conn into a connection that will try to connect to a +// replica instead of a primary. Note that this is not guaranteed that it will +// do so (there may not be any replica, or due to redirections it may end up on +// a primary, etc.), and it will only try to do so if pool is a Redis Cluster +// pool. The returned connection should only be used to run read-only +// commands. +func ReadOnlyConn(pool fleet.RedisPool, conn redis.Conn) redis.Conn { + if p, isCluster := pool.(*clusterPool); isCluster && p.readReplica { + // it only fails if the connection is not a redisc connection or the + // connection is already bound, in which case we just return the connection + // as-is. + _ = redisc.ReadOnlyConn(conn) + } + return conn +} + +// ConfigureDoer configures conn to follow redirections if the redis +// configuration requested it and the pool is a Redis Cluster pool. If the conn +// is already in error, or if it is not a redisc cluster connection, it is +// returned unaltered. +func ConfigureDoer(pool fleet.RedisPool, conn redis.Conn) redis.Conn { + if p, isCluster := pool.(*clusterPool); isCluster { + if err := conn.Err(); err == nil && p.followRedirs { + rc, err := redisc.RetryConn(conn, 3, 300*time.Millisecond) + if err == nil { + return rc + } + } + } + return conn +} + +// SplitKeysBySlot takes a list of redis keys and groups them by hash slot // so that keys in a given group are guaranteed to hash to the same slot, making // them safe to run e.g. in a pipeline on the same connection or as part of a // multi-key command in a Redis Cluster setup. When using standalone Redis, it // simply returns all keys in the same group (i.e. the top-level slice has a // length of 1). -func SplitRedisKeysBySlot(pool fleet.RedisPool, keys ...string) [][]string { +func SplitKeysBySlot(pool fleet.RedisPool, keys ...string) [][]string { if _, isCluster := pool.(*clusterPool); isCluster { return redisc.SplitBySlot(keys...) } return [][]string{keys} } -// EachRedisNode calls fn for each node in the redis cluster, with a connection +// EachNode calls fn for each node in the redis cluster, with a connection // to that node, until all nodes have been visited. The connection is // automatically closed after the call. If fn returns an error, the iteration -// of nodes stops and EachRedisNode returns that error. For standalone redis, +// of nodes stops and EachNode returns that error. For standalone redis, // fn is called only once. -func EachRedisNode(pool fleet.RedisPool, fn func(conn redis.Conn) error) error { +// +// If replicas is true, it will visit each replica node instead, otherwise the +// primary nodes are visited. Keep in mind that if replicas is true, it will +// visit all known replicas - which is great e.g. to run diagnostics on each +// node, but can be surprising if the goal is e.g. to collect all keys, as it +// is possible that more than one node is acting as replica for the same +// primary, meaning that the same keys could be seen multiple times - you +// should be prepared to handle this scenario. The connection provided to fn is +// not a ReadOnly connection (conn.ReadOnly hasn't been called on it), it is up +// to fn to execute the READONLY redis command if required. +func EachNode(pool fleet.RedisPool, replicas bool, fn func(conn redis.Conn) error) error { if cluster, isCluster := pool.(*clusterPool); isCluster { - return cluster.EachNode(false, func(_ string, conn redis.Conn) error { + return cluster.EachNode(replicas, func(_ string, conn redis.Conn) error { return fn(conn) }) } @@ -109,6 +140,80 @@ func EachRedisNode(pool fleet.RedisPool, fn func(conn redis.Conn) error) error { return fn(conn) } +// BindConn binds the connection to the redis node that serves those keys. +// In a Redis Cluster setup, all keys must hash to the same slot, otherwise +// an error is returned. In a Redis Standalone setup, it is a no-op and never +// fails. On successful return, the connection is ready to be used with those +// keys. +func BindConn(pool fleet.RedisPool, conn redis.Conn, keys ...string) error { + if _, isCluster := pool.(*clusterPool); isCluster { + return redisc.BindConn(conn, keys...) + } + return nil +} + +// PublishHasListeners is like the PUBLISH redis command, but it also returns a +// boolean indicating if channel still has subscribed listeners. It is required +// because the redis command only returns the count of subscribers active on +// the same node as the one that is used to publish, which may not always be +// the case in Redis Cluster (especially with the read from replica option +// set). +// +// In Standalone mode, it is the same as PUBLISH (with the count of subscribers +// turned into a boolean), and in Cluster mode, if the count returned by +// PUBLISH is 0, it gets the number of subscribers on each node in the cluster +// to get the accurate count. +func PublishHasListeners(pool fleet.RedisPool, conn redis.Conn, channel, message string) (bool, error) { + n, err := redis.Int(conn.Do("PUBLISH", channel, message)) + if n > 0 || err != nil { + return n > 0, err + } + + // otherwise n == 0, check the actual number of subscribers if this is a + // redis cluster. + if _, isCluster := pool.(*clusterPool); !isCluster { + return false, nil + } + + errDone := errors.New("done") + var count int + + // subscribers can be subscribed on replicas, so we need to iterate on both + // primaries and replicas. + for _, replicas := range []bool{true, false} { + err = EachNode(pool, replicas, func(conn redis.Conn) error { + res, err := redis.Values(conn.Do("PUBSUB", "NUMSUB", channel)) + if err != nil { + return err + } + var ( + name string + n int + ) + _, err = redis.Scan(res, &name, &n) + if err != nil { + return err + } + count += n + if count > 0 { + // end early if we know it has subscribers + return errDone + } + return nil + }) + + if err == errDone { + break + } + } + + // if it completed successfully + if err == nil || err == errDone { + return count > 0, nil + } + return false, errors.Wrap(err, "checking for active subscribers") +} + func newCluster(config PoolConfig) *redisc.Cluster { opts := []redis.DialOption{ redis.DialDatabase(config.Database), diff --git a/server/datastore/redis/redis_external_test.go b/server/datastore/redis/redis_external_test.go new file mode 100644 index 0000000000..ce9649d7c6 --- /dev/null +++ b/server/datastore/redis/redis_external_test.go @@ -0,0 +1,252 @@ +package redis_test + +import ( + "fmt" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/datastore/redis" + "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" + "github.com/fleetdm/fleet/v4/server/fleet" + redigo "github.com/gomodule/redigo/redis" + "github.com/mna/redisc" + "github.com/stretchr/testify/require" +) + +func TestRedisPoolConfigureDoer(t *testing.T) { + const prefix = "TestRedisPoolConfigureDoer:" + + t.Run("standalone", func(t *testing.T) { + pool := redistest.SetupRedis(t, false, false, false) + + c1 := pool.Get() + defer c1.Close() + c2 := redis.ConfigureDoer(pool, pool.Get()) + defer c2.Close() + + // both conns work equally well, get nil because keys do not exist, + // but no redirection error (this is standalone redis). + _, err := redigo.String(c1.Do("GET", prefix+"{a}")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(c1.Do("GET", prefix+"{b}")) + require.Equal(t, redigo.ErrNil, err) + + _, err = redigo.String(c2.Do("GET", prefix+"{a}")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(c2.Do("GET", prefix+"{b}")) + require.Equal(t, redigo.ErrNil, err) + }) + + t.Run("cluster", func(t *testing.T) { + pool := redistest.SetupRedis(t, true, true, false) + + c1 := pool.Get() + defer c1.Close() + c2 := redis.ConfigureDoer(pool, pool.Get()) + defer c2.Close() + + // unconfigured conn gets MOVED error on the second key + // (it is bound to {a}, {b} is on a different node) + _, err := redigo.String(c1.Do("GET", prefix+"{a}")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(c1.Do("GET", prefix+"{b}")) + rerr := redisc.ParseRedir(err) + require.Error(t, rerr) + require.Equal(t, "MOVED", rerr.Type) + + // configured conn gets the nil value, it redirected automatically + _, err = redigo.String(c2.Do("GET", prefix+"{a}")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(c2.Do("GET", prefix+"{b}")) + require.Equal(t, redigo.ErrNil, err) + }) +} + +func TestEachNode(t *testing.T) { + const prefix = "TestEachNode:" + + runTest := func(t *testing.T, pool fleet.RedisPool) { + conn := pool.Get() + defer conn.Close() + if rc, err := redisc.RetryConn(conn, 3, 100*time.Millisecond); err == nil { + conn = rc + } + + for i := 0; i < 10; i++ { + _, err := conn.Do("SET", fmt.Sprintf("%s%d", prefix, i), i) + require.NoError(t, err) + } + + var keys []string + err := redis.EachNode(pool, false, func(conn redigo.Conn) error { + var cursor int + for { + res, err := redigo.Values(conn.Do("SCAN", cursor, "MATCH", prefix+"*")) + if err != nil { + return err + } + var curKeys []string + if _, err = redigo.Scan(res, &cursor, &curKeys); err != nil { + return err + } + keys = append(keys, curKeys...) + if cursor == 0 { + return nil + } + } + }) + require.NoError(t, err) + require.Len(t, keys, 10) + } + + t.Run("standalone", func(t *testing.T) { + pool := redistest.SetupRedis(t, false, false, false) + runTest(t, pool) + }) + + t.Run("cluster", func(t *testing.T) { + pool := redistest.SetupRedis(t, true, false, false) + runTest(t, pool) + }) +} + +func TestBindConn(t *testing.T) { + const prefix = "TestBindConn:" + + t.Run("standalone", func(t *testing.T) { + pool := redistest.SetupRedis(t, false, false, false) + + conn := pool.Get() + defer conn.Close() + + err := redis.BindConn(pool, conn, prefix+"a", prefix+"b", prefix+"c") + require.NoError(t, err) + _, err = redigo.String(conn.Do("GET", prefix+"a")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(conn.Do("GET", prefix+"b")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(conn.Do("GET", prefix+"c")) + require.Equal(t, redigo.ErrNil, err) + }) + + t.Run("cluster", func(t *testing.T) { + pool := redistest.SetupRedis(t, true, false, false) + + conn := pool.Get() + defer conn.Close() + + err := redis.BindConn(pool, conn, prefix+"a", prefix+"b", prefix+"c") + require.Error(t, err) + + err = redis.BindConn(pool, conn, prefix+"{z}a", prefix+"{z}b", prefix+"{z}c") + require.NoError(t, err) + + _, err = redigo.String(conn.Do("GET", prefix+"{z}a")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(conn.Do("GET", prefix+"{z}b")) + require.Equal(t, redigo.ErrNil, err) + _, err = redigo.String(conn.Do("GET", prefix+"{z}c")) + require.Equal(t, redigo.ErrNil, err) + }) +} + +func TestPublishHasListeners(t *testing.T) { + const prefix = "TestPublishHasListeners:" + + t.Run("standalone", func(t *testing.T) { + pool := redistest.SetupRedis(t, false, false, false) + + pconn := pool.Get() + defer pconn.Close() + sconn := pool.Get() + defer sconn.Close() + + ok, err := redis.PublishHasListeners(pool, pconn, prefix+"a", "A") + require.NoError(t, err) + require.False(t, ok) + + psc := redigo.PubSubConn{Conn: sconn} + require.NoError(t, psc.Subscribe(prefix+"a")) + + ok, err = redis.PublishHasListeners(pool, pconn, prefix+"a", "B") + require.NoError(t, err) + require.True(t, ok) + + start := time.Now() + loop: + for time.Since(start) < 2*time.Second { + msg := psc.Receive() + switch msg := msg.(type) { + case redigo.Message: + require.Equal(t, "B", string(msg.Data)) + break loop + } + } + }) + + t.Run("cluster", func(t *testing.T) { + pool := redistest.SetupRedis(t, true, false, false) + + pconn := pool.Get() + defer pconn.Close() + sconn := pool.Get() + defer sconn.Close() + + ok, err := redis.PublishHasListeners(pool, pconn, prefix+"{a}", "A") + require.NoError(t, err) + require.False(t, ok) + + // one listener on a different node + redis.BindConn(pool, sconn, "b") + psc := redigo.PubSubConn{Conn: sconn} + require.NoError(t, psc.Subscribe(prefix+"{a}")) + + // a standard PUBLISH returns 0 + n, err := redigo.Int(pconn.Do("PUBLISH", prefix+"{a}", "B")) + require.NoError(t, err) + require.Equal(t, 0, n) + + // but this returns true + ok, err = redis.PublishHasListeners(pool, pconn, prefix+"{a}", "C") + require.NoError(t, err) + require.True(t, ok) + + start := time.Now() + want := "B" + loop: + for time.Since(start) < 2*time.Second { + msg := psc.Receive() + switch msg := msg.(type) { + case redigo.Message: + require.Equal(t, want, string(msg.Data)) + if want == "C" { + break loop + } + want = "C" + } + } + }) +} + +func TestReadOnlyConn(t *testing.T) { + const prefix = "TestReadOnlyConn:" + + t.Run("standalone", func(t *testing.T) { + pool := redistest.SetupRedis(t, false, false, true) + conn := redis.ReadOnlyConn(pool, pool.Get()) + defer conn.Close() + + _, err := conn.Do("SET", prefix+"a", 1) + require.NoError(t, err) + }) + + t.Run("cluster", func(t *testing.T) { + pool := redistest.SetupRedis(t, true, false, true) + conn := redis.ReadOnlyConn(pool, pool.Get()) + defer conn.Close() + + _, err := conn.Do("SET", prefix+"a", 1) + require.Error(t, err) + require.Contains(t, err.Error(), "MOVED") + }) +} diff --git a/server/datastore/redis/redis_test.go b/server/datastore/redis/redis_test.go index c1bbf82b36..d15c14da0c 100644 --- a/server/datastore/redis/redis_test.go +++ b/server/datastore/redis/redis_test.go @@ -1,15 +1,11 @@ package redis import ( - "fmt" "io" - "runtime" "testing" "time" - "github.com/fleetdm/fleet/v4/server/fleet" "github.com/gomodule/redigo/redis" - "github.com/mna/redisc" "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -88,7 +84,7 @@ func TestConnectRetry(t *testing.T) { for _, c := range cases { t.Run(c.err.Error(), func(t *testing.T) { start := time.Now() - _, err := NewRedisPool(PoolConfig{ + _, err := NewPool(PoolConfig{ Server: "127.0.0.1:12345", ConnectRetryAttempts: c.retries, testRedisDialFunc: mockDial(c.err), @@ -115,145 +111,3 @@ func TestConnectRetry(t *testing.T) { }) } } - -func TestRedisPoolConfigureDoer(t *testing.T) { - const prefix = "TestRedisPoolConfigureDoer:" - - t.Run("standalone", func(t *testing.T) { - pool := setupRedisForTest(t, false, false) - - c1 := pool.Get() - defer c1.Close() - c2 := pool.ConfigureDoer(pool.Get()) - defer c2.Close() - - // both conns work equally well, get nil because keys do not exist, - // but no redirection error (this is standalone redis). - _, err := redis.String(c1.Do("GET", prefix+"{a}")) - require.Equal(t, redis.ErrNil, err) - _, err = redis.String(c1.Do("GET", prefix+"{b}")) - require.Equal(t, redis.ErrNil, err) - - _, err = redis.String(c2.Do("GET", prefix+"{a}")) - require.Equal(t, redis.ErrNil, err) - _, err = redis.String(c2.Do("GET", prefix+"{b}")) - require.Equal(t, redis.ErrNil, err) - }) - - t.Run("cluster", func(t *testing.T) { - pool := setupRedisForTest(t, true, true) - - c1 := pool.Get() - defer c1.Close() - c2 := pool.ConfigureDoer(pool.Get()) - defer c2.Close() - - // unconfigured conn gets MOVED error on the second key - // (it is bound to {a}, {b} is on a different node) - _, err := redis.String(c1.Do("GET", prefix+"{a}")) - require.Equal(t, redis.ErrNil, err) - _, err = redis.String(c1.Do("GET", prefix+"{b}")) - rerr := redisc.ParseRedir(err) - require.Error(t, rerr) - require.Equal(t, "MOVED", rerr.Type) - - // configured conn gets the nil value, it redirected automatically - _, err = redis.String(c2.Do("GET", prefix+"{a}")) - require.Equal(t, redis.ErrNil, err) - _, err = redis.String(c2.Do("GET", prefix+"{b}")) - require.Equal(t, redis.ErrNil, err) - }) -} - -func TestEachRedisNode(t *testing.T) { - const prefix = "TestEachRedisNode:" - - runTest := func(t *testing.T, pool fleet.RedisPool) { - conn := pool.Get() - defer conn.Close() - if rc, err := redisc.RetryConn(conn, 3, 100*time.Millisecond); err == nil { - conn = rc - } - - for i := 0; i < 10; i++ { - _, err := conn.Do("SET", fmt.Sprintf("%s%d", prefix, i), i) - require.NoError(t, err) - } - - var keys []string - err := EachRedisNode(pool, func(conn redis.Conn) error { - var cursor int - for { - res, err := redis.Values(conn.Do("SCAN", cursor, "MATCH", prefix+"*")) - if err != nil { - return err - } - var curKeys []string - if _, err = redis.Scan(res, &cursor, &curKeys); err != nil { - return err - } - keys = append(keys, curKeys...) - if cursor == 0 { - return nil - } - } - }) - require.NoError(t, err) - require.Len(t, keys, 10) - } - - t.Run("standalone", func(t *testing.T) { - pool := setupRedisForTest(t, false, false) - runTest(t, pool) - }) - - t.Run("cluster", func(t *testing.T) { - pool := setupRedisForTest(t, true, false) - runTest(t, pool) - }) -} - -func setupRedisForTest(t *testing.T, cluster, redir bool) (pool fleet.RedisPool) { - if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) - } - - var ( - addr = "127.0.0.1:" - password = "" - database = 0 - useTLS = false - port = "6379" - ) - if cluster { - port = "7001" - } - addr += port - - pool, err := NewRedisPool(PoolConfig{ - Server: addr, - Password: password, - Database: database, - UseTLS: useTLS, - ConnTimeout: 5 * time.Second, - KeepAlive: 10 * time.Second, - ClusterFollowRedirections: redir, - }) - require.NoError(t, err) - - conn := pool.Get() - defer conn.Close() - _, err = conn.Do("PING") - require.Nil(t, err) - - t.Cleanup(func() { - err := EachRedisNode(pool, func(conn redis.Conn) error { - _, err := conn.Do("FLUSHDB") - return err - }) - require.NoError(t, err) - pool.Close() - }) - - return pool -} diff --git a/server/datastore/redis/redistest/redistest.go b/server/datastore/redis/redistest/redistest.go new file mode 100644 index 0000000000..ac47a598fb --- /dev/null +++ b/server/datastore/redis/redistest/redistest.go @@ -0,0 +1,62 @@ +package redistest + +import ( + "os" + "runtime" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/datastore/redis" + "github.com/fleetdm/fleet/v4/server/fleet" + redigo "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/require" +) + +func SetupRedis(tb testing.TB, cluster, redir, readReplica bool) fleet.RedisPool { + if _, ok := os.LookupEnv("REDIS_TEST"); !ok { + tb.Skip("set REDIS_TEST environment variable to run redis-based tests") + } + if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { + tb.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) + } + + var ( + addr = "127.0.0.1:" + password = "" + database = 0 + useTLS = false + port = "6379" + ) + if cluster { + port = "7001" + } + addr += port + + pool, err := redis.NewPool(redis.PoolConfig{ + Server: addr, + Password: password, + Database: database, + UseTLS: useTLS, + ConnTimeout: 5 * time.Second, + KeepAlive: 10 * time.Second, + ClusterFollowRedirections: redir, + ClusterReadFromReplica: readReplica, + }) + require.NoError(tb, err) + + conn := pool.Get() + defer conn.Close() + _, err = conn.Do("PING") + require.Nil(tb, err) + + tb.Cleanup(func() { + err := redis.EachNode(pool, false, func(conn redigo.Conn) error { + _, err := conn.Do("FLUSHDB") + return err + }) + require.NoError(tb, err) + pool.Close() + }) + + return pool +} diff --git a/server/fleet/redis_pool.go b/server/fleet/redis_pool.go index e1d86daf2c..ddc87ef95a 100644 --- a/server/fleet/redis_pool.go +++ b/server/fleet/redis_pool.go @@ -13,9 +13,4 @@ type RedisPool interface { // Stats returns a map of redis pool statistics for each server address. Stats() map[string]redis.PoolStats - - // ConfigureDoer returns a redis connection that is properly configured - // to execute Do commands. This should only be called when the actions - // to execute are all done with conn.Do. - ConfigureDoer(redis.Conn) redis.Conn } diff --git a/server/live_query/redis_live_query.go b/server/live_query/redis_live_query.go index a54e480b37..8e5766e52a 100644 --- a/server/live_query/redis_live_query.go +++ b/server/live_query/redis_live_query.go @@ -110,7 +110,7 @@ func (r *redisLiveQuery) MigrateKeys() error { } } - keysBySlot := redis.SplitRedisKeysBySlot(r.pool, oldKeys...) + keysBySlot := redis.SplitKeysBySlot(r.pool, oldKeys...) for _, keys := range keysBySlot { if err := migrateBatchKeys(r.pool, keys); err != nil { return err @@ -194,7 +194,7 @@ func (r *redisLiveQuery) RunQuery(name, sql string, hostIDs []uint) error { } func (r *redisLiveQuery) StopQuery(name string) error { - conn := r.pool.ConfigureDoer(r.pool.Get()) + conn := redis.ConfigureDoer(r.pool, r.pool.Get()) defer conn.Close() targetKey, sqlKey := generateKeys(name) @@ -212,7 +212,7 @@ func (r *redisLiveQuery) QueriesForHost(hostID uint) (map[string]string, error) return nil, errors.Wrap(err, "scan active queries") } - keysBySlot := redis.SplitRedisKeysBySlot(r.pool, queryKeys...) + keysBySlot := redis.SplitKeysBySlot(r.pool, queryKeys...) queries := make(map[string]string) for _, qkeys := range keysBySlot { if err := r.collectBatchQueriesForHost(hostID, qkeys, queries); err != nil { @@ -223,7 +223,7 @@ func (r *redisLiveQuery) QueriesForHost(hostID uint) (map[string]string, error) } func (r *redisLiveQuery) collectBatchQueriesForHost(hostID uint, queryKeys []string, queriesByHost map[string]string) error { - conn := r.pool.Get() + conn := redis.ReadOnlyConn(r.pool, r.pool.Get()) defer conn.Close() // Pipeline redis calls to check for this host in the bitfield of the @@ -279,7 +279,7 @@ func (r *redisLiveQuery) collectBatchQueriesForHost(hostID uint, queryKeys []str } func (r *redisLiveQuery) QueryCompletedByHost(name string, hostID uint) error { - conn := r.pool.ConfigureDoer(r.pool.Get()) + conn := redis.ConfigureDoer(r.pool, r.pool.Get()) defer conn.Close() targetKey, _ := generateKeys(name) @@ -318,7 +318,7 @@ func mapBitfield(hostIDs []uint) []byte { func scanKeys(pool fleet.RedisPool, pattern string) ([]string, error) { var keys []string - err := redis.EachRedisNode(pool, func(conn redigo.Conn) error { + err := redis.EachNode(pool, false, func(conn redigo.Conn) error { cursor := 0 for { res, err := redigo.Values(conn.Do("SCAN", cursor, "MATCH", pattern)) diff --git a/server/live_query/redis_live_query_test.go b/server/live_query/redis_live_query_test.go index 5a13411b32..0471c64492 100644 --- a/server/live_query/redis_live_query_test.go +++ b/server/live_query/redis_live_query_test.go @@ -1,11 +1,11 @@ package live_query import ( - "runtime" "testing" "time" "github.com/fleetdm/fleet/v4/server/datastore/redis" + "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" "github.com/fleetdm/fleet/v4/server/test" redigo "github.com/gomodule/redigo/redis" "github.com/mna/redisc" @@ -17,14 +17,12 @@ func TestRedisLiveQuery(t *testing.T) { for _, f := range testFunctions { t.Run(test.FunctionName(f), func(t *testing.T) { t.Run("standalone", func(t *testing.T) { - store, teardown := setupRedisLiveQuery(t, false) - defer teardown() + store := setupRedisLiveQuery(t, false) f(t, store) }) t.Run("cluster", func(t *testing.T) { - store, teardown := setupRedisLiveQuery(t, true) - defer teardown() + store := setupRedisLiveQuery(t, true) f(t, store) }) }) @@ -66,7 +64,7 @@ func TestMigrateKeys(t *testing.T) { require.NoError(t, err) got := make(map[string]string) - err = redis.EachRedisNode(store.pool, func(conn redigo.Conn) error { + err = redis.EachNode(store.pool, false, func(conn redigo.Conn) error { keys, err := redigo.Strings(conn.Do("KEYS", "*")) if err != nil { return err @@ -87,61 +85,19 @@ func TestMigrateKeys(t *testing.T) { } t.Run("standalone", func(t *testing.T) { - store, teardown := setupRedisLiveQuery(t, false) - defer teardown() + store := setupRedisLiveQuery(t, false) runTest(t, store) }) t.Run("cluster", func(t *testing.T) { - store, teardown := setupRedisLiveQuery(t, true) - defer teardown() + store := setupRedisLiveQuery(t, true) runTest(t, store) }) } -func setupRedisLiveQuery(t *testing.T, cluster bool) (store *redisLiveQuery, teardown func()) { - if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) - } - - var ( - addr = "127.0.0.1:" - password = "" - database = 0 - useTLS = false - port = "6379" - ) - if cluster { - port = "7001" - } - addr += port - - pool, err := redis.NewRedisPool(redis.PoolConfig{ - Server: addr, - Password: password, - Database: database, - UseTLS: useTLS, - ConnTimeout: 5 * time.Second, - KeepAlive: 10 * time.Second, - }) - require.NoError(t, err) - store = NewRedisLiveQuery(pool) - - conn := store.pool.Get() - defer conn.Close() - _, err = conn.Do("PING") - require.NoError(t, err) - - teardown = func() { - err := redis.EachRedisNode(store.pool, func(conn redigo.Conn) error { - _, err := conn.Do("FLUSHDB") - return err - }) - require.NoError(t, err) - store.pool.Close() - } - - return store, teardown +func setupRedisLiveQuery(t *testing.T, cluster bool) *redisLiveQuery { + pool := redistest.SetupRedis(t, cluster, false, false) + return NewRedisLiveQuery(pool) } func TestMapBitfield(t *testing.T) { diff --git a/server/pubsub/query_results_test.go b/server/pubsub/query_results_test.go index cb865c23e8..96a11578a8 100644 --- a/server/pubsub/query_results_test.go +++ b/server/pubsub/query_results_test.go @@ -6,8 +6,11 @@ import ( "testing" "time" + "github.com/fleetdm/fleet/v4/server/datastore/redis" "github.com/fleetdm/fleet/v4/server/fleet" + redigo "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // waitTimeout waits for the waitgroup for the specified max timeout. @@ -28,38 +31,51 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { func TestQueryResultsStoreErrors(t *testing.T) { runTest := func(t *testing.T, store *redisQueryResults) { - // Write with no subscriber - err := store.WriteResult( - fleet.DistributedQueryResult{ - DistributedQueryCampaignID: 9999, - Rows: []map[string]string{{"bing": "fds"}}, - Host: fleet.Host{ - ID: 4, - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - UpdateTimestamp: fleet.UpdateTimestamp{ - UpdatedAt: time.Now().UTC(), - }, + result := fleet.DistributedQueryResult{ + DistributedQueryCampaignID: 9999, + Rows: []map[string]string{{"bing": "fds"}}, + Host: fleet.Host{ + ID: 4, + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + UpdateTimestamp: fleet.UpdateTimestamp{ + UpdatedAt: time.Now().UTC(), }, - DetailUpdatedAt: time.Now().UTC(), }, + DetailUpdatedAt: time.Now().UTC(), }, - ) - assert.NotNil(t, err) + } + + // Write with no subscriber + err := store.WriteResult(result) + require.Error(t, err) castErr, ok := err.(Error) if assert.True(t, ok, "err should be pubsub.Error") { assert.True(t, castErr.NoSubscriber(), "NoSubscriber() should be true") } + + // Write with one subscriber, force it to bind to a different node if + // this is a cluster, so we don't rely on publishing/subscribing on the + // same nodes. + conn := redis.ReadOnlyConn(store.pool, store.pool.Get()) + defer conn.Close() + err = redis.BindConn(store.pool, conn, "ZZZ") + require.NoError(t, err) + + psc := &redigo.PubSubConn{Conn: conn} + pubSubName := pubSubForID(9999) + require.NoError(t, psc.Subscribe(pubSubName)) + + err = store.WriteResult(result) + require.NoError(t, err) } t.Run("standalone", func(t *testing.T) { - store, teardown := SetupRedisForTest(t, false) - defer teardown() + store := SetupRedisForTest(t, false, false) runTest(t, store) }) t.Run("cluster", func(t *testing.T) { - store, teardown := SetupRedisForTest(t, true) - defer teardown() + store := SetupRedisForTest(t, true, true) runTest(t, store) }) } @@ -240,14 +256,12 @@ func TestQueryResultsStore(t *testing.T) { } t.Run("standalone", func(t *testing.T) { - store, teardown := SetupRedisForTest(t, false) - defer teardown() + store := SetupRedisForTest(t, false, false) runTest(t, store) }) t.Run("cluster", func(t *testing.T) { - store, teardown := SetupRedisForTest(t, true) - defer teardown() + store := SetupRedisForTest(t, true, true) runTest(t, store) }) } diff --git a/server/pubsub/redis_query_results.go b/server/pubsub/redis_query_results.go index fe815906dc..b7623f67d4 100644 --- a/server/pubsub/redis_query_results.go +++ b/server/pubsub/redis_query_results.go @@ -6,8 +6,9 @@ import ( "fmt" "time" + "github.com/fleetdm/fleet/v4/server/datastore/redis" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/gomodule/redigo/redis" + redigo "github.com/gomodule/redigo/redis" "github.com/pkg/errors" ) @@ -35,7 +36,8 @@ func (r *redisQueryResults) Pool() fleet.RedisPool { } func (r *redisQueryResults) WriteResult(result fleet.DistributedQueryResult) error { - conn := r.pool.Get() + // pub-sub can publish and listen on any node in the cluster + conn := redis.ReadOnlyConn(r.pool, r.pool.Get()) defer conn.Close() channelName := pubSubForID(result.DistributedQueryCampaignID) @@ -45,17 +47,17 @@ func (r *redisQueryResults) WriteResult(result fleet.DistributedQueryResult) err return errors.Wrap(err, "marshalling JSON for result") } - n, err := redis.Int(conn.Do("PUBLISH", channelName, string(jsonVal))) + hasSubs, err := redis.PublishHasListeners(r.pool, conn, channelName, string(jsonVal)) - if n != 0 && r.duplicateResults { + if hasSubs && r.duplicateResults { // Ignore errors, duplicate result publishing is on a "best-effort" basis. - _, _ = redis.Int(conn.Do("PUBLISH", "LQDuplicate", string(jsonVal))) + _, _ = redigo.Int(conn.Do("PUBLISH", "LQDuplicate", string(jsonVal))) } if err != nil { return errors.Wrap(err, "PUBLISH failed to channel "+channelName) } - if n == 0 { + if !hasSubs { return noSubscriberError{channelName} } @@ -77,7 +79,7 @@ func writeOrDone(ctx context.Context, ch chan<- interface{}, item interface{}) b // connection over the provided channel. This effectively allows a select // statement to run on conn.Receive() (by selecting on outChan that is // passed into this function) -func receiveMessages(ctx context.Context, conn *redis.PubSubConn, outChan chan<- interface{}) { +func receiveMessages(ctx context.Context, conn *redigo.PubSubConn, outChan chan<- interface{}) { defer close(outChan) // conn.Close() needs to be here in this function because Receive and Close should not be called // concurrently. Otherwise we end up with a hang when Close is called. @@ -97,7 +99,7 @@ func receiveMessages(ctx context.Context, conn *redis.PubSubConn, outChan chan<- case error: // If an error occurred (i.e. connection was closed), then we should exit. return - case redis.Subscription: + case redigo.Subscription: // If the subscription count is 0, the ReadChannel call that invoked this goroutine has unsubscribed, // and we can exit. if msg.Count == 0 { @@ -111,8 +113,9 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib outChannel := make(chan interface{}) msgChannel := make(chan interface{}) - conn := r.pool.Get() - psc := &redis.PubSubConn{Conn: conn} + // pub-sub can publish and listen on any node in the cluster + conn := redis.ReadOnlyConn(r.pool, r.pool.Get()) + psc := &redigo.PubSubConn{Conn: conn} pubSubName := pubSubForID(query.ID) if err := psc.Subscribe(pubSubName); err != nil { // Explicit conn.Close() here because we can't defer it until in the goroutine @@ -140,7 +143,7 @@ func (r *redisQueryResults) ReadChannel(ctx context.Context, query fleet.Distrib } switch msg := msg.(type) { - case redis.Message: + case redigo.Message: var res fleet.DistributedQueryResult err := json.Unmarshal(msg.Data, &res) if err != nil { diff --git a/server/pubsub/testing_utils.go b/server/pubsub/testing_utils.go index 9ff5b39c1d..a0c7fa6dae 100644 --- a/server/pubsub/testing_utils.go +++ b/server/pubsub/testing_utils.go @@ -1,57 +1,13 @@ package pubsub import ( - "runtime" "testing" - "time" - "github.com/fleetdm/fleet/v4/server/datastore/redis" - redigo "github.com/gomodule/redigo/redis" - "github.com/stretchr/testify/require" + "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" ) -func SetupRedisForTest(t *testing.T, cluster bool) (store *redisQueryResults, teardown func()) { - if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) - } - - var ( - addr = "127.0.0.1:" - password = "" - database = 0 - useTLS = false - dupResults = false - port = "6379" - ) - if cluster { - port = "7001" - } - addr += port - - pool, err := redis.NewRedisPool(redis.PoolConfig{ - Server: addr, - Password: password, - Database: database, - UseTLS: useTLS, - ConnTimeout: 5 * time.Second, - KeepAlive: 10 * time.Second, - }) - require.NoError(t, err) - store = NewRedisQueryResults(pool, dupResults) - - conn := store.pool.Get() - defer conn.Close() - _, err = conn.Do("PING") - require.Nil(t, err) - - teardown = func() { - err := redis.EachRedisNode(store.pool, func(conn redigo.Conn) error { - _, err := conn.Do("FLUSHDB") - return err - }) - require.NoError(t, err) - store.pool.Close() - } - - return store, teardown +func SetupRedisForTest(t *testing.T, cluster, readReplica bool) *redisQueryResults { + const dupResults = false + pool := redistest.SetupRedis(t, cluster, false, readReplica) + return NewRedisQueryResults(pool, dupResults) } diff --git a/server/service/service_campaign_test.go b/server/service/service_campaign_test.go index 5437cca55f..52830f7abd 100644 --- a/server/service/service_campaign_test.go +++ b/server/service/service_campaign_test.go @@ -26,8 +26,7 @@ import ( func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) { t.Skip("Seems to be a bit problematic in CI") - store, teardown := pubsub.SetupRedisForTest(t, false) - defer teardown() + store := pubsub.SetupRedisForTest(t, false, false) mockClock := clock.NewMockClock() ds := new(mock.Store) diff --git a/server/sso/session_store.go b/server/sso/session_store.go index 7b0bb36fdc..c8c98ca019 100644 --- a/server/sso/session_store.go +++ b/server/sso/session_store.go @@ -5,8 +5,9 @@ import ( "encoding/json" "time" + "github.com/fleetdm/fleet/v4/server/datastore/redis" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/gomodule/redigo/redis" + redigo "github.com/gomodule/redigo/redis" "github.com/pkg/errors" ) @@ -46,7 +47,7 @@ func (s *store) create(requestID, originalURL, metadata string, lifetimeSecs uin if len(requestID) < 8 { return errors.New("request id must be 8 or more characters in length") } - conn := s.pool.ConfigureDoer(s.pool.Get()) + conn := redis.ConfigureDoer(s.pool, s.pool.Get()) defer conn.Close() sess := Session{OriginalURL: originalURL, Metadata: metadata} var writer bytes.Buffer @@ -59,11 +60,14 @@ func (s *store) create(requestID, originalURL, metadata string, lifetimeSecs uin } func (s *store) Get(requestID string) (*Session, error) { - conn := s.pool.ConfigureDoer(s.pool.Get()) + // not reading from a replica here as this gets called in close succession + // in the auth flow, with initiate SSO writing and callback SSO having to + // read that write. + conn := redis.ConfigureDoer(s.pool, s.pool.Get()) defer conn.Close() - val, err := redis.String(conn.Do("GET", requestID)) + val, err := redigo.String(conn.Do("GET", requestID)) if err != nil { - if err == redis.ErrNil { + if err == redigo.ErrNil { return nil, ErrSessionNotFound } return nil, err @@ -81,7 +85,7 @@ func (s *store) Get(requestID string) (*Session, error) { var ErrSessionNotFound = errors.New("session not found") func (s *store) Expire(requestID string) error { - conn := s.pool.ConfigureDoer(s.pool.Get()) + conn := redis.ConfigureDoer(s.pool, s.pool.Get()) defer conn.Close() _, err := conn.Do("DEL", requestID) return err diff --git a/server/sso/session_store_test.go b/server/sso/session_store_test.go index 318ba12acd..514eb7ce22 100644 --- a/server/sso/session_store_test.go +++ b/server/sso/session_store_test.go @@ -1,58 +1,16 @@ package sso import ( - "os" - "runtime" "testing" "time" - "github.com/fleetdm/fleet/v4/server/datastore/redis" + "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func newPool(t *testing.T, cluster bool) fleet.RedisPool { - if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { - t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) - } - - if _, ok := os.LookupEnv("REDIS_TEST"); ok { - var ( - addr = "127.0.0.1:" - password = "" - database = 0 - useTLS = false - port = "6379" - ) - if cluster { - port = "7001" - } - addr += port - - pool, err := redis.NewRedisPool(redis.PoolConfig{ - Server: addr, - Password: password, - Database: database, - UseTLS: useTLS, - ConnTimeout: 5 * time.Second, - KeepAlive: 10 * time.Second, - }) - require.NoError(t, err) - conn := pool.Get() - defer conn.Close() - _, err = conn.Do("PING") - require.Nil(t, err) - return pool - } - return nil -} - func TestSessionStore(t *testing.T) { - if _, ok := os.LookupEnv("REDIS_TEST"); !ok { - t.Skip("skipping sso session store tests") - } - runTest := func(t *testing.T, pool fleet.RedisPool) { store := NewSessionStore(pool) require.NotNil(t, store) @@ -72,16 +30,14 @@ func TestSessionStore(t *testing.T) { } t.Run("standalone", func(t *testing.T) { - p := newPool(t, false) + p := redistest.SetupRedis(t, false, false, false) require.NotNil(t, p) - defer p.Close() runTest(t, p) }) t.Run("cluster", func(t *testing.T) { - p := newPool(t, true) + p := redistest.SetupRedis(t, true, false, false) require.NotNil(t, p) - defer p.Close() runTest(t, p) }) }