From e2194be61ce6e47f9c386710d98a38166fd883ba Mon Sep 17 00:00:00 2001 From: gillespi314 <73313222+gillespi314@users.noreply.github.com> Date: Wed, 10 Aug 2022 11:00:56 -0500 Subject: [PATCH] Add `schedule` package and refactor cron jobs for cleanups, aggregations, and usage statistics (#6618) --- changes/scheduler | 2 + cmd/fleet/cron.go | 193 ++++++++------ cmd/fleet/serve.go | 59 ++--- server/service/schedule/schedule.go | 309 +++++++++++++++++++++++ server/service/schedule/schedule_test.go | 292 +++++++++++++++++++++ 5 files changed, 747 insertions(+), 108 deletions(-) create mode 100644 changes/scheduler create mode 100644 server/service/schedule/schedule.go create mode 100644 server/service/schedule/schedule_test.go diff --git a/changes/scheduler b/changes/scheduler new file mode 100644 index 0000000000..b97b704856 --- /dev/null +++ b/changes/scheduler @@ -0,0 +1,2 @@ +- Added new `scheduler` package and refactored `cronDB` (which included cron operations for + cleanups, aggregations, and usage statistics) to use the new package. diff --git a/cmd/fleet/cron.go b/cmd/fleet/cron.go index 521473b8a6..ea946b526c 100644 --- a/cmd/fleet/cron.go +++ b/cmd/fleet/cron.go @@ -9,11 +9,13 @@ import ( "time" "github.com/fleetdm/fleet/v4/pkg/fleethttp" + "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/policies" "github.com/fleetdm/fleet/v4/server/service/externalsvc" + "github.com/fleetdm/fleet/v4/server/service/schedule" "github.com/fleetdm/fleet/v4/server/vulnerabilities" "github.com/fleetdm/fleet/v4/server/vulnerabilities/oval" "github.com/fleetdm/fleet/v4/server/webhooks" @@ -29,81 +31,6 @@ func errHandler(ctx context.Context, logger kitlog.Logger, msg string, err error ctxerr.Handle(ctx, err) } -func cronDB(ctx context.Context, ds fleet.Datastore, logger kitlog.Logger, identifier string, config *config.FleetConfig, license *fleet.LicenseInfo, enrollHostLimiter fleet.EnrollHostLimiter) { - logger = kitlog.With(logger, "cron", lockKeyLeader) - - ticker := time.NewTicker(10 * time.Second) - for { - level.Debug(logger).Log("waiting", "on ticker") - select { - case <-ticker.C: - level.Debug(logger).Log("waiting", "done") - ticker.Reset(1 * time.Hour) - case <-ctx.Done(): - level.Debug(logger).Log("exit", "done with cron.") - return - } - - if locked, err := ds.Lock(ctx, lockKeyLeader, identifier, 1*time.Hour); err != nil { - level.Error(logger).Log("msg", "Error acquiring lock", "err", err) - continue - } else if !locked { - level.Debug(logger).Log("msg", "Not the leader. Skipping...") - continue - } - - _, err := ds.CleanupDistributedQueryCampaigns(ctx, time.Now()) - if err != nil { - errHandler(ctx, logger, "cleaning distributed query campaigns", err) - } - _, err = ds.CleanupIncomingHosts(ctx, time.Now()) - if err != nil { - errHandler(ctx, logger, "cleaning incoming hosts", err) - } - _, err = ds.CleanupCarves(ctx, time.Now()) - if err != nil { - errHandler(ctx, logger, "cleaning carves", err) - } - err = ds.UpdateQueryAggregatedStats(ctx) - if err != nil { - errHandler(ctx, logger, "aggregating query stats", err) - } - err = ds.UpdateScheduledQueryAggregatedStats(ctx) - if err != nil { - errHandler(ctx, logger, "aggregating scheduled query stats", err) - } - _, err = ds.CleanupExpiredHosts(ctx) - if err != nil { - errHandler(ctx, logger, "cleaning expired hosts", err) - } - err = ds.GenerateAggregatedMunkiAndMDM(ctx) - if err != nil { - errHandler(ctx, logger, "aggregating munki and mdm data", err) - } - err = ds.CleanupPolicyMembership(ctx, time.Now()) - if err != nil { - errHandler(ctx, logger, "cleanup policy membership", err) - } - err = ds.UpdateOSVersions(ctx) - if err != nil { - errHandler(ctx, logger, "update os versions", err) - } - err = enrollHostLimiter.SyncEnrolledHostIDs(ctx) - if err != nil { - errHandler(ctx, logger, "sync enrolled host ids", err) - } - - // NOTE(mna): this is not a route from the fleet server (not in server/service/handler.go) so it - // will not automatically support the /latest/ versioning. Leaving it as /v1/ for that reason. - err = trySendStatistics(ctx, ds, fleet.StatisticsFrequency, "https://fleetdm.com/api/v1/webhooks/receive-usage-analytics", *config, license) - if err != nil { - errHandler(ctx, logger, "sending statistics", err) - } - - level.Debug(logger).Log("loop", "done") - } -} - func cronVulnerabilities( ctx context.Context, ds fleet.Datastore, @@ -689,3 +616,119 @@ func newFailerClient(forcedFailures string) *worker.TestAutomationFailer { } return failerClient } + +func startCleanupsAndAggregationSchedule( + ctx context.Context, instanceID string, ds fleet.Datastore, logger kitlog.Logger, enrollHostLimiter fleet.EnrollHostLimiter, +) { + schedule.New( + ctx, "cleanups_then_aggregation", instanceID, 1*time.Hour, ds, + // Using leader for the lock to be backwards compatilibity with old deployments. + schedule.WithAltLockID("leader"), + schedule.WithLogger(kitlog.With(logger, "cron", "cleanups_then_aggregation")), + // Run cleanup jobs first. + schedule.WithJob( + "distributed_query_campaings", + func(ctx context.Context) error { + _, err := ds.CleanupDistributedQueryCampaigns(ctx, time.Now()) + return err + }, + ), + schedule.WithJob( + "incoming_hosts", + func(ctx context.Context) error { + _, err := ds.CleanupIncomingHosts(ctx, time.Now()) + return err + }, + ), + schedule.WithJob( + "carves", + func(ctx context.Context) error { + _, err := ds.CleanupCarves(ctx, time.Now()) + return err + }, + ), + schedule.WithJob( + "expired_hosts", + func(ctx context.Context) error { + _, err := ds.CleanupExpiredHosts(ctx) + return err + }, + ), + schedule.WithJob( + "policy_membership", + func(ctx context.Context) error { + return ds.CleanupPolicyMembership(ctx, time.Now()) + }, + ), + schedule.WithJob( + "sync_enrolled_host_ids", + func(ctx context.Context) error { + return enrollHostLimiter.SyncEnrolledHostIDs(ctx) + }, + ), + // Run aggregation jobs after cleanups. + schedule.WithJob( + "query_aggregated_stats", + func(ctx context.Context) error { + return ds.UpdateQueryAggregatedStats(ctx) + }, + ), + schedule.WithJob( + "scheduled_query_aggregated_stats", + func(ctx context.Context) error { + return ds.UpdateScheduledQueryAggregatedStats(ctx) + }, + ), + schedule.WithJob( + "aggregated_munki_and_mdm", + func(ctx context.Context) error { + return ds.GenerateAggregatedMunkiAndMDM(ctx) + }, + ), + schedule.WithJob( + "update_os_versions", + func(ctx context.Context) error { + return ds.UpdateOSVersions(ctx) + }, + ), + ).Start() +} + +func startSendStatsSchedule(ctx context.Context, instanceID string, ds fleet.Datastore, config config.FleetConfig, license *fleet.LicenseInfo, logger kitlog.Logger) { + schedule.New( + ctx, "stats", instanceID, 1*time.Hour, ds, + schedule.WithLogger(kitlog.With(logger, "cron", "stats")), + schedule.WithJob( + "try_send_statistics", + func(ctx context.Context) error { + // NOTE(mna): this is not a route from the fleet server (not in server/service/handler.go) so it + // will not automatically support the /latest/ versioning. Leaving it as /v1/ for that reason. + return trySendStatistics(ctx, ds, fleet.StatisticsFrequency, "https://fleetdm.com/api/v1/webhooks/receive-usage-analytics", config, license) + }, + ), + ).Start() +} + +func trySendStatistics(ctx context.Context, ds fleet.Datastore, frequency time.Duration, url string, config config.FleetConfig, license *fleet.LicenseInfo) error { + ac, err := ds.AppConfig(ctx) + if err != nil { + return err + } + if !ac.ServerSettings.EnableAnalytics { + return nil + } + + stats, shouldSend, err := ds.ShouldSendStatistics(ctx, frequency, config, license) + if err != nil { + return err + } + if !shouldSend { + return nil + } + + err = server.PostJSONWithTimeout(ctx, url, stats) + if err != nil { + return err + } + return ds.RecordStatisticsSent(ctx) +} diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index 5adbbf3767..8c4c0145cc 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -25,6 +25,7 @@ import ( "github.com/fleetdm/fleet/v4/ee/server/licensing" eeservice "github.com/fleetdm/fleet/v4/ee/server/service" "github.com/fleetdm/fleet/v4/server" + "github.com/fleetdm/fleet/v4/server/config" configpkg "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/datastore/cached_mysql" @@ -395,7 +396,14 @@ the way that the Fleet server works. } } - runCrons(ctx, ds, task, kitlog.With(logger, "component", "crons"), config, license, failingPolicySet, redisWrapperDS) + instanceID, err := server.GenerateRandomText(64) + if err != nil { + initFatal(errors.New("Error generating random instance identifier"), "") + } + runCrons(ctx, ds, task, kitlog.With(logger, "component", "crons"), config, license, failingPolicySet, instanceID) + if err := startSchedules(ctx, ds, logger, config, license, redisWrapperDS, instanceID); err != nil { + initFatal(err, "failed to register schedules") + } // Flush seen hosts every second hostsAsyncCfg := config.Osquery.AsyncConfigForTask(configpkg.AsyncTaskHostLastSeen) @@ -645,37 +653,13 @@ func basicAuthHandler(username, password string, next http.Handler) http.Handler } const ( - lockKeyLeader = "leader" lockKeyVulnerabilities = "vulnerabilities" lockKeyWebhooksHostStatus = "webhooks" // keeping this name for backwards compatibility. lockKeyWebhooksFailingPolicies = "webhooks:global_failing_policies" lockKeyWorker = "worker" ) -func trySendStatistics(ctx context.Context, ds fleet.Datastore, frequency time.Duration, url string, config configpkg.FleetConfig, license *fleet.LicenseInfo) error { - ac, err := ds.AppConfig(ctx) - if err != nil { - return err - } - if !ac.ServerSettings.EnableAnalytics { - return nil - } - - stats, shouldSend, err := ds.ShouldSendStatistics(ctx, frequency, config, license) - if err != nil { - return err - } - if !shouldSend { - return nil - } - - err = server.PostJSONWithTimeout(ctx, url, stats) - if err != nil { - return err - } - return ds.RecordStatisticsSent(ctx) -} - +// runCrons runs cron jobs not yet ported to use the schedule package (startSchedules) func runCrons( ctx context.Context, ds fleet.Datastore, @@ -684,23 +668,32 @@ func runCrons( config configpkg.FleetConfig, license *fleet.LicenseInfo, failingPoliciesSet fleet.FailingPolicySet, - enrollHostLimiter fleet.EnrollHostLimiter, + ourIdentifier string, ) { - ourIdentifier, err := server.GenerateRandomText(64) - if err != nil { - initFatal(ctxerr.New(ctx, "generating random instance identifier"), "") - } - // StartCollectors starts a goroutine per collector, using ctx to cancel. task.StartCollectors(ctx, kitlog.With(logger, "cron", "async_task")) - go cronDB(ctx, ds, kitlog.With(logger, "cron", "cleanups"), ourIdentifier, &config, license, enrollHostLimiter) go cronVulnerabilities( ctx, ds, kitlog.With(logger, "cron", "vulnerabilities"), ourIdentifier, &config.Vulnerabilities) go cronWebhooks(ctx, ds, kitlog.With(logger, "cron", "webhooks"), ourIdentifier, failingPoliciesSet, 1*time.Hour) go cronWorker(ctx, ds, kitlog.With(logger, "cron", "worker"), ourIdentifier) } +func startSchedules( + ctx context.Context, + ds fleet.Datastore, + logger kitlog.Logger, + config config.FleetConfig, + license *fleet.LicenseInfo, + enrollHostLimiter fleet.EnrollHostLimiter, + instanceID string, +) error { + startCleanupsAndAggregationSchedule(ctx, instanceID, ds, logger, enrollHostLimiter) + startSendStatsSchedule(ctx, instanceID, ds, config, license, logger) + + return nil +} + // Support for TLS security profiles, we set up the TLS configuation based on // value supplied to server_tls_compatibility command line flag. The default // profile is 'modern'. diff --git a/server/service/schedule/schedule.go b/server/service/schedule/schedule.go new file mode 100644 index 0000000000..91e98b9b32 --- /dev/null +++ b/server/service/schedule/schedule.go @@ -0,0 +1,309 @@ +// Package schedule allows periodic run of a list of jobs. +// +// Type Schedule allows grouping a set of Jobs to run at specific intervals. +// Each Job is executed serially in the order they were added to the Schedule. +package schedule + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/getsentry/sentry-go" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" +) + +// ReloadInterval reloads and returns a new interval. +type ReloadInterval func(ctx context.Context) (time.Duration, error) + +// Schedule runs a list of jobs serially at a given schedule. +// +// Each job is executed one after the other in the order they were added. +// If one of the job fails, an error is logged and the scheduler +// continues with the next. +type Schedule struct { + ctx context.Context + name string + instanceID string + logger log.Logger + + schedIntervalMu sync.Mutex // protects schedInterval. + schedInterval time.Duration + + done chan struct{} + + configReloadInterval time.Duration + configReloadIntervalFn ReloadInterval + + locker Locker + + altLockName string + + jobs []Job +} + +// JobFn is the signature of a Job. +type JobFn func(context.Context) error + +// Job represents a job that can be added to Scheduler. +type Job struct { + // ID is the unique identifier for the job. + ID string + // Fn is the job itself. + Fn JobFn +} + +// Locker allows a Schedule to acquire a lock before running jobs. +type Locker interface { + Lock(ctx context.Context, scheduleName string, scheduleInstanceID string, expiration time.Duration) (bool, error) +} + +// Option allows configuring a Schedule. +type Option func(*Schedule) + +// WithLogger sets a logger for the Schedule. +func WithLogger(l log.Logger) Option { + return func(s *Schedule) { + s.logger = log.With(l, "schedule", s.name) + } +} + +// WithConfigReloadInterval allows setting a reload interval function, +// that will allow updating the interval of a running schedule. +// +// If not set, then the schedule performs no interval reloading. +func WithConfigReloadInterval(interval time.Duration, fn ReloadInterval) Option { + return func(s *Schedule) { + s.configReloadInterval = interval + s.configReloadIntervalFn = fn + } +} + +// WithAltLockID sets an alternative identifier to use when acquiring the lock. +// +// If not set, then the Schedule's name is used for acquiring the lock. +func WithAltLockID(name string) Option { + return func(s *Schedule) { + s.altLockName = name + } +} + +// WithJob adds a job to the Schedule. +// +// Each job is executed in the order they are added. +func WithJob(id string, fn JobFn) Option { + return func(s *Schedule) { + s.jobs = append(s.jobs, Job{ + ID: id, + Fn: fn, + }) + } +} + +// New creates and returns a Schedule. +// Jobs are added with the WithJob Option. +// +// The jobs are executed serially in order at the provided interval. +// +// The provided locker is used to acquire/release a lock before running the jobs. +// The provided name and instanceID of the Schedule is used as the locking identifier. +func New( + ctx context.Context, + name string, + instanceID string, + interval time.Duration, + locker Locker, + opts ...Option, +) *Schedule { + sch := &Schedule{ + ctx: ctx, + name: name, + instanceID: instanceID, + logger: log.NewNopLogger(), + done: make(chan struct{}), + configReloadInterval: 1 * time.Hour, // by default we will check for updated config once per hour + schedInterval: interval, + locker: locker, + } + for _, fn := range opts { + fn(sch) + } + return sch +} + +// Start starts running the added jobs. +// +// All jobs must be added before calling Start. +func (s *Schedule) Start() { + var m sync.Mutex // protects currentStart and currentWait. + currentStart := time.Now() + currentWait := 10 * time.Second + + getWaitTimes := func() (start time.Time, wait time.Duration) { + m.Lock() + defer m.Unlock() + + return currentStart, currentWait + } + + setWaitTimes := func(start time.Time, wait time.Duration) { + m.Lock() + defer m.Unlock() + + currentStart = start + currentWait = wait + } + + if schedInterval := s.getSchedInterval(); schedInterval < currentWait { + setWaitTimes(currentStart, schedInterval) + } + + var g sync.WaitGroup + + schedTicker := time.NewTicker(currentWait) + g.Add(+1) + go func() { + defer g.Done() + + for { + _, currWait := getWaitTimes() + level.Debug(s.logger).Log("msg", "waiting", "current wait time", currWait) + + select { + case <-s.ctx.Done(): + return + + case <-schedTicker.C: + level.Debug(s.logger).Log("waiting", "done") + + schedInterval := s.getSchedInterval() + schedTicker.Reset(schedInterval) + + newStart := time.Now() + newWait := schedInterval + setWaitTimes(newStart, newWait) + + if ok := s.acquireLock(); !ok { + continue + } + + for _, job := range s.jobs { + level.Debug(s.logger).Log("msg", "starting", "jobID", job.ID) + if err := runJob(s.ctx, job.Fn); err != nil { + level.Error(s.logger).Log("err", job.ID, "details", err) + sentry.CaptureException(err) + ctxerr.Handle(s.ctx, err) + } + } + } + } + }() + + // Periodically check for config updates and resets the schedInterval for the previous loop. + g.Add(+1) + go func() { + defer g.Done() + configTicker := time.NewTicker(200 * time.Millisecond) + + for { + select { + case <-s.ctx.Done(): + return + case <-configTicker.C: + level.Debug(s.logger).Log("msg", "config reload check") + + configTicker.Reset(s.configReloadInterval) + + schedInterval := s.getSchedInterval() + currStart, _ := getWaitTimes() + + if s.configReloadIntervalFn == nil { + level.Debug(s.logger).Log("msg", "config reload interval method not set") + continue + } + + newInterval, err := s.configReloadIntervalFn(s.ctx) + if err != nil { + level.Error(s.logger).Log("msg", "schedule interval config reload failed", "err", err) + sentry.CaptureException(err) + continue + } + if schedInterval == newInterval { + level.Debug(s.logger).Log("msg", "schedule interval unchanged") + continue + } + s.setSchedInterval(newInterval) + + newWait := 10 * time.Millisecond + if time.Since(currStart) < newInterval { + newWait = newInterval - time.Since(currStart) + } + setWaitTimes(currStart, newWait) + schedTicker.Reset(newWait) + + level.Debug(s.logger).Log("new schedule interval", newInterval, "new wait", newWait) + } + } + }() + + go func() { + g.Wait() + level.Debug(s.logger).Log("msg", "done") + close(s.done) // communicates that the scheduler has finished running its goroutines + }() +} + +// runJob executes the job function with panic recovery +func runJob(ctx context.Context, fn JobFn) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + + if err := fn(ctx); err != nil { + return err + } + return nil +} + +// Done returns a channel that will be closed when the scheduler's context is done +// and it has finished running its goroutines. +func (s *Schedule) Done() <-chan struct{} { + return s.done +} + +func (s *Schedule) getSchedInterval() time.Duration { + s.schedIntervalMu.Lock() + defer s.schedIntervalMu.Unlock() + + return s.schedInterval +} + +func (s *Schedule) setSchedInterval(interval time.Duration) { + s.schedIntervalMu.Lock() + defer s.schedIntervalMu.Unlock() + + s.schedInterval = interval +} + +func (s *Schedule) acquireLock() bool { + name := s.name + if s.altLockName != "" { + name = s.altLockName + } + locked, err := s.locker.Lock(s.ctx, name, s.instanceID, s.getSchedInterval()) + if err != nil { + level.Error(s.logger).Log("msg", "lock failed", "err", err) + sentry.CaptureException(err) + return false + } + if locked { + return true + } + level.Debug(s.logger).Log("msg", "not the lock leader, skipping") + return false +} diff --git a/server/service/schedule/schedule_test.go b/server/service/schedule/schedule_test.go new file mode 100644 index 0000000000..0c9c5a6404 --- /dev/null +++ b/server/service/schedule/schedule_test.go @@ -0,0 +1,292 @@ +package schedule + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +type nopLocker struct{} + +func (nopLocker) Lock(context.Context, string, string, time.Duration) (bool, error) { + return true, nil +} + +func TestNewSchedule(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + jobRan := false + s := New(ctx, "test_new_schedule", "test_instance", 10*time.Millisecond, nopLocker{}, + WithJob("test_job", func(ctx context.Context) error { + jobRan = true + return nil + }), + ) + s.Start() + + time.Sleep(1 * time.Second) + cancel() + + select { + case <-s.Done(): + require.True(t, jobRan) + case <-time.After(5 * time.Second): + t.Error("timeout") + } +} + +type counterLocker struct { + mu sync.Mutex + count int +} + +func (l *counterLocker) Lock(context.Context, string, string, time.Duration) (bool, error) { + l.mu.Lock() + defer l.mu.Unlock() + + l.count = l.count + 1 + return true, nil +} + +func TestScheduleLocker(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + locker := counterLocker{} + jobRunCount := 0 + s := New(ctx, "test_schedule_locker", "test_instance", 10*time.Millisecond, &locker, + WithJob("test_job", func(ctx context.Context) error { + jobRunCount++ + return nil + }), + ) + s.Start() + + time.Sleep(1 * time.Second) + cancel() + + select { + case <-s.Done(): + require.Equal(t, locker.count, jobRunCount) + case <-time.After(5 * time.Second): + t.Error("timeout") + } +} + +func TestMultipleSchedules(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var ss []*Schedule + + var m sync.Mutex + jobRun := make(map[string]struct{}) + setJobRun := func(id string) { + m.Lock() + defer m.Unlock() + + jobRun[id] = struct{}{} + } + var jobNames []string + + for _, tc := range []struct { + name string + instanceID string + interval time.Duration + jobs []Job + }{ + { + name: "test_schedule_1", + instanceID: "test_instance", + interval: 10 * time.Millisecond, + jobs: []Job{ + { + ID: "test_job_1", + Fn: func(ctx context.Context) error { + setJobRun("test_job_1") + return nil + }, + }, + }, + }, + { + name: "test_schedule_2", + instanceID: "test_instance", + interval: 100 * time.Millisecond, + jobs: []Job{ + { + ID: "test_job_2", + Fn: func(ctx context.Context) error { + setJobRun("test_job_2") + return nil + }, + }, + }, + }, + { + name: "test_schedule_3", + instanceID: "test_instance", + interval: 100 * time.Millisecond, + jobs: []Job{ + { + ID: "test_job_3", + Fn: func(ctx context.Context) error { + setJobRun("test_job_3") + return errors.New("job 3") // job 3 fails, job 4 should still run. + }, + }, + { + ID: "test_job_4", + Fn: func(ctx context.Context) error { + setJobRun("test_job_4") + return nil + }, + }, + }, + }, + } { + var opts []Option + for _, job := range tc.jobs { + opts = append(opts, WithJob(job.ID, job.Fn)) + jobNames = append(jobNames, job.ID) + } + s := New(ctx, tc.name, tc.instanceID, tc.interval, nopLocker{}, opts...) + s.Start() + ss = append(ss, s) + } + + time.Sleep(1 * time.Second) + cancel() + + for i, s := range ss { + select { + case <-s.Done(): + // OK + case <-time.After(1 * time.Second): + t.Errorf("timeout: %d", i) + } + } + for _, s := range jobNames { + _, ok := jobRun[s] + require.True(t, ok, "job: %s", s) + } +} + +func TestMultipleJobsInOrder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + jobs := make(chan int) + + s := New(ctx, "test_schedule", "test_instance", 100*time.Millisecond, nopLocker{}, + WithJob("test_job_1", func(ctx context.Context) error { + jobs <- 1 + return nil + }), + WithJob("test_job_2", func(ctx context.Context) error { + jobs <- 2 + return errors.New("test_job_2") + }), + WithJob("test_job_3", func(ctx context.Context) error { + jobs <- 3 + return nil + }), + ) + s.Start() + + var g errgroup.Group + g.Go(func() error { + i := 1 + for { + select { + case job, ok := <-jobs: + if !ok { + return nil + } + if job != i { + return fmt.Errorf("mismatch id: %d vs %d", job, i) + } + i++ + if i == 4 { + i = 1 + } + case <-time.After(5 * time.Second): + return fmt.Errorf("timeout: %d", i) + } + } + }) + + time.Sleep(1 * time.Second) + cancel() + select { + case <-s.Done(): + close(jobs) + case <-time.After(5 * time.Second): + t.Error("timeout") + } + + err := g.Wait() + require.NoError(t, err) +} + +func TestConfigReloadCheck(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + jobRan := false + s := New(ctx, "test_schedule", "test_instance", 200*time.Millisecond, nopLocker{}, + WithConfigReloadInterval(100*time.Millisecond, func(_ context.Context) (time.Duration, error) { + return 50 * time.Millisecond, nil + }), + WithJob("test_job", func(ctx context.Context) error { + jobRan = true + return nil + }), + ) + + require.Equal(t, s.getSchedInterval(), 200*time.Millisecond) + require.Equal(t, s.configReloadInterval, 100*time.Millisecond) + + s.Start() + + time.Sleep(1 * time.Second) + cancel() + + select { + case <-s.Done(): + require.Equal(t, s.getSchedInterval(), 50*time.Millisecond) + require.Equal(t, s.configReloadInterval, 100*time.Millisecond) + require.True(t, jobRan) + case <-time.After(5 * time.Second): + t.Error("timeout") + } +} + +func TestJobPanicRecover(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + jobRan := false + + s := New(ctx, "test_new_schedule", "test_instance", 10*time.Millisecond, nopLocker{}, + WithJob("job_1", func(ctx context.Context) error { + panic("job_1") + }), + WithJob("job_2", func(ctx context.Context) error { + jobRan = true + return nil + })) + s.Start() + + time.Sleep(1 * time.Second) + cancel() + + select { + case <-s.Done(): + // job 2 should still run even though job 1 panicked + require.True(t, jobRan) + case <-time.After(5 * time.Second): + t.Error("timeout") + } +}