From 685245c2bd9c72470b5adba0fbebd070e06d470b Mon Sep 17 00:00:00 2001 From: Tomas Touceda Date: Tue, 28 Sep 2021 10:01:53 -0300 Subject: [PATCH] Cache app config in redis (#2205) * Cache app config in redis * Add changes files * Replace string with constant * Revert some test refactorign and duplicate a bit of test code * Add test for AppConfig with redis failing * Fix lint * Use Doer so it works better in clusters * Skip unmarshalling if we already did * Allow to cache hosts if configured * Omit the setting if empty * Remove hashing, too much CPU * Revert caching of host auth... needs a more thought through approach * Remove config * Remove old config * Remove locker interface * Fix test and address review comments --- changes/cache-app-config | 1 + cmd/fleet/serve.go | 40 ++--- cmd/fleet/serve_test.go | 49 ++++-- server/datastore/cached_mysql/cached_mysql.go | 100 +++++++++++ .../cached_mysql/cached_mysql_test.go | 159 ++++++++++++++++++ server/datastore/redis/redis_test.go | 20 +-- server/fleet/datastore.go | 15 ++ server/mock/datastore_mock.go | 20 +++ 8 files changed, 349 insertions(+), 55 deletions(-) create mode 100644 changes/cache-app-config create mode 100644 server/datastore/cached_mysql/cached_mysql.go create mode 100644 server/datastore/cached_mysql/cached_mysql_test.go diff --git a/changes/cache-app-config b/changes/cache-app-config new file mode 100644 index 0000000000..2e8338411a --- /dev/null +++ b/changes/cache-app-config @@ -0,0 +1 @@ +* Cache AppConfig in redis to speed up requests. diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index ab90d2bba6..1c849fa4fc 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -8,6 +8,7 @@ import ( "github.com/e-dard/netbug" "github.com/fleetdm/fleet/v4/server" + "github.com/fleetdm/fleet/v4/server/datastore/cached_mysql" "github.com/fleetdm/fleet/v4/server/logging" "github.com/fleetdm/fleet/v4/server/webhooks" @@ -158,6 +159,7 @@ the way that the Fleet server works. if err != nil { initFatal(err, "initializing datastore") } + if config.S3.Bucket != "" { carveStore, err = s3.New(config.S3, ds) if err != nil { @@ -213,6 +215,7 @@ the way that the Fleet server works. if err != nil { initFatal(err, "initialize Redis") } + ds = cached_mysql.New(ds, redisPool) resultStore := pubsub.NewRedisQueryResults(redisPool, config.Redis.DuplicateResults) liveQueryStore := live_query.NewRedisLiveQuery(redisPool) if err := liveQueryStore.MigrateKeys(); err != nil { @@ -412,22 +415,6 @@ the way that the Fleet server works. return serveCmd } -// Locker represents an object that can obtain an atomic lock on a resource -// in a non blocking manner for an owner, with an expiration time. -type Locker interface { - // Lock tries to get an atomic lock on an instance named with `name` - // and an `owner` identified by a random string per instance. - // Subsequently locking the same resource name for the same owner - // renews the lock expiration. - // It returns true, nil if it managed to obtain a lock on the instance. - // false and potentially an error otherwise. - // This must not be blocking. - Lock(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) - // Unlock tries to unlock the lock by that `name` for the specified - // `owner`. Unlocking when not holding the lock shouldn't error - Unlock(ctx context.Context, name string, owner string) error -} - const ( lockKeyLeader = "leader" lockKeyVulnerabilities = "vulnerabilities" @@ -459,10 +446,6 @@ func trySendStatistics(ctx context.Context, ds fleet.Datastore, frequency time.D } func runCrons(ds fleet.Datastore, logger kitlog.Logger, config config.FleetConfig) context.CancelFunc { - locker, ok := ds.(Locker) - if !ok { - initFatal(errors.New("No global locker available"), "") - } ctx, cancelBackground := context.WithCancel(context.Background()) ourIdentifier, err := server.GenerateRandomText(64) @@ -470,15 +453,15 @@ func runCrons(ds fleet.Datastore, logger kitlog.Logger, config config.FleetConfi initFatal(errors.New("Error generating random instance identifier"), "") } - go cronCleanups(ctx, ds, kitlog.With(logger, "cron", "cleanups"), locker, ourIdentifier) + go cronCleanups(ctx, ds, kitlog.With(logger, "cron", "cleanups"), ourIdentifier) go cronVulnerabilities( - ctx, ds, kitlog.With(logger, "cron", "vulnerabilities"), locker, ourIdentifier, config) - go cronWebhooks(ctx, ds, kitlog.With(logger, "cron", "webhooks"), locker, ourIdentifier) + ctx, ds, kitlog.With(logger, "cron", "vulnerabilities"), ourIdentifier, config) + go cronWebhooks(ctx, ds, kitlog.With(logger, "cron", "webhooks"), ourIdentifier) return cancelBackground } -func cronCleanups(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, locker Locker, identifier string) { +func cronCleanups(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, identifier string) { ticker := time.NewTicker(1 * time.Hour) for { level.Debug(logger).Log("waiting", "on ticker") @@ -489,7 +472,7 @@ func cronCleanups(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, level.Debug(logger).Log("exit", "done with cron.") return } - if locked, err := locker.Lock(ctx, lockKeyLeader, identifier, time.Hour); err != nil || !locked { + if locked, err := ds.Lock(ctx, lockKeyLeader, identifier, time.Hour); err != nil || !locked { level.Debug(logger).Log("leader", "Not the leader. Skipping...") continue } @@ -526,7 +509,6 @@ func cronVulnerabilities( ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, - locker Locker, identifier string, config config.FleetConfig, ) { @@ -581,7 +563,7 @@ func cronVulnerabilities( return } if config.Vulnerabilities.CurrentInstanceChecks == "auto" { - if locked, err := locker.Lock(ctx, lockKeyVulnerabilities, identifier, time.Hour); err != nil || !locked { + if locked, err := ds.Lock(ctx, lockKeyVulnerabilities, identifier, time.Hour); err != nil || !locked { level.Debug(logger).Log("leader", "Not the leader. Skipping...") continue } @@ -603,7 +585,7 @@ func cronVulnerabilities( } } -func cronWebhooks(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, locker Locker, identifier string) { +func cronWebhooks(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, identifier string) { appConfig, err := ds.AppConfig(ctx) if err != nil { level.Error(logger).Log("config", "couldn't read app config", "err", err) @@ -621,7 +603,7 @@ func cronWebhooks(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, level.Debug(logger).Log("exit", "done with cron.") return } - if locked, err := locker.Lock(ctx, lockKeyWebhooks, identifier, interval); err != nil || !locked { + if locked, err := ds.Lock(ctx, lockKeyWebhooks, identifier, interval); err != nil || !locked { level.Debug(logger).Log("leader", "Not the leader. Skipping...") continue } diff --git a/cmd/fleet/serve_test.go b/cmd/fleet/serve_test.go index c1afe59d5c..f91fc7c116 100644 --- a/cmd/fleet/serve_test.go +++ b/cmd/fleet/serve_test.go @@ -104,15 +104,6 @@ func TestMaybeSendStatisticsSkipsIfNotConfigured(t *testing.T) { assert.False(t, called) } -type alwaysLocker struct{} - -func (m *alwaysLocker) Lock(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { - return true, nil -} -func (m *alwaysLocker) Unlock(ctx context.Context, name string, owner string) error { - return nil -} - func TestCronWebhooks(t *testing.T) { ds := new(mock.Store) @@ -135,6 +126,12 @@ func TestCronWebhooks(t *testing.T) { }, }, nil } + ds.LockFunc = func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { + return true, nil + } + ds.UnlockFunc = func(ctx context.Context, name string, owner string) error { + return nil + } calledOnce := make(chan struct{}) calledTwice := make(chan struct{}) @@ -157,7 +154,7 @@ func TestCronWebhooks(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - go cronWebhooks(ctx, ds, kitlog.With(kitlog.NewNopLogger(), "cron", "webhooks"), &alwaysLocker{}, "1234") + go cronWebhooks(ctx, ds, kitlog.With(kitlog.NewNopLogger(), "cron", "webhooks"), "1234") <-calledOnce time.Sleep(1 * time.Second) @@ -174,6 +171,12 @@ func TestCronVulnerabilitiesCreatesDatabasesPath(t *testing.T) { ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{}, nil } + ds.LockFunc = func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { + return true, nil + } + ds.UnlockFunc = func(ctx context.Context, name string, owner string) error { + return nil + } vulnPath := path.Join(t.TempDir(), "something") require.NoDirExists(t, vulnPath) @@ -188,7 +191,7 @@ func TestCronVulnerabilitiesCreatesDatabasesPath(t *testing.T) { // We cancel right away so cronsVulnerailities finishes. The logic we are testing happens before the loop starts cancelFunc() - cronVulnerabilities(ctx, ds, kitlog.NewNopLogger(), &alwaysLocker{}, "AAA", fleetConfig) + cronVulnerabilities(ctx, ds, kitlog.NewNopLogger(), "AAA", fleetConfig) require.DirExists(t, vulnPath) } @@ -204,6 +207,12 @@ func TestCronVulnerabilitiesAcceptsExistingDbPath(t *testing.T) { ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{}, nil } + ds.LockFunc = func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { + return true, nil + } + ds.UnlockFunc = func(ctx context.Context, name string, owner string) error { + return nil + } fleetConfig := config.FleetConfig{ Vulnerabilities: config.VulnerabilitiesConfig{ @@ -215,7 +224,7 @@ func TestCronVulnerabilitiesAcceptsExistingDbPath(t *testing.T) { // We cancel right away so cronsVulnerailities finishes. The logic we are testing happens before the loop starts cancelFunc() - cronVulnerabilities(ctx, ds, logger, &alwaysLocker{}, "AAA", fleetConfig) + cronVulnerabilities(ctx, ds, logger, "AAA", fleetConfig) require.Contains(t, buf.String(), `{"level":"debug","waiting":"on ticker"}`) } @@ -231,6 +240,12 @@ func TestCronVulnerabilitiesQuitsIfErrorVulnPath(t *testing.T) { ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{}, nil } + ds.LockFunc = func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { + return true, nil + } + ds.UnlockFunc = func(ctx context.Context, name string, owner string) error { + return nil + } fileVulnPath := path.Join(t.TempDir(), "somefile") _, err := os.Create(fileVulnPath) @@ -246,7 +261,7 @@ func TestCronVulnerabilitiesQuitsIfErrorVulnPath(t *testing.T) { // We cancel right away so cronsVulnerailities finishes. The logic we are testing happens before the loop starts cancelFunc() - cronVulnerabilities(ctx, ds, logger, &alwaysLocker{}, "AAA", fleetConfig) + cronVulnerabilities(ctx, ds, logger, "AAA", fleetConfig) require.Contains(t, buf.String(), `"databases-path":"creation failed, returning"`) } @@ -262,6 +277,12 @@ func TestCronVulnerabilitiesSkipCreationIfStatic(t *testing.T) { ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{}, nil } + ds.LockFunc = func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { + return true, nil + } + ds.UnlockFunc = func(ctx context.Context, name string, owner string) error { + return nil + } vulnPath := path.Join(t.TempDir(), "something") require.NoDirExists(t, vulnPath) @@ -276,7 +297,7 @@ func TestCronVulnerabilitiesSkipCreationIfStatic(t *testing.T) { // We cancel right away so cronsVulnerailities finishes. The logic we are testing happens before the loop starts cancelFunc() - cronVulnerabilities(ctx, ds, logger, &alwaysLocker{}, "AAA", fleetConfig) + cronVulnerabilities(ctx, ds, logger, "AAA", fleetConfig) require.NoDirExists(t, vulnPath) } diff --git a/server/datastore/cached_mysql/cached_mysql.go b/server/datastore/cached_mysql/cached_mysql.go new file mode 100644 index 0000000000..41debf3358 --- /dev/null +++ b/server/datastore/cached_mysql/cached_mysql.go @@ -0,0 +1,100 @@ +package cached_mysql + +import ( + "context" + "encoding/json" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" + redigo "github.com/gomodule/redigo/redis" + "github.com/pkg/errors" +) + +type cachedMysql struct { + fleet.Datastore + + redisPool fleet.RedisPool +} + +const ( + CacheKeyAppConfig = "cache:AppConfig" +) + +func New(ds fleet.Datastore, redisPool fleet.RedisPool) fleet.Datastore { + return &cachedMysql{ + Datastore: ds, + redisPool: redisPool, + } +} + +func (ds *cachedMysql) storeInRedis(key string, v interface{}) error { + conn := ds.redisPool.ConfigureDoer(ds.redisPool.Get()) + defer conn.Close() + + b, err := json.Marshal(v) + if err != nil { + return errors.Wrap(err, "marshaling object to cache in redis") + } + + if _, err := conn.Do("SET", key, b, "EX", (24 * time.Hour).Seconds()); err != nil { + return errors.Wrap(err, "caching object in redis") + } + + return nil +} + +func (ds *cachedMysql) getFromRedis(key string, v interface{}) error { + conn := ds.redisPool.ConfigureDoer(ds.redisPool.Get()) + defer conn.Close() + + data, err := redigo.Bytes(conn.Do("GET", key)) + if err != nil { + return errors.Wrap(err, "getting value from cache") + } + + err = json.Unmarshal(data, v) + if err != nil { + return errors.Wrap(err, "unmarshaling object from cache") + } + + return nil +} + +func (ds *cachedMysql) NewAppConfig(ctx context.Context, info *fleet.AppConfig) (*fleet.AppConfig, error) { + ac, err := ds.Datastore.NewAppConfig(ctx, info) + if err != nil { + return nil, errors.Wrap(err, "calling new app config") + } + + err = ds.storeInRedis(CacheKeyAppConfig, ac) + + return ac, err +} + +func (ds *cachedMysql) AppConfig(ctx context.Context) (*fleet.AppConfig, error) { + ac := &fleet.AppConfig{} + ac.ApplyDefaults() + + err := ds.getFromRedis(CacheKeyAppConfig, ac) + if err == nil { + return ac, nil + } + + ac, err = ds.Datastore.AppConfig(ctx) + if err != nil { + return nil, errors.Wrap(err, "calling app config") + } + + err = ds.storeInRedis(CacheKeyAppConfig, ac) + + return ac, err +} + +func (ds *cachedMysql) SaveAppConfig(ctx context.Context, info *fleet.AppConfig) error { + err := ds.Datastore.SaveAppConfig(ctx, info) + if err != nil { + return errors.Wrap(err, "calling save app config") + } + + return ds.storeInRedis(CacheKeyAppConfig, info) +} diff --git a/server/datastore/cached_mysql/cached_mysql_test.go b/server/datastore/cached_mysql/cached_mysql_test.go new file mode 100644 index 0000000000..cda15822b7 --- /dev/null +++ b/server/datastore/cached_mysql/cached_mysql_test.go @@ -0,0 +1,159 @@ +package cached_mysql + +import ( + "context" + "encoding/json" + "os" + "runtime" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/datastore/redis" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + redigo "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newPool is basically repeated in every package that uses redis +// I tried to move this to a datastoretest package, but there's an import loop with redis +// so I decided to copy and past for now +func newPool(t *testing.T, cluster bool) fleet.RedisPool { + if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { + t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) + } + + if _, ok := os.LookupEnv("REDIS_TEST"); ok { + var ( + addr = "127.0.0.1:" + password = "" + database = 0 + useTLS = false + port = "6379" + ) + if cluster { + port = "7001" + } + addr += port + + pool, err := redis.NewRedisPool(redis.PoolConfig{ + Server: addr, + Password: password, + Database: database, + UseTLS: useTLS, + ConnTimeout: 5 * time.Second, + KeepAlive: 10 * time.Second, + }) + require.NoError(t, err) + conn := pool.Get() + defer conn.Close() + _, err = conn.Do("PING") + require.Nil(t, err) + return pool + } + return nil +} + +func TestCachedAppConfig(t *testing.T) { + pool := newPool(t, false) + conn := pool.Get() + _, err := conn.Do("DEL", CacheKeyAppConfig) + require.NoError(t, err) + + mockedDS := new(mock.Store) + ds := New(mockedDS, pool) + + var appConfigSet *fleet.AppConfig + mockedDS.NewAppConfigFunc = func(ctx context.Context, info *fleet.AppConfig) (*fleet.AppConfig, error) { + appConfigSet = info + return info, nil + } + mockedDS.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return appConfigSet, err + } + mockedDS.SaveAppConfigFunc = func(ctx context.Context, info *fleet.AppConfig) error { + appConfigSet = info + return nil + } + _, err = ds.NewAppConfig(context.Background(), &fleet.AppConfig{ + HostSettings: fleet.HostSettings{ + AdditionalQueries: ptr.RawMessage(json.RawMessage(`"TestCachedAppConfig"`)), + }, + }) + require.NoError(t, err) + + t.Run("NewAppConfig", func(t *testing.T) { + data, err := redigo.Bytes(conn.Do("GET", CacheKeyAppConfig)) + require.NoError(t, err) + + require.NotEmpty(t, data) + newAc := &fleet.AppConfig{} + require.NoError(t, json.Unmarshal(data, &newAc)) + require.NotNil(t, newAc.HostSettings.AdditionalQueries) + assert.Equal(t, json.RawMessage(`"TestCachedAppConfig"`), *newAc.HostSettings.AdditionalQueries) + }) + + t.Run("AppConfig", func(t *testing.T) { + require.False(t, mockedDS.AppConfigFuncInvoked) + ac, err := ds.AppConfig(context.Background()) + require.NoError(t, err) + require.False(t, mockedDS.AppConfigFuncInvoked) + + require.Equal(t, ptr.RawMessage(json.RawMessage(`"TestCachedAppConfig"`)), ac.HostSettings.AdditionalQueries) + }) + + t.Run("AppConfig uses DS if redis fails", func(t *testing.T) { + _, err = conn.Do("DEL", CacheKeyAppConfig) + require.NoError(t, err) + ac, err := ds.AppConfig(context.Background()) + require.NoError(t, err) + require.True(t, mockedDS.AppConfigFuncInvoked) + + require.Equal(t, ptr.RawMessage(json.RawMessage(`"TestCachedAppConfig"`)), ac.HostSettings.AdditionalQueries) + }) + + t.Run("SaveAppConfig", func(t *testing.T) { + require.NoError(t, ds.SaveAppConfig(context.Background(), &fleet.AppConfig{ + HostSettings: fleet.HostSettings{ + AdditionalQueries: ptr.RawMessage(json.RawMessage(`"NewSAVED"`)), + }, + })) + + data, err := redigo.Bytes(conn.Do("GET", CacheKeyAppConfig)) + require.NoError(t, err) + + require.NotEmpty(t, data) + newAc := &fleet.AppConfig{} + require.NoError(t, json.Unmarshal(data, &newAc)) + require.NotNil(t, newAc.HostSettings.AdditionalQueries) + assert.Equal(t, json.RawMessage(`"NewSAVED"`), *newAc.HostSettings.AdditionalQueries) + + ac, err := ds.AppConfig(context.Background()) + require.NoError(t, err) + require.NotNil(t, ac.HostSettings.AdditionalQueries) + assert.Equal(t, json.RawMessage(`"NewSAVED"`), *ac.HostSettings.AdditionalQueries) + }) + + t.Run("AuthenticateHost skips cache if disabled", func(t *testing.T) { + _, err = conn.Do("DEL", CacheKeyAppConfig) + require.NoError(t, err) + + mockedDS.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + mockedDS.AuthenticateHostFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + return &fleet.Host{ID: 999}, nil + } + _, err = ds.AuthenticateHost(context.Background(), "1234") + require.NoError(t, err) + require.True(t, mockedDS.AuthenticateHostFuncInvoked) + mockedDS.AuthenticateHostFuncInvoked = false + + _, err = ds.AuthenticateHost(context.Background(), "1234") + require.NoError(t, err) + require.True(t, mockedDS.AuthenticateHostFuncInvoked) + mockedDS.AuthenticateHostFuncInvoked = false + }) +} diff --git a/server/datastore/redis/redis_test.go b/server/datastore/redis/redis_test.go index 30b55a3c16..c1bbf82b36 100644 --- a/server/datastore/redis/redis_test.go +++ b/server/datastore/redis/redis_test.go @@ -120,8 +120,7 @@ func TestRedisPoolConfigureDoer(t *testing.T) { const prefix = "TestRedisPoolConfigureDoer:" t.Run("standalone", func(t *testing.T) { - pool, teardown := setupRedisForTest(t, false, false) - defer teardown() + pool := setupRedisForTest(t, false, false) c1 := pool.Get() defer c1.Close() @@ -142,8 +141,7 @@ func TestRedisPoolConfigureDoer(t *testing.T) { }) t.Run("cluster", func(t *testing.T) { - pool, teardown := setupRedisForTest(t, true, true) - defer teardown() + pool := setupRedisForTest(t, true, true) c1 := pool.Get() defer c1.Close() @@ -205,19 +203,17 @@ func TestEachRedisNode(t *testing.T) { } t.Run("standalone", func(t *testing.T) { - pool, teardown := setupRedisForTest(t, false, false) - defer teardown() + pool := setupRedisForTest(t, false, false) runTest(t, pool) }) t.Run("cluster", func(t *testing.T) { - pool, teardown := setupRedisForTest(t, true, false) - defer teardown() + pool := setupRedisForTest(t, true, false) runTest(t, pool) }) } -func setupRedisForTest(t *testing.T, cluster, redir bool) (pool fleet.RedisPool, teardown func()) { +func setupRedisForTest(t *testing.T, cluster, redir bool) (pool fleet.RedisPool) { if cluster && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") { t.Skipf("docker networking limitations prevent running redis cluster tests on %s", runtime.GOOS) } @@ -250,14 +246,14 @@ func setupRedisForTest(t *testing.T, cluster, redir bool) (pool fleet.RedisPool, _, err = conn.Do("PING") require.Nil(t, err) - teardown = func() { + t.Cleanup(func() { err := EachRedisNode(pool, func(conn redis.Conn) error { _, err := conn.Do("FLUSHDB") return err }) require.NoError(t, err) pool.Close() - } + }) - return pool, teardown + return pool } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 50b960b84a..c0eb622efe 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -358,6 +358,21 @@ type Datastore interface { ListTeamPolicies(ctx context.Context, teamID uint) ([]*Policy, error) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) TeamPolicy(ctx context.Context, teamID uint, policyID uint) (*Policy, error) + + /////////////////////////////////////////////////////////////////////////////// + // Team Policies + + // Lock tries to get an atomic lock on an instance named with `name` + // and an `owner` identified by a random string per instance. + // Subsequently locking the same resource name for the same owner + // renews the lock expiration. + // It returns true, nil if it managed to obtain a lock on the instance. + // false and potentially an error otherwise. + // This must not be blocking. + Lock(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) + // Unlock tries to unlock the lock by that `name` for the specified + // `owner`. Unlocking when not holding the lock shouldn't error + Unlock(ctx context.Context, name string, owner string) error } type MigrationStatus int diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 206d4433f4..a34f3de43f 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -287,6 +287,10 @@ type DeleteTeamPoliciesFunc func(ctx context.Context, teamID uint, ids []uint) ( type TeamPolicyFunc func(ctx context.Context, teamID uint, policyID uint) (*fleet.Policy, error) +type LockFunc func(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) + +type UnlockFunc func(ctx context.Context, name string, owner string) error + type DataStore struct { NewCarveFunc NewCarveFunc NewCarveFuncInvoked bool @@ -701,6 +705,12 @@ type DataStore struct { TeamPolicyFunc TeamPolicyFunc TeamPolicyFuncInvoked bool + + LockFunc LockFunc + LockFuncInvoked bool + + UnlockFunc UnlockFunc + UnlockFuncInvoked bool } func (s *DataStore) NewCarve(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) { @@ -1392,3 +1402,13 @@ func (s *DataStore) TeamPolicy(ctx context.Context, teamID uint, policyID uint) s.TeamPolicyFuncInvoked = true return s.TeamPolicyFunc(ctx, teamID, policyID) } + +func (s *DataStore) Lock(ctx context.Context, name string, owner string, expiration time.Duration) (bool, error) { + s.LockFuncInvoked = true + return s.LockFunc(ctx, name, owner, expiration) +} + +func (s *DataStore) Unlock(ctx context.Context, name string, owner string) error { + s.UnlockFuncInvoked = true + return s.UnlockFunc(ctx, name, owner) +}