diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index eee7d96cc1..10a9c317a6 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -68,7 +68,7 @@ jobs: # Pre-starting dependencies here means they are ready to go when we need them. - name: Start Infra Dependencies # Use & to background this - run: FLEET_MYSQL_IMAGE=${{ matrix.mysql }} docker-compose -f docker-compose.yml -f docker-compose-redis-cluster.yml up -d mysql_test redis redis-cluster-1 redis-cluster-2 redis-cluster-3 redis-cluster-4 redis-cluster-5 redis-cluster-6 redis-cluster-setup minio saml_idp mailhog mailpit smtp4dev_test & + run: FLEET_MYSQL_IMAGE=${{ matrix.mysql }} docker-compose -f docker-compose.yml -f docker-compose-redis-cluster.yml up -d mysql_test mysql_replica_test redis redis-cluster-1 redis-cluster-2 redis-cluster-3 redis-cluster-4 redis-cluster-5 redis-cluster-6 redis-cluster-setup minio saml_idp mailhog mailpit smtp4dev_test & - name: Add TLS certificate for SMTP Tests run: | @@ -101,6 +101,12 @@ jobs: sleep 1 done echo "mysql is ready" + echo "waiting for mysql replica..." + until docker-compose exec -T mysql_replica_test sh -c "mysql -uroot -p\"\${MYSQL_ROOT_PASSWORD}\" -e \"SELECT 1=1\" fleet" &> /dev/null; do + echo "." + sleep 1 + done + echo "mysql replica is ready" - name: Run Go Tests run: | @@ -109,6 +115,7 @@ jobs: NETWORK_TEST=1 \ REDIS_TEST=1 \ MYSQL_TEST=1 \ + MYSQL_REPLICA_TEST=1 \ MINIO_STORAGE_TEST=1 \ SAML_IDP_TEST=1 \ MAIL_TEST=1 \ diff --git a/changes/18838-additional-db-optimizations b/changes/18838-additional-db-optimizations index cff51b6313..97be894d07 100644 --- a/changes/18838-additional-db-optimizations +++ b/changes/18838-additional-db-optimizations @@ -2,3 +2,4 @@ MySQL query optimizations: - During software ingestion, switched to updating last_opened_at as a batch (for 1 host). - Removed DELETE FROM software statement that ran for every host update (when software was deleted). The cleanup of unused software is now only done during the vulnerability job. - `/api/v1/fleet/software/versions/:id` endpoint can return software even if it has been recently deleted from all hosts. During hourly cleanup, this software item will be removed from the database. +- Moved aggregated query stats calculations to read replica DB to reduce load on the master. diff --git a/cmd/fleetctl/query_test.go b/cmd/fleetctl/query_test.go index bfd41dc428..f697b39160 100644 --- a/cmd/fleetctl/query_test.go +++ b/cmd/fleetctl/query_test.go @@ -105,10 +105,15 @@ func TestSavedLiveQuery(t *testing.T) { return true, nil } var GetLiveQueryStatsFuncWg sync.WaitGroup - GetLiveQueryStatsFuncWg.Add(1) + GetLiveQueryStatsFuncWg.Add(2) ds.GetLiveQueryStatsFunc = func(ctx context.Context, queryID uint, hostIDs []uint) ([]*fleet.LiveQueryStats, error) { + stats := []*fleet.LiveQueryStats{ + { + LastExecuted: time.Now(), + }, + } GetLiveQueryStatsFuncWg.Done() - return nil, nil + return stats, nil } var UpdateLiveQueryStatsFuncWg sync.WaitGroup UpdateLiveQueryStatsFuncWg.Add(1) diff --git a/docker-compose.yml b/docker-compose.yml index 5684e856d8..bb7e3066a3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -46,7 +46,7 @@ services: # These 3 keys run MySQL with GTID consistency enforced to avoid issues with production deployments that use it. "--enforce-gtid-consistency=ON", "--log-bin=bin.log", - "--server-id=master-01", + "--server-id=1", # Required for storage of Apple MDM bootstrap packages. "--max_allowed_packet=536870912" ] @@ -57,6 +57,34 @@ services: - /var/lib/mysql:rw,noexec,nosuid - /tmpfs + mysql_replica_test: + image: ${FLEET_MYSQL_IMAGE:-mysql:5.7} + platform: ${FLEET_MYSQL_PLATFORM:-linux/x86_64} + # innodb-file-per-table=OFF gives ~20% speedup for test runs. + command: + [ + "mysqld", + "--datadir=/tmpfs", + "--slow_query_log=1", + "--log_output=TABLE", + "--log-queries-not-using-indexes", + "--innodb-file-per-table=OFF", + "--table-definition-cache=8192", + # These 3 keys run MySQL with GTID consistency enforced to avoid issues with production deployments that use it. + "--enforce-gtid-consistency=ON", + "--log-bin=bin.log", + "--server-id=2", + # Required for storage of Apple MDM bootstrap packages. + "--max_allowed_packet=536870912" + ] + environment: *mysql-default-environment + ports: + # ports 3308 and 3309 are used by the main and replica MySQL containers in tools/mysql-replica-testing/docker-compose.yml + - "3310:3306" + tmpfs: + - /var/lib/mysql:rw,noexec,nosuid + - /tmpfs + # Unauthenticated SMTP server. mailhog: image: mailhog/mailhog:latest diff --git a/server/datastore/mysql/aggregated_stats.go b/server/datastore/mysql/aggregated_stats.go index f908c95fc9..9523904132 100644 --- a/server/datastore/mysql/aggregated_stats.go +++ b/server/datastore/mysql/aggregated_stats.go @@ -95,20 +95,20 @@ func (ds *Datastore) UpdateQueryAggregatedStats(ctx context.Context) error { // CalculateAggregatedPerfStatsPercentiles calculates the aggregated user/system time performance statistics for the given query. func (ds *Datastore) CalculateAggregatedPerfStatsPercentiles(ctx context.Context, aggregate fleet.AggregatedStatsType, queryID uint) error { - tx := ds.writer(ctx) + reader := ds.reader(ctx) var totalExecutions int statsMap := make(map[string]interface{}) // many queries is not ideal, but getting both values and totals in the same query was a bit more complicated // so I went for the simpler approach first, we can optimize later - if err := setP50AndP95Map(ctx, tx, aggregate, "user_time", queryID, statsMap); err != nil { + if err := setP50AndP95Map(ctx, reader, aggregate, "user_time", queryID, statsMap); err != nil { return err } - if err := setP50AndP95Map(ctx, tx, aggregate, "system_time", queryID, statsMap); err != nil { + if err := setP50AndP95Map(ctx, reader, aggregate, "system_time", queryID, statsMap); err != nil { return err } - err := sqlx.GetContext(ctx, tx, &totalExecutions, getTotalExecutionsQuery(aggregate), queryID) + err := sqlx.GetContext(ctx, reader, &totalExecutions, getTotalExecutionsQuery(aggregate), queryID) if err != nil { return ctxerr.Wrapf(ctx, err, "getting total executions for %s %d", aggregate, queryID) } @@ -122,7 +122,8 @@ func (ds *Datastore) CalculateAggregatedPerfStatsPercentiles(ctx context.Context // NOTE: this function gets called for query and scheduled_query, so the id // refers to a query/scheduled_query id, and it never computes "global" // stats. For that reason, we always set global_stats=0. - _, err = tx.ExecContext(ctx, + _, err = ds.writer(ctx).ExecContext( + ctx, ` INSERT INTO aggregated_stats(id, type, global_stats, json_value) VALUES (?, ?, 0, ?) diff --git a/server/datastore/mysql/mysql_test.go b/server/datastore/mysql/mysql_test.go index 4222d62702..5967fdfdc2 100644 --- a/server/datastore/mysql/mysql_test.go +++ b/server/datastore/mysql/mysql_test.go @@ -43,7 +43,7 @@ func TestDatastoreReplica(t *testing.T) { }) t.Run("replica", func(t *testing.T) { - opts := &DatastoreTestOptions{Replica: true} + opts := &DatastoreTestOptions{DummyReplica: true} ds := CreateMySQLDSWithOptions(t, opts) defer ds.Close() require.NotEqual(t, ds.reader(ctx), ds.writer(ctx)) diff --git a/server/datastore/mysql/queries.go b/server/datastore/mysql/queries.go index 2ad343922c..47ceb0e03b 100644 --- a/server/datastore/mysql/queries.go +++ b/server/datastore/mysql/queries.go @@ -673,7 +673,7 @@ func (ds *Datastore) IsSavedQuery(ctx context.Context, queryID uint) (bool, erro // GetLiveQueryStats returns the live query stats for the given query and hosts. func (ds *Datastore) GetLiveQueryStats(ctx context.Context, queryID uint, hostIDs []uint) ([]*fleet.LiveQueryStats, error) { stmt, args, err := sqlx.In( - `SELECT host_id, average_memory, executions, system_time, user_time, wall_time, output_size + `SELECT host_id, average_memory, executions, system_time, user_time, wall_time, output_size, last_executed FROM scheduled_query_stats WHERE host_id IN (?) AND scheduled_query_id = ? AND query_type = ? `, hostIDs, queryID, statsLiveQueryType, @@ -696,8 +696,8 @@ func (ds *Datastore) UpdateLiveQueryStats(ctx context.Context, queryID uint, sta } // Bulk insert/update - const valueStr = "(?,?,?,?,?,?,?,?,?,?,?)," - stmt := "REPLACE INTO scheduled_query_stats (scheduled_query_id, host_id, query_type, executions, average_memory, system_time, user_time, wall_time, output_size, denylisted, schedule_interval) VALUES " + + const valueStr = "(?,?,?,?,?,?,?,?,?,?,?,?)," + stmt := "REPLACE INTO scheduled_query_stats (scheduled_query_id, host_id, query_type, executions, average_memory, system_time, user_time, wall_time, output_size, denylisted, schedule_interval, last_executed) VALUES " + strings.Repeat(valueStr, len(stats)) stmt = strings.TrimSuffix(stmt, ",") @@ -705,7 +705,7 @@ func (ds *Datastore) UpdateLiveQueryStats(ctx context.Context, queryID uint, sta for _, s := range stats { args = append( args, queryID, s.HostID, statsLiveQueryType, s.Executions, s.AverageMemory, s.SystemTime, s.UserTime, s.WallTime, s.OutputSize, - 0, 0, + 0, 0, s.LastExecuted, ) } _, err := ds.writer(ctx).ExecContext(ctx, stmt, args...) diff --git a/server/datastore/mysql/queries_test.go b/server/datastore/mysql/queries_test.go index f43c07fe7e..cedec4ad89 100644 --- a/server/datastore/mysql/queries_test.go +++ b/server/datastore/mysql/queries_test.go @@ -178,15 +178,26 @@ func testQueriesDelete(t *testing.T, ds *Datastore) { require.NoError(t, err) require.NotNil(t, query) assert.NotEqual(t, query.ID, 0) + lastExecuted := time.Now().Add(-time.Hour).Round(time.Second) // TIMESTAMP precision is seconds by default in MySQL err = ds.UpdateLiveQueryStats( context.Background(), query.ID, []*fleet.LiveQueryStats{ { - HostID: hostID, - Executions: 1, + HostID: hostID, + Executions: 1, + LastExecuted: lastExecuted, }, }, ) require.NoError(t, err) + // Check that the stats were saved correctly + stats, err := ds.GetLiveQueryStats(context.Background(), query.ID, []uint{hostID}) + require.NoError(t, err) + require.Len(t, stats, 1) + assert.Equal(t, hostID, stats[0].HostID) + assert.Equal(t, uint64(1), stats[0].Executions) + assert.False(t, lastExecuted.Before(stats[0].LastExecuted)) + assert.Equal(t, lastExecuted.UTC(), stats[0].LastExecuted.UTC()) + err = ds.CalculateAggregatedPerfStatsPercentiles(context.Background(), fleet.AggregatedStatsTypeScheduledQuery, query.ID) require.NoError(t, err) diff --git a/server/datastore/mysql/testing_utils.go b/server/datastore/mysql/testing_utils.go index b46e07028a..cac7f67ff7 100644 --- a/server/datastore/mysql/testing_utils.go +++ b/server/datastore/mysql/testing_utils.go @@ -9,6 +9,7 @@ import ( "os/exec" "path" "runtime" + "strconv" "strings" "testing" "text/tabwriter" @@ -16,6 +17,7 @@ import ( "github.com/WatchBeam/clock" "github.com/fleetdm/fleet/v4/server/config" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/go-kit/kit/log" "github.com/google/uuid" @@ -28,6 +30,7 @@ const ( testPassword = "toor" testAddress = "localhost:3307" testReplicaDatabaseSuffix = "_replica" + testReplicaAddress = "localhost:3310" ) func connectMySQL(t testing.TB, testName string, opts *DatastoreTestOptions) *Datastore { @@ -40,7 +43,7 @@ func connectMySQL(t testing.TB, testName string, opts *DatastoreTestOptions) *Da // Create datastore client var replicaOpt DBOption - if opts.Replica { + if opts.DummyReplica { replicaConf := config replicaConf.Database += testReplicaDatabaseSuffix replicaOpt = Replica(&replicaConf) @@ -57,14 +60,23 @@ func connectMySQL(t testing.TB, testName string, opts *DatastoreTestOptions) *Da ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1), replicaOpt, SQLMode("ANSI")) require.Nil(t, err) - if opts.Replica { - setupReadReplica(t, testName, ds, opts) + if opts.DummyReplica { + setupDummyReplica(t, testName, ds, opts) + } + if opts.RealReplica { + replicaOpts := &dbOptions{ + minLastOpenedAtDiff: defaultMinLastOpenedAtDiff, + maxAttempts: 1, + logger: log.NewNopLogger(), + sqlMode: "ANSI", + } + setupRealReplica(t, testName, ds, replicaOpts) } return ds } -func setupReadReplica(t testing.TB, testName string, ds *Datastore, opts *DatastoreTestOptions) { +func setupDummyReplica(t testing.TB, testName string, ds *Datastore, opts *DatastoreTestOptions) { t.Helper() // create the context that will cancel the replication goroutine on test exit @@ -185,6 +197,96 @@ func setupReadReplica(t testing.TB, testName string, ds *Datastore, opts *Datast } } +func setupRealReplica(t testing.TB, testName string, ds *Datastore, options *dbOptions) { + t.Helper() + const replicaUser = "replicator" + const replicaPassword = "rotacilper" + + t.Cleanup( + func() { + // Stop slave + if out, err := exec.Command( + "docker-compose", "exec", "-T", "mysql_replica_test", + // Command run inside container + "mysql", + "-u"+testUsername, "-p"+testPassword, + "-e", + "STOP SLAVE; RESET SLAVE ALL;", + ).CombinedOutput(); err != nil { + t.Log(err) + t.Log(string(out)) + } + }, + ) + + ctx := context.Background() + + // Create replication user + _, err := ds.primary.ExecContext(ctx, fmt.Sprintf("DROP USER IF EXISTS '%s'", replicaUser)) + require.NoError(t, err) + _, err = ds.primary.ExecContext(ctx, fmt.Sprintf("CREATE USER '%s'@'%%' IDENTIFIED BY '%s'", replicaUser, replicaPassword)) + require.NoError(t, err) + _, err = ds.primary.ExecContext(ctx, fmt.Sprintf("GRANT REPLICATION SLAVE ON *.* TO '%s'@'%%'", replicaUser)) + require.NoError(t, err) + _, err = ds.primary.ExecContext(ctx, "FLUSH PRIVILEGES") + require.NoError(t, err) + + // Retrieve master binary log coordinates + ms, err := ds.MasterStatus(ctx) + require.NoError(t, err) + + // Get MySQL version + var version string + err = ds.primary.GetContext(ctx, &version, "SELECT VERSION()") + require.NoError(t, err) + using57 := strings.HasPrefix(version, "5.7") + extraMasterOptions := "" + if !using57 { + extraMasterOptions = "GET_MASTER_PUBLIC_KEY=1," // needed for MySQL 8.0 caching_sha2_password authentication + } + + // Configure slave and start replication + if out, err := exec.Command( + "docker-compose", "exec", "-T", "mysql_replica_test", + // Command run inside container + "mysql", + "-u"+testUsername, "-p"+testPassword, + "-e", + fmt.Sprintf( + ` + STOP SLAVE; + RESET SLAVE ALL; + CHANGE MASTER TO + %s + MASTER_HOST='mysql_test', + MASTER_USER='%s', + MASTER_PASSWORD='%s', + MASTER_LOG_FILE='%s', + MASTER_LOG_POS=%d; + START SLAVE; + `, extraMasterOptions, replicaUser, replicaPassword, ms.File, ms.Position, + ), + ).CombinedOutput(); err != nil { + t.Error(err) + t.Error(string(out)) + t.FailNow() + } + + // Connect to the replica + replicaConfig := config.MysqlConfig{ + Username: testUsername, + Password: testPassword, + Database: testName, + Address: testReplicaAddress, + } + require.NoError(t, checkConfig(&replicaConfig)) + replica, err := newDB(&replicaConfig, options) + require.NoError(t, err) + ds.replica = replica + ds.readReplicaConfig = &replicaConfig + +} + // initializeDatabase loads the dumped schema into a newly created database in // MySQL. This is much faster than running the full set of migrations on each // test. @@ -200,7 +302,7 @@ func initializeDatabase(t testing.TB, testName string, opts *DatastoreTestOption // execute the schema for the test db, and once more for the replica db if // that option is set. dbs := []string{testName} - if opts.Replica { + if opts.DummyReplica { dbs = append(dbs, testName+testReplicaDatabaseSuffix) } for _, dbName := range dbs { @@ -221,20 +323,42 @@ func initializeDatabase(t testing.TB, testName string, opts *DatastoreTestOption t.FailNow() } } + if opts.RealReplica { + // Load schema from dumpfile + if out, err := exec.Command( + "docker-compose", "exec", "-T", "mysql_replica_test", + // Command run inside container + "mysql", + "-u"+testUsername, "-p"+testPassword, + "-e", + fmt.Sprintf( + "DROP DATABASE IF EXISTS %s; CREATE DATABASE %s; USE %s; SET FOREIGN_KEY_CHECKS=0; %s;", + testName, testName, testName, schema, + ), + ).CombinedOutput(); err != nil { + t.Error(err) + t.Error(string(out)) + t.FailNow() + } + } + return connectMySQL(t, testName, opts) } // DatastoreTestOptions configures how the test datastore is created // by CreateMySQLDSWithOptions. type DatastoreTestOptions struct { - // Replica indicates that a read replica test database should be created. - Replica bool + // DummyReplica indicates that a read replica test database should be created. + DummyReplica bool // RunReplication is the function to call to execute the replication of all // missing changes from the primary to the replica. The function is created // and set automatically by CreateMySQLDSWithOptions. The test is in full - // control of when the replication is executed. + // control of when the replication is executed. Only applies to DummyReplica. RunReplication func() + + // RealReplica indicates that the replica should be a real DB replica, with a dedicated connection. + RealReplica bool } func createMySQLDSWithOptions(t testing.TB, opts *DatastoreTestOptions) *Datastore { @@ -242,15 +366,21 @@ func createMySQLDSWithOptions(t testing.TB, opts *DatastoreTestOptions) *Datasto t.Skip("MySQL tests are disabled") } - if tt, ok := t.(*testing.T); ok { - tt.Parallel() - } - if opts == nil { // so it is never nil in internal helper functions opts = new(DatastoreTestOptions) } + if tt, ok := t.(*testing.T); ok && !opts.RealReplica { + tt.Parallel() + } + + if opts.RealReplica { + if _, ok := os.LookupEnv("MYSQL_REPLICA_TEST"); !ok { + t.Skip("MySQL replica tests are disabled. Set env var MYSQL_REPLICA_TEST=1 to enable.") + } + } + pc, _, _, ok := runtime.Caller(2) details := runtime.FuncForPC(pc) if !ok || details == nil { @@ -487,3 +617,57 @@ func SetOrderedCreatedAtTimestamps(t testing.TB, ds *Datastore, afterTime time.T } return now } + +// MasterStatus is a struct that holds the file and position of the master, retrieved by SHOW MASTER STATUS +type MasterStatus struct { + File string + Position uint64 +} + +func (ds *Datastore) MasterStatus(ctx context.Context) (MasterStatus, error) { + + rows, err := ds.writer(ctx).Query("SHOW MASTER STATUS") + if err != nil { + return MasterStatus{}, ctxerr.Wrap(ctx, err, "show master status") + } + defer rows.Close() + + // Since we don't control the column names, and we want to be future compatible, + // we only scan for the columns we care about. + ms := MasterStatus{} + // Get the column names from the query + columns, err := rows.Columns() + if err != nil { + return ms, ctxerr.Wrap(ctx, err, "get columns") + } + numberOfColumns := len(columns) + for rows.Next() { + cols := make([]interface{}, numberOfColumns) + for i := range cols { + cols[i] = new(string) + } + err := rows.Scan(cols...) + if err != nil { + return ms, ctxerr.Wrap(ctx, err, "scan row") + } + for i, col := range cols { + switch columns[i] { + case "File": + ms.File = *col.(*string) + case "Position": + ms.Position, err = strconv.ParseUint(*col.(*string), 10, 64) + if err != nil { + return ms, ctxerr.Wrap(ctx, err, "parse Position") + } + + } + } + } + if err := rows.Err(); err != nil { + return ms, ctxerr.Wrap(ctx, err, "rows error") + } + if ms.File == "" || ms.Position == 0 { + return ms, ctxerr.New(ctx, "missing required fields in master status") + } + return ms, nil +} diff --git a/server/fleet/queries.go b/server/fleet/queries.go index e2aab98d88..4b4d3eaf70 100644 --- a/server/fleet/queries.go +++ b/server/fleet/queries.go @@ -157,13 +157,14 @@ func (q *Query) Copy() *Query { type LiveQueryStats struct { // host_id, average_memory, execution, system_time, user_time - HostID uint `db:"host_id"` - Executions uint64 `db:"executions"` - AverageMemory uint64 `db:"average_memory"` - SystemTime uint64 `db:"system_time"` - UserTime uint64 `db:"user_time"` - WallTime uint64 `db:"wall_time"` - OutputSize uint64 `db:"output_size"` + HostID uint `db:"host_id"` + Executions uint64 `db:"executions"` + AverageMemory uint64 `db:"average_memory"` + SystemTime uint64 `db:"system_time"` + UserTime uint64 `db:"user_time"` + WallTime uint64 `db:"wall_time"` + OutputSize uint64 `db:"output_size"` + LastExecuted time.Time `db:"last_executed"` } var ( diff --git a/server/service/service_campaign_test.go b/server/service/service_campaign_test.go index 17c9a993e5..25e2ae1896 100644 --- a/server/service/service_campaign_test.go +++ b/server/service/service_campaign_test.go @@ -167,9 +167,12 @@ func TestStreamCampaignResultsClosesReditOnWSClose(t *testing.T) { require.Equal(t, prevActiveConn-1, newActiveConn) } -func TestUpdateStats(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer mysql.TruncateTables(t, ds) +func testUpdateStats(t *testing.T, ds *mysql.Datastore, usingReplica bool) { + t.Cleanup( + func() { + overwriteLastExecuted = false + }, + ) s, ctx := newTestService(t, ds, nil, nil) svc := s.(validationMiddleware).Service.(*Service) @@ -224,11 +227,47 @@ func TestUpdateStats(t *testing.T) { hostIDs = append(hostIDs, i) } tracker.saveStats = true + // We overwrite the last executed time to ensure that these stats have a different timestamp than later stats + overwriteLastExecuted = true + overwriteLastExecutedTime = time.Now().Add(-2 * time.Second).Round(time.Second) svc.updateStats(ctx, queryID, svc.logger, &tracker, false) assert.True(t, tracker.saveStats) assert.Equal(t, 0, len(tracker.stats)) assert.True(t, tracker.aggregationNeeded) + // Aggregate stats + svc.updateStats(ctx, queryID, svc.logger, &tracker, true) + overwriteLastExecuted = false + + // Check that aggregated stats were created. Since we read aggregated stats from the replica, we may need to wait for it to catch up. + var err error + var aggStats fleet.AggregatedStats + done := make(chan struct{}, 1) + go func() { + for { + aggStats, err = mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID) + if usingReplica && err != nil { + time.Sleep(30 * time.Millisecond) + } else { + done <- struct{}{} + return + } + } + }() + select { + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for aggregated stats") + case <-done: + // Continue + } + require.NoError(t, err) + assert.Equal(t, statsBatchSize, int(*aggStats.TotalExecutions)) + // Sanity checks. Complete testing done in aggregated_stats_test.go + assert.True(t, *aggStats.SystemTimeP50 > 0) + assert.True(t, *aggStats.SystemTimeP95 > 0) + assert.True(t, *aggStats.UserTimeP50 > 0) + assert.True(t, *aggStats.UserTimeP95 > 0) + // Get the stats from DB and make sure they match currentStats, err := svc.ds.GetLiveQueryStats(ctx, queryID, hostIDs) assert.NoError(t, err) @@ -245,17 +284,6 @@ func TestUpdateStats(t *testing.T) { assert.Equal(t, myMemory, myStat.AverageMemory) assert.Equal(t, myOutputSize, myStat.OutputSize) - // Aggregate stats - svc.updateStats(ctx, queryID, svc.logger, &tracker, true) - aggStats, err := mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID) - require.NoError(t, err) - assert.Equal(t, statsBatchSize, int(*aggStats.TotalExecutions)) - // Sanity checks. Complete testing done in aggregated_stats_test.go - assert.True(t, *aggStats.SystemTimeP50 > 0) - assert.True(t, *aggStats.SystemTimeP95 > 0) - assert.True(t, *aggStats.UserTimeP50 > 0) - assert.True(t, *aggStats.UserTimeP95 > 0) - // Write new stats (update) for the same query/hosts myNewWallTime := uint64(15) myNewUserTime := uint64(16) @@ -281,8 +309,8 @@ func TestUpdateStats(t *testing.T) { hostID: i, Stats: &fleet.Stats{ WallTimeMs: rand.Uint64(), - UserTime: rand.Uint64(), - SystemTime: rand.Uint64(), + UserTime: rand.Uint64() % 100, // Keep these values small to ensure the update will be noticeable + SystemTime: rand.Uint64() % 100, // Keep these values small to ensure the update will be noticeable Memory: rand.Uint64(), }, outputSize: rand.Uint64(), @@ -295,6 +323,42 @@ func TestUpdateStats(t *testing.T) { assert.Equal(t, 0, len(tracker.stats)) assert.False(t, tracker.aggregationNeeded) + // Check that aggregated stats were updated. Since we read aggregated stats from the replica, we may need to wait for it to catch up. + var newAggStats fleet.AggregatedStats + done = make(chan struct{}, 1) + go func() { + for { + newAggStats, err = mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID) + if usingReplica && (*aggStats.SystemTimeP50 == *newAggStats.SystemTimeP50 || + *aggStats.SystemTimeP95 == *newAggStats.SystemTimeP95 || + *aggStats.UserTimeP50 == *newAggStats.UserTimeP50 || + *aggStats.UserTimeP95 == *newAggStats.UserTimeP95) { + time.Sleep(30 * time.Millisecond) + } else { + done <- struct{}{} + return + } + } + }() + select { + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for aggregated stats") + case <-done: + // Continue + } + + require.NoError(t, err) + assert.Equal(t, statsBatchSize*2, int(*newAggStats.TotalExecutions)) + // Sanity checks. Complete testing done in aggregated_stats_test.go + assert.True(t, *newAggStats.SystemTimeP50 > 0) + assert.True(t, *newAggStats.SystemTimeP95 > 0) + assert.True(t, *newAggStats.UserTimeP50 > 0) + assert.True(t, *newAggStats.UserTimeP95 > 0) + assert.NotEqual(t, *aggStats.SystemTimeP50, *newAggStats.SystemTimeP50) + assert.NotEqual(t, *aggStats.SystemTimeP95, *newAggStats.SystemTimeP95) + assert.NotEqual(t, *aggStats.UserTimeP50, *newAggStats.UserTimeP50) + assert.NotEqual(t, *aggStats.UserTimeP95, *newAggStats.UserTimeP95) + // Check that stats were updated currentStats, err = svc.ds.GetLiveQueryStats(ctx, queryID, []uint{myHostID}) assert.NoError(t, err) @@ -307,16 +371,21 @@ func TestUpdateStats(t *testing.T) { assert.Equal(t, mySystemTime+myNewSystemTime, myStat.SystemTime) assert.Equal(t, (myMemory+myNewMemory)/2, myStat.AverageMemory) assert.Equal(t, myOutputSize+myNewOutputSize, myStat.OutputSize) +} - // Check that aggregated stats were updated - aggStats, err = mysql.GetAggregatedStats(ctx, svc.ds.(*mysql.Datastore), fleet.AggregatedStatsTypeScheduledQuery, queryID) - require.NoError(t, err) - assert.Equal(t, statsBatchSize*2, int(*aggStats.TotalExecutions)) - // Sanity checks. Complete testing done in aggregated_stats_test.go - assert.True(t, *aggStats.SystemTimeP50 > 0) - assert.True(t, *aggStats.SystemTimeP95 > 0) - assert.True(t, *aggStats.UserTimeP50 > 0) - assert.True(t, *aggStats.UserTimeP95 > 0) +func TestUpdateStats(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + defer mysql.TruncateTables(t, ds) + testUpdateStats(t, ds, false) +} + +func TestUpdateStatsOnReplica(t *testing.T) { + opts := &mysql.DatastoreTestOptions{ + RealReplica: true, + } + ds := mysql.CreateMySQLDSWithOptions(t, opts) + defer mysql.TruncateTables(t, ds) + testUpdateStats(t, ds, true) } func TestCalculateOutputSize(t *testing.T) { diff --git a/server/service/service_campaigns.go b/server/service/service_campaigns.go index 6970e4ee92..cee447e762 100644 --- a/server/service/service_campaigns.go +++ b/server/service/service_campaigns.go @@ -45,6 +45,7 @@ type statsTracker struct { saveStats bool aggregationNeeded bool stats []statsToSave + lastStatsEntry *fleet.LiveQueryStats } func (svc Service) StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint) { @@ -298,6 +299,10 @@ func calculateOutputSize(perfStatsTracker *statsTracker, res *fleet.DistributedQ return outputSize } +// overwriteLastExecuted is used for testing purposes to overwrite the last executed time of the live query stats. +var overwriteLastExecuted = false +var overwriteLastExecutedTime time.Time + func (svc Service) updateStats( ctx context.Context, queryID uint, logger log.Logger, tracker *statsTracker, aggregateStats bool, ) { @@ -327,6 +332,12 @@ func (svc Service) updateStats( } // Update stats + // We round to the nearest second because MySQL default precision of TIMESTAMP is 1 second. + // We could alter the table to increase precision. However, this precision granularity is sufficient for the live query stats use case. + lastExecuted := time.Now().Round(time.Second) + if overwriteLastExecuted { + lastExecuted = overwriteLastExecutedTime + } for _, gatheredStats := range tracker.stats { stats, ok := statsMap[gatheredStats.hostID] if !ok { @@ -338,6 +349,7 @@ func (svc Service) updateStats( UserTime: gatheredStats.UserTime, WallTime: gatheredStats.WallTimeMs, OutputSize: gatheredStats.outputSize, + LastExecuted: lastExecuted, } currentStats = append(currentStats, &newStats) } else { @@ -348,6 +360,7 @@ func (svc Service) updateStats( stats.UserTime = stats.UserTime + gatheredStats.UserTime stats.WallTime = stats.WallTime + gatheredStats.WallTimeMs stats.OutputSize = stats.OutputSize + gatheredStats.outputSize + stats.LastExecuted = lastExecuted } } @@ -359,12 +372,56 @@ func (svc Service) updateStats( return } + tracker.lastStatsEntry = currentStats[0] tracker.aggregationNeeded = true tracker.stats = nil } // Do aggregation if aggregateStats && tracker.aggregationNeeded { + // Since we just wrote new stats, we need the write data to sync to the replica before calculating aggregated stats. + // The calculations are done on the replica to reduce the load on the master. + // Although this check is not necessary if replica is not used, we leave it in for consistency and to ensure the code is exercised in dev/test environments. + // To sync with the replica, we read the last stats entry from the replica and compare the timestamp to what was written on the master. + if tracker.lastStatsEntry != nil { // This check is just to be safe. It should never be nil. + done := make(chan error, 1) + stop := make(chan struct{}, 1) + go func() { + var stats []*fleet.LiveQueryStats + var err error + for { + select { + case <-stop: + return + default: + stats, err = svc.ds.GetLiveQueryStats(ctx, queryID, []uint{tracker.lastStatsEntry.HostID}) + if err != nil { + done <- err + return + } + if !(len(stats) == 0 || stats[0].LastExecuted.Before(tracker.lastStatsEntry.LastExecuted)) { + // Replica is in sync with the last query stats update + done <- nil + return + } + time.Sleep(30 * time.Millisecond) // We see the replication time less than 30 ms in production. + } + } + }() + select { + case err := <-done: + if err != nil { + level.Error(logger).Log("msg", "error syncing replica to master", "err", err) + tracker.saveStats = false + return + } + case <-time.After(5 * time.Second): + stop <- struct{}{} + level.Error(logger).Log("msg", "replica sync timeout: replica did not catch up to the master in 5 seconds") + // We proceed with the aggregation even if the replica is not in sync. + } + } + err := svc.ds.CalculateAggregatedPerfStatsPercentiles(ctx, fleet.AggregatedStatsTypeScheduledQuery, queryID) if err != nil { level.Error(logger).Log("msg", "error aggregating performance stats", "err", err)