fleet/server/datastore/mysql/mysql.go
Konstantin Sykulev 6ed3ba6801
Added OTEL DB stats metrics, renamed trace attributes to expected OTEL names (#42097)
1. Added DB metrics via otelsql.RegisterDBStatsMetrics()
`db.sql.connection.open`
`db.sql.connection.max_open`
`db.sql.connection.wait`
`db.sql.connection.wait_duration`
`db.sql.connection.closed_max_idle`
`db.sql.connection.closed_max_idle_time`
`db.sql.latency.*`
2. renamed these metrics to signoz convention/expected names
`db.sql.connection.open` -> `db.client.connection.usage`
`db.sql.connection.max_open` -> `db.client.connection.max`
`db.sql.connection.wait` -> `db.client.connection.wait_count`
`db.sql.connection.wait_duration` -> `db.client.connection.wait_time`
`db.sql.connection.closed_max_idle` -> `db.client.connection.idle.max`
`db.sql.connection.closed_max_idle_time` ->
`db.client.connection.idle.min`
3. created custom dashboard to display these metrics, (import via json)
<img width="1580" height="906" alt="Screenshot 2026-03-19 at 2 44 43 PM"
src="https://github.com/user-attachments/assets/f1b64ed6-e534-4490-8955-bc1205dd21d4"
/>
4. Fixed metrics for service db dashboards
Signoz expects

`db.system` : Identifies the database type (e.g., postgresql, mysql,
mongodb).
`db.statement` : The actual query being executed (e.g., SELECT * FROM
users).
`db.operation` : The type of operation (e.g., SELECT, INSERT).
`service.name` : The name of the service making the call.

We needed to set the `db.system` attribute explicitly.

`db.operation` is missing because otelsql doesn't capture this by
default. Decided not to add this for now as the dashboards work without.
Can be a future enhancement.

<img width="1563" height="487" alt="Screenshot 2026-03-19 at 2 45 18 PM"
src="https://github.com/user-attachments/assets/51028e16-ee2c-45a9-9025-26f17b0db67a"
/>


# Checklist for submitter

## Testing
- [x] QA'd all new/changed functionality manually

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Added a new observability dashboard for database and connection
performance metrics, including RPS, latency, connection pool saturation,
and queue statistics.
* Enhanced database metrics collection with automatic registration of
connection and query performance indicators.
* Standardized OpenTelemetry metric naming to align with industry
conventions for improved observability compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-03-20 11:07:58 -05:00

1397 lines
47 KiB
Go

// Package mysql is a MySQL implementation of the Datastore interface.
package mysql
import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
"net"
"os"
"regexp"
"strings"
"sync"
"time"
"github.com/WatchBeam/clock"
"github.com/XSAM/otelsql"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
condaccessdepot "github.com/fleetdm/fleet/v4/ee/server/service/condaccess/depot"
hostidscepdepot "github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/depot"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/rdsauth"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/goose"
"github.com/fleetdm/fleet/v4/server/mdm/android"
nano_push "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/push"
scep_depot "github.com/fleetdm/fleet/v4/server/mdm/scep/depot"
common_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
"github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-multierror"
"github.com/jmoiron/sqlx"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.39.0"
)
const (
mySQLTimestampFormat = "2006-01-02 15:04:05" // %Y/%m/%d %H:%M:%S
// Migration IDs needed for fixing broken migrations that some customers encountered with fleet v4.73.2
// See https://github.com/fleetdm/fleet/issues/33562
fleet4732BadMigrationID1 = 20250918154557 // was 20250918154557_AddKernelHostCountsIndexForVulnQueries.go
fleet4732GoodMigrationID1 = 20250817154557 // 20250817154557_AddKernelHostCountsIndexForVulnQueries.go
fleet4732BadMigrationID2 = 20250904115553 // was 20250904115553_OptimizeHostScriptResultsIndex.go
fleet4732GoodMigrationID2 = 20250816115553 // 20250816115553_OptimizeHostScriptResultsIndex.go
fleet4731GoodMigrationID = 20250815130115
)
// Datastore is an implementation of fleet.Datastore interface backed by
// MySQL
type Datastore struct {
replica fleet.DBReader // so it cannot be used to perform writes
primary *sqlx.DB
logger *slog.Logger
clock clock.Clock
config config.MysqlConfig
pusher nano_push.Pusher
android.Datastore
// nil if no read replica
readReplicaConfig *common_mysql.MysqlConfig
// minimum interval between software last_opened_at timestamp to update the
// database (see file software.go).
minLastOpenedAtDiff time.Duration
writeCh chan itemToWrite
// stmtCacheMu protects access to stmtCache.
stmtCacheMu sync.Mutex
// stmtCache holds statements for queries.
stmtCache map[string]*sqlx.Stmt
// for tests, set to override the default batch size.
testDeleteMDMProfilesBatchSize int
// for tests, set to override the default batch size.
testUpsertMDMDesiredProfilesBatchSize int
// for tests set to override the default batch size.
testSelectMDMProfilesBatchSize int
// set this to the execution ids of activities that should be activated in
// the next call to activateNextUpcomingActivity, instead of picking the next
// available activity based on normal prioritization and creation date
// ordering.
testActivateSpecificNextActivities []string
// This key is used to encrypt sensitive data stored in the Fleet DB, for example MDM
// certificates and keys.
serverPrivateKey string
}
// WithPusher sets an APNs pusher for the datastore, used when activating
// next activities that require MDM commands.
func (ds *Datastore) WithPusher(p nano_push.Pusher) {
ds.pusher = p
}
// reader returns the DB instance to use for read-only statements, which is the
// replica unless the primary has been explicitly required via
// ctxdb.RequirePrimary.
func (ds *Datastore) reader(ctx context.Context) fleet.DBReader {
if ctxdb.IsPrimaryRequired(ctx) {
return ds.primary
}
return ds.replica
}
// writer returns the DB instance to use for write statements, which is always
// the primary.
func (ds *Datastore) writer(ctx context.Context) *sqlx.DB {
return ds.primary
}
// loadOrPrepareStmt will load a statement from the statement cache.
// If not available, it will attempt to prepare (create) it.
// Returns nil if it failed to prepare a statement.
//
// IMPORTANT: Adding prepare statements consumes MySQL server resources and is limited by the MySQL max_prepared_stmt_count
// system variable. This method may create 1 prepare statement for EACH database connection. Customers must be notified
// to update their MySQL configurations when additional prepare statements are added.
// For more detail, see: https://github.com/fleetdm/fleet/issues/15476
func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.Stmt {
// the cache is only available on the replica
if ctxdb.IsPrimaryRequired(ctx) {
return nil
}
ds.stmtCacheMu.Lock()
defer ds.stmtCacheMu.Unlock()
stmt, ok := ds.stmtCache[query]
if !ok {
var err error
stmt, err = sqlx.PreparexContext(ctx, ds.replica, query)
if err != nil {
ds.logger.ErrorContext(ctx, "failed to prepare statement",
"query", query,
"err", err,
)
return nil
}
ds.stmtCache[query] = stmt
}
return stmt
}
func (ds *Datastore) deleteCachedStmt(ctx context.Context, query string) {
ds.stmtCacheMu.Lock()
defer ds.stmtCacheMu.Unlock()
stmt, ok := ds.stmtCache[query]
if ok {
if err := stmt.Close(); err != nil {
ds.logger.ErrorContext(ctx, "failed to close prepared statement before deleting it",
"query", query,
"err", err,
)
}
delete(ds.stmtCache, query)
}
}
// NewSCEPDepot returns a scep_depot.Depot that uses the Datastore
// underlying MySQL writer *sql.DB.
func (ds *Datastore) NewSCEPDepot() (scep_depot.Depot, error) {
return newSCEPDepot(ds.primary.DB, ds)
}
// NewHostIdentitySCEPDepot returns a scep_depot.Depot for host identity certs that uses the Datastore
// underlying MySQL writer *sql.DB.
func (ds *Datastore) NewHostIdentitySCEPDepot(logger *slog.Logger, cfg *config.FleetConfig) (scep_depot.Depot, error) {
return hostidscepdepot.NewHostIdentitySCEPDepot(ds.primary, ds, logger, cfg)
}
// NewConditionalAccessSCEPDepot returns a new conditional access SCEP depot that uses the
// underlying MySQL writer *sql.DB.
func (ds *Datastore) NewConditionalAccessSCEPDepot(logger *slog.Logger, cfg *config.FleetConfig) (scep_depot.Depot, error) {
return condaccessdepot.NewConditionalAccessSCEPDepot(ds.primary, ds, logger, cfg)
}
type entity struct {
name string
}
var (
hostsTable = entity{"hosts"}
invitesTable = entity{"invites"}
packsTable = entity{"packs"}
queriesTable = entity{"queries"}
sessionsTable = entity{"sessions"}
usersTable = entity{"users"}
)
func (ds *Datastore) withRetryTxx(ctx context.Context, fn common_mysql.TxFn) (err error) {
return common_mysql.WithRetryTxx(ctx, ds.writer(ctx), fn, ds.logger)
}
// withTx provides a common way to commit/rollback a txFn
func (ds *Datastore) withTx(ctx context.Context, fn common_mysql.TxFn) (err error) {
return common_mysql.WithTxx(ctx, ds.writer(ctx), fn, ds.logger)
}
// withReadTx runs fn in a read-only transaction with a consistent snapshot of the DB
// for executing multiple SELECT queries in an isolated fashion. It should be preferred
// over withTx for these usecases as mysql applies some optimizations to transactions
// declared as read-only versus.
func (ds *Datastore) withReadTx(ctx context.Context, fn common_mysql.ReadTxFn) (err error) {
reader := ds.reader(ctx)
readerDB, ok := reader.(*sqlx.DB)
if !ok {
return ctxerr.New(ctx, "failed to cast reader to *sqlx.DB")
}
return common_mysql.WithReadOnlyTxx(ctx, readerDB, fn, ds.logger)
}
// NewDBConnections creates database connections from config.
// The returned connections can be used to create multiple datastores
// that share the same underlying database connections.
func NewDBConnections(cfg config.MysqlConfig, opts ...DBOption) (*common_mysql.DBConnections, error) {
options := &common_mysql.DBOptions{
MinLastOpenedAtDiff: defaultMinLastOpenedAtDiff,
MaxAttempts: defaultMaxAttempts,
Logger: slog.New(slog.DiscardHandler),
}
for _, setOpt := range opts {
if setOpt != nil {
if err := setOpt(options); err != nil {
return nil, err
}
}
}
if err := checkAndModifyConfig(&cfg); err != nil {
return nil, err
}
// Convert replica config once so that checkAndModifyConfig mutations are preserved for the later NewDB call.
var replicaConf *config.MysqlConfig
if options.ReplicaConfig != nil {
replicaConf = fromCommonMysqlConfig(options.ReplicaConfig)
if err := checkAndModifyConfig(replicaConf); err != nil {
return nil, fmt.Errorf("replica: %w", err)
}
}
// Set up IAM authentication connector factory if needed
if err := setupIAMAuthIfNeeded(&cfg, options); err != nil {
return nil, err
}
dbWriter, err := NewDB(&cfg, options)
if err != nil {
return nil, err
}
dbReader := dbWriter
if replicaConf != nil {
// Set up IAM auth for replica if needed (may have different region/credentials)
replicaOptions := *options
// Reset ConnectorFactory - replica may have different auth requirements than primary
replicaOptions.ConnectorFactory = nil
if err := setupIAMAuthIfNeeded(replicaConf, &replicaOptions); err != nil {
return nil, fmt.Errorf("replica: %w", err)
}
dbReader, err = NewDB(replicaConf, &replicaOptions)
if err != nil {
return nil, err
}
}
return &common_mysql.DBConnections{Primary: dbWriter, Replica: dbReader, Options: options}, nil
}
// NewDatastore creates a Datastore using existing database connections.
// Use this when you need to share database connections with other bounded context datastores.
func NewDatastore(conns *common_mysql.DBConnections, cfg config.MysqlConfig, c clock.Clock) (*Datastore, error) {
ds := &Datastore{
primary: conns.Primary,
replica: conns.Replica,
logger: conns.Options.Logger,
clock: c,
config: cfg,
readReplicaConfig: conns.Options.ReplicaConfig,
writeCh: make(chan itemToWrite),
stmtCache: make(map[string]*sqlx.Stmt),
minLastOpenedAtDiff: conns.Options.MinLastOpenedAtDiff,
serverPrivateKey: conns.Options.PrivateKey,
Datastore: NewAndroidDatastore(conns.Options.Logger, conns.Primary, conns.Replica),
}
go ds.writeChanLoop()
return ds, nil
}
// New creates a MySQL datastore.
func New(cfg config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore, error) {
conns, err := NewDBConnections(cfg, opts...)
if err != nil {
return nil, err
}
return NewDatastore(conns, cfg, c)
}
type itemToWrite struct {
ctx context.Context
errCh chan error
item interface{}
}
type hostXUpdatedAt struct {
hostID uint
updatedAt time.Time
what string
}
func (ds *Datastore) writeChanLoop() {
for item := range ds.writeCh {
switch actualItem := item.item.(type) {
case *fleet.Host:
item.errCh <- ds.UpdateHost(item.ctx, actualItem)
case hostXUpdatedAt:
err := ds.withRetryTxx(
item.ctx, func(tx sqlx.ExtContext) error {
query := fmt.Sprintf(`UPDATE hosts SET %s = ? WHERE id=?`, actualItem.what)
_, err := tx.ExecContext(item.ctx, query, actualItem.updatedAt, actualItem.hostID)
return err
},
)
item.errCh <- ctxerr.Wrap(item.ctx, err, "updating hosts label updated at")
}
}
}
var otelTracedDriverName string
func init() {
var err error
otelTracedDriverName, err = otelsql.Register("mysql",
otelsql.WithAttributes(
attribute.String("db.system", "mysql"),
semconv.DBSystemNameMySQL,
),
otelsql.WithSpanOptions(otelsql.SpanOptions{
// DisableErrSkip ignores driver.ErrSkip errors which are frequently returned by the MySQL driver
// when certain optional methods or paths are not implemented/taken.
// For example: interpolateParams=false (the secure default) will not do a parametrized sql.conn.query directly without preparing it first, causing driver.ErrSkip
DisableErrSkip: true,
// Omitting span for sql.conn.reset_session since it takes ~1us and doesn't provide useful information
OmitConnResetSession: true,
// Omitting span for sql.rows since it is very quick and typically doesn't provide useful information beyond what's already reported by prepare/exec/query
OmitRows: true,
}),
// WithSpanNameFormatter allows us to customize the span name, which is especially useful for SQL queries run outside an HTTPS transaction,
// which do not belong to a parent span, show up as their own trace, and would otherwise be named "sql.conn.query" or "sql.conn.exec".
otelsql.WithSpanNameFormatter(func(ctx context.Context, method otelsql.Method, query string) string {
if query == "" {
return string(method)
}
// Append query with extra whitespaces removed
query = strings.Join(strings.Fields(query), " ")
const maxQueryLen = 100
if len(query) > maxQueryLen {
query = query[:maxQueryLen] + "..."
}
return string(method) + ": " + query
}),
)
if err != nil {
panic(err)
}
}
func NewDB(conf *config.MysqlConfig, opts *common_mysql.DBOptions) (*sqlx.DB, error) {
return common_mysql.NewDB(toCommonMysqlConfig(conf), opts, otelTracedDriverName)
}
// toCommonMysqlConfig converts a config.MysqlConfig to common_mysql.MysqlConfig.
func toCommonMysqlConfig(conf *config.MysqlConfig) *common_mysql.MysqlConfig {
return &common_mysql.MysqlConfig{
Protocol: conf.Protocol,
Address: conf.Address,
Username: conf.Username,
Password: conf.Password,
PasswordPath: conf.PasswordPath,
Database: conf.Database,
TLSCert: conf.TLSCert,
TLSKey: conf.TLSKey,
TLSCA: conf.TLSCA,
TLSServerName: conf.TLSServerName,
TLSConfig: conf.TLSConfig,
MaxOpenConns: conf.MaxOpenConns,
MaxIdleConns: conf.MaxIdleConns,
ConnMaxLifetime: conf.ConnMaxLifetime,
SQLMode: conf.SQLMode,
Region: conf.Region,
}
}
// toCommonLoggingConfig converts a config.LoggingConfig to common_mysql.LoggingConfig.
func toCommonLoggingConfig(conf *config.LoggingConfig) *common_mysql.LoggingConfig {
if conf == nil {
return nil
}
return &common_mysql.LoggingConfig{
TracingEnabled: conf.TracingEnabled,
TracingType: conf.TracingType,
}
}
// fromCommonMysqlConfig converts a common_mysql.MysqlConfig to config.MysqlConfig.
func fromCommonMysqlConfig(conf *common_mysql.MysqlConfig) *config.MysqlConfig {
if conf == nil {
return nil
}
return &config.MysqlConfig{
Protocol: conf.Protocol,
Address: conf.Address,
Username: conf.Username,
Password: conf.Password,
PasswordPath: conf.PasswordPath,
Database: conf.Database,
TLSCert: conf.TLSCert,
TLSKey: conf.TLSKey,
TLSCA: conf.TLSCA,
TLSServerName: conf.TLSServerName,
TLSConfig: conf.TLSConfig,
MaxOpenConns: conf.MaxOpenConns,
MaxIdleConns: conf.MaxIdleConns,
ConnMaxLifetime: conf.ConnMaxLifetime,
SQLMode: conf.SQLMode,
Region: conf.Region,
}
}
func checkAndModifyConfig(conf *config.MysqlConfig) error {
if conf.PasswordPath != "" && conf.Password != "" {
return errors.New("A MySQL password and a MySQL password file were provided - please specify only one")
}
// Check to see if the flag is populated
// Check if file exists on disk
// If file exists read contents
if conf.PasswordPath != "" {
fileContents, err := os.ReadFile(conf.PasswordPath)
if err != nil {
return err
}
conf.Password = strings.TrimSpace(string(fileContents))
}
if conf.TLSCA != "" {
conf.TLSConfig = "custom"
err := registerTLS(*conf)
if err != nil {
return fmt.Errorf("register TLS config for mysql: %w", err)
}
}
return nil
}
// setupIAMAuthIfNeeded configures IAM authentication for RDS if the config
// indicates it should be used (no password provided but region is set).
func setupIAMAuthIfNeeded(conf *config.MysqlConfig, opts *common_mysql.DBOptions) error {
if conf.Password != "" || conf.PasswordPath != "" || conf.Region == "" {
return nil
}
// Parse host and port from address
host, port, err := net.SplitHostPort(conf.Address)
if err != nil {
host = conf.Address
port = "3306"
}
factory, err := rdsauth.NewConnectorFactory(conf, host, port)
if err != nil {
return fmt.Errorf("failed to create RDS IAM auth connector factory: %w", err)
}
opts.ConnectorFactory = factory
return nil
}
func (ds *Datastore) MigrateTables(ctx context.Context) error {
return tables.MigrationClient.Up(ds.writer(ctx).DB, "")
}
func (ds *Datastore) MigrateData(ctx context.Context) error {
return data.MigrationClient.Up(ds.writer(ctx).DB, "")
}
// loadMigrations manually loads the applied migrations in ascending
// order (goose doesn't provide such functionality).
//
// Returns two lists of version IDs (one for "table" and one for "data").
func (ds *Datastore) loadMigrations(
ctx context.Context,
writer *sql.DB,
reader fleet.DBReader,
) (tableRecs []int64, dataRecs []int64, err error) {
// We need to run the following to trigger the creation of the migration status tables.
_, err = tables.MigrationClient.GetDBVersion(writer)
if err != nil {
return nil, nil, err
}
_, err = data.MigrationClient.GetDBVersion(writer)
if err != nil {
return nil, nil, err
}
// version_id > 0 to skip the bootstrap migration that creates the migration tables.
if err := sqlx.SelectContext(ctx, reader, &tableRecs,
"SELECT version_id FROM "+tables.MigrationClient.TableName+" WHERE version_id > 0 AND is_applied ORDER BY id ASC",
); err != nil {
return nil, nil, err
}
if err := sqlx.SelectContext(ctx, reader, &dataRecs,
"SELECT version_id FROM "+data.MigrationClient.TableName+" WHERE version_id > 0 AND is_applied ORDER BY id ASC",
); err != nil {
return nil, nil, err
}
return tableRecs, dataRecs, nil
}
// MigrationStatus will return the current status of the migrations
// comparing the known migrations in code and the applied migrations in the database.
//
// It assumes some deployments may have performed migrations out of order.
func (ds *Datastore) MigrationStatus(ctx context.Context) (*fleet.MigrationStatus, error) {
if tables.MigrationClient.Migrations == nil || data.MigrationClient.Migrations == nil {
return nil, errors.New("unexpected nil migrations list")
}
appliedTable, appliedData, err := ds.loadMigrations(ctx, ds.primary.DB, ds.replica)
if err != nil {
return nil, fmt.Errorf("cannot load migrations: %w", err)
}
// This will only return a non-nil status if we detect the specific broken state from v4.73.2
status := ds.CheckFleetv4732BadMigrations(appliedTable)
if status != nil {
return status, nil
}
return compareMigrations(
tables.MigrationClient.Migrations,
data.MigrationClient.Migrations,
appliedTable,
appliedData,
), nil
}
// Checks for misnumbered migrations introduced in some released fleet v4.73.2 versions
func (ds *Datastore) CheckFleetv4732BadMigrations(appliedTable []int64) *fleet.MigrationStatus {
if len(appliedTable) == 0 {
return nil
}
// If the last 3 migrations are the "bad" 4.73.2 migrations and then the good 4.73.1 migration, in that order,
// we are in the known-bad 4.73.2 state and should apply the fix
if len(appliedTable) > 2 &&
appliedTable[len(appliedTable)-1] == fleet4732BadMigrationID1 &&
appliedTable[len(appliedTable)-2] == fleet4732BadMigrationID2 &&
appliedTable[len(appliedTable)-3] == fleet4731GoodMigrationID {
return &fleet.MigrationStatus{
StatusCode: fleet.NeedsFleetv4732Fix,
}
}
for _, v := range appliedTable {
if v == fleet4732BadMigrationID1 || v == fleet4732BadMigrationID2 {
return &fleet.MigrationStatus{
StatusCode: fleet.UnknownFleetv4732State,
}
}
}
return nil
}
func (ds *Datastore) FixFleetv4732Migrations(ctx context.Context) error {
// Update version ID of the bad migrations to the renumbered version IDs. Exactly 1 row should be affected
// by each query
stmt := `UPDATE ` + tables.MigrationClient.TableName + ` SET version_id = ? WHERE version_id = ?`
return ds.withTx(ctx, func(tx sqlx.ExtContext) error {
result, err := tx.ExecContext(ctx, stmt, fleet4732GoodMigrationID1, fleet4732BadMigrationID1)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return ctxerr.Errorf(ctx, "expected to affect 1 row for migration %d, affected %d", fleet4732BadMigrationID1, affected)
}
result, err = tx.ExecContext(ctx, stmt, fleet4732GoodMigrationID2, fleet4732BadMigrationID2)
if err != nil {
return err
}
affected, err = result.RowsAffected()
if err != nil {
return err
}
if affected != 1 {
return ctxerr.Errorf(ctx, "expected to affect 1 row for migration %d, affected %d", fleet4732BadMigrationID2, affected)
}
return nil
})
}
// It assumes some deployments may have performed migrations out of order.
func compareMigrations(knownTable goose.Migrations, knownData goose.Migrations, appliedTable, appliedData []int64) *fleet.MigrationStatus {
if len(appliedTable) == 0 && len(appliedData) == 0 {
return &fleet.MigrationStatus{
StatusCode: fleet.NoMigrationsCompleted,
}
}
missingTable, unknownTable, equalTable := compareVersions(
getVersionsFromMigrations(knownTable),
appliedTable,
knownUnknownTableMigrations,
)
missingData, unknownData, equalData := compareVersions(
getVersionsFromMigrations(knownData),
appliedData,
knownUnknownDataMigrations,
)
if equalData && equalTable {
return &fleet.MigrationStatus{
StatusCode: fleet.AllMigrationsCompleted,
}
}
//
// The following code assumes there cannot be migrations missing on
// "table" and database being ahead on "data" (and vice-versa).
//
// Check for missing migrations first, as these are more important
// to detect than the unknown migrations.
if len(missingTable) > 0 || len(missingData) > 0 {
return &fleet.MigrationStatus{
StatusCode: fleet.SomeMigrationsCompleted,
MissingTable: missingTable,
MissingData: missingData,
}
}
// len(unknownTable) > 0 || len(unknownData) > 0
return &fleet.MigrationStatus{
StatusCode: fleet.UnknownMigrations,
UnknownTable: unknownTable,
UnknownData: unknownData,
}
}
var (
knownUnknownTableMigrations = map[int64]struct{}{
// This migration was introduced incorrectly in fleet-v4.4.0 and its
// timestamp was changed in fleet-v4.4.1.
20210924114500: {},
}
knownUnknownDataMigrations = map[int64]struct{}{
// This migration was present in 2.0.0, and was removed on a subsequent release.
// Was basically running `DELETE FROM packs WHERE deleted = 1`, (such `deleted`
// column doesn't exist anymore).
20171212182459: {},
// Deleted in
// https://github.com/fleetdm/fleet/commit/fd61dcab67f341c9e47fb6cb968171650c19a681
20161223115449: {},
20170309091824: {},
20171027173700: {},
20171212182458: {},
}
)
func unknownUnknowns(in []int64, knownUnknowns map[int64]struct{}) []int64 {
var result []int64
for _, t := range in {
if _, ok := knownUnknowns[t]; !ok {
result = append(result, t)
}
}
return result
}
// compareVersions returns any missing or extra elements in v2 with respect to v1
// (v1 or v2 need not be ordered).
func compareVersions(v1, v2 []int64, knownUnknowns map[int64]struct{}) (missing []int64, unknown []int64, equal bool) {
v1s := make(map[int64]struct{})
for _, m := range v1 {
v1s[m] = struct{}{}
}
v2s := make(map[int64]struct{})
for _, m := range v2 {
v2s[m] = struct{}{}
}
for _, m := range v1 {
if _, ok := v2s[m]; !ok {
missing = append(missing, m)
}
}
for _, m := range v2 {
if _, ok := v1s[m]; !ok {
unknown = append(unknown, m)
}
}
unknown = unknownUnknowns(unknown, knownUnknowns)
if len(missing) == 0 && len(unknown) == 0 {
return nil, nil, true
}
return missing, unknown, false
}
func getVersionsFromMigrations(migrations goose.Migrations) []int64 {
versions := make([]int64, len(migrations))
for i := range migrations {
versions[i] = migrations[i].Version
}
return versions
}
// HealthCheck returns an error if the MySQL backend is not healthy.
func (ds *Datastore) HealthCheck() error {
// NOTE: does not receive a context as argument here, because the HealthCheck
// interface potentially affects more than the datastore layer, and I'm not
// sure we can safely identify and change them all at this moment.
// Check that the primary is reachable and not in read-only mode.
// After an AWS Aurora failover the old writer is demoted to a reader;
// detecting this lets the health check fail so the orchestrator can restart Fleet.
var readOnly int
if err := ds.primary.QueryRowContext(context.Background(), "SELECT @@read_only").Scan(&readOnly); err != nil {
return err
}
if readOnly == 1 {
// Intentionally return an error so that the health check endpoint returns a 500,
// signaling the orchestrator (ECS, Kubernetes) to restart Fleet with fresh DB connections.
return errors.New("primary database is read-only, possible failover detected")
}
if ds.readReplicaConfig != nil {
var dst int
if err := sqlx.GetContext(context.Background(), ds.replica, &dst, "select 1"); err != nil {
return err
}
}
return nil
}
func (ds *Datastore) closeStmts() error {
ds.stmtCacheMu.Lock()
defer ds.stmtCacheMu.Unlock()
var err error
for query, stmt := range ds.stmtCache {
if errClose := stmt.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
delete(ds.stmtCache, query)
}
return err
}
// Close frees resources associated with underlying mysql connection
func (ds *Datastore) Close() error {
var err error
if errStmt := ds.closeStmts(); errStmt != nil {
err = multierror.Append(err, errStmt)
}
if errWriter := ds.primary.Close(); errWriter != nil {
err = multierror.Append(err, errWriter)
}
if ds.readReplicaConfig != nil {
if errRead := ds.replica.Close(); errRead != nil {
err = multierror.Append(err, errRead)
}
}
return err
}
// appendListOptionsToSelect will apply the given list options to ds and
// return the new select dataset.
//
// NOTE: This is a copy of appendListOptionsToSQL that uses the goqu package.
func appendListOptionsToSelect(ds *goqu.SelectDataset, opts fleet.ListOptions) *goqu.SelectDataset {
ds = appendOrderByToSelect(ds, opts)
ds = appendLimitOffsetToSelect(ds, opts)
return ds
}
func appendOrderByToSelect(ds *goqu.SelectDataset, opts fleet.ListOptions) *goqu.SelectDataset {
if opts.OrderKey != "" {
ordersKeys := strings.Split(opts.OrderKey, ",")
for _, key := range ordersKeys {
sanitized := common_mysql.SanitizeColumn(key)
if sanitized == "" {
continue
}
var orderedExpr exp.OrderedExpression
if opts.OrderDirection == fleet.OrderDescending {
orderedExpr = goqu.L(sanitized).Desc()
} else {
orderedExpr = goqu.L(sanitized).Asc()
}
ds = ds.OrderAppend(orderedExpr)
}
}
return ds
}
func appendLimitOffsetToSelect(ds *goqu.SelectDataset, opts fleet.ListOptions) *goqu.SelectDataset {
perPage := opts.PerPage
// If caller doesn't supply a limit apply a reasonably large default limit
// to insure that an unbounded query with many results doesn't consume too
// much memory or hang
if perPage == 0 {
perPage = fleet.DefaultPerPage
}
offset := perPage * opts.Page
if offset > 0 {
ds = ds.Offset(offset)
}
if opts.IncludeMetadata {
perPage++
}
ds = ds.Limit(perPage)
return ds
}
// sanitizeColumn is a facade that calls common_mysql.SanitizeColumn.
func sanitizeColumn(col string) string {
return common_mysql.SanitizeColumn(col)
}
// appendListOptionsToSQL is a facade that calls common_mysql.AppendListOptions.
//
// Deprecated: this method will be removed in favor of appendListOptionsWithCursorToSQL
func appendListOptionsToSQL(sql string, opts *fleet.ListOptions) (string, []any) {
return appendListOptionsWithCursorToSQL(sql, nil, opts)
}
// appendListOptionsToSQLSecure is a facade that calls common_mysql.AppendListOptionsWithParamsSecure.
// The allowlist parameter maps user-facing order key names to actual SQL column expressions.
// This prevents SQL injection and information disclosure via arbitrary column sorting.
// See common_mysql.OrderKeyAllowlist for details.
func appendListOptionsToSQLSecure(sql string, opts *fleet.ListOptions, allowlist common_mysql.OrderKeyAllowlist) (string, []any, error) {
return appendListOptionsWithCursorToSQLSecure(sql, nil, opts, allowlist)
}
// appendListOptionsWithCursorToSQL is a facade that calls common_mysql.AppendListOptionsWithParams.
// NOTE: this method will mutate opts.PerPage if it is 0, setting it to the default value.
//
// Deprecated: this method will be removed in favor of appendListOptionsWithCursorToSQLSecure
func appendListOptionsWithCursorToSQL(sql string, params []any, opts *fleet.ListOptions) (string, []any) {
if opts.PerPage == 0 {
opts.PerPage = fleet.DefaultPerPage
}
return common_mysql.AppendListOptionsWithParams(sql, params, opts)
}
// appendListOptionsWithCursorToSQLSecure is a facade that calls common_mysql.AppendListOptionsWithParamsSecure.
// NOTE: this method will mutate opts.PerPage if it is 0, setting it to the default value.
//
// The allowlist parameter maps user-facing order key names to actual SQL column expressions.
// This prevents SQL injection and information disclosure via arbitrary column sorting.
// See common_mysql.OrderKeyAllowlist for details.
func appendListOptionsWithCursorToSQLSecure(sql string, params []any, opts *fleet.ListOptions, allowlist common_mysql.OrderKeyAllowlist) (string, []any, error) {
if opts.PerPage == 0 {
opts.PerPage = fleet.DefaultPerPage
}
return common_mysql.AppendListOptionsWithParamsSecure(sql, params, opts, allowlist)
}
// whereFilterHostsByTeams returns the appropriate condition to use in the WHERE
// clause to render only the appropriate teams.
//
// filter provides the filtering parameters that should be used. hostKey is the
// name/alias of the hosts table to use in generating the SQL.
func (ds *Datastore) whereFilterHostsByTeams(filter fleet.TeamFilter, hostKey string) string {
if filter.User == nil {
// This is likely unintentional, however we would like to return no
// results rather than panicking or returning some other error. At least
// log.
ds.logger.InfoContext(context.TODO(), "team filter missing user")
return "FALSE"
}
defaultAllowClause := "TRUE"
if filter.TeamID != nil {
defaultAllowClause = fmt.Sprintf("%s.team_id = %d", hostKey, *filter.TeamID)
}
if filter.User.GlobalRole != nil {
switch *filter.User.GlobalRole {
case fleet.RoleAdmin, fleet.RoleMaintainer, fleet.RoleTechnician, fleet.RoleObserverPlus:
return defaultAllowClause
case fleet.RoleObserver:
if filter.IncludeObserver {
if filter.ObserverTeamID != nil {
// Restrict global observer to only the specified team (e.g. the live query's own team).
return fmt.Sprintf("%s.team_id = %d", hostKey, *filter.ObserverTeamID)
}
return defaultAllowClause
}
return "FALSE"
default:
// Fall through to specific teams
}
}
// Collect matching teams
var idStrs []string
var teamIDSeen bool
for _, team := range filter.User.Teams {
if team.Role == fleet.RoleAdmin ||
team.Role == fleet.RoleMaintainer ||
team.Role == fleet.RoleTechnician ||
team.Role == fleet.RoleObserverPlus {
idStrs = append(idStrs, fmt.Sprint(team.ID))
if filter.TeamID != nil && *filter.TeamID == team.ID {
teamIDSeen = true
}
} else if team.Role == fleet.RoleObserver && filter.IncludeObserver {
// When ObserverTeamID is set, restrict observer access to only that team.
// This scopes observer_can_run to the query's own team, not all observed teams.
if filter.ObserverTeamID == nil || *filter.ObserverTeamID == team.ID {
idStrs = append(idStrs, fmt.Sprint(team.ID))
if filter.TeamID != nil && *filter.TeamID == team.ID {
teamIDSeen = true
}
}
}
}
if len(idStrs) == 0 {
// User has no global role and no teams allowed by includeObserver.
return "FALSE"
}
if filter.TeamID != nil {
if teamIDSeen {
// all good, this user has the right to see the requested team
return defaultAllowClause
}
return "FALSE"
}
return fmt.Sprintf("%s.team_id IN (%s)", hostKey, strings.Join(idStrs, ","))
}
// whereFilterTeamWithGlobalStats is the same as whereFilterHostsByTeams, it
// returns the appropriate condition to use in the WHERE clause to render only
// the appropriate teams, but is to be used when the team_id column uses "0" to
// mean "all teams including no team". This is the case e.g. for
// software_title_host_counts.
//
// filter provides the filtering parameters that should be used.
// filterTableAlias is the name/alias of the table to use in generating the
// SQL.
func (ds *Datastore) whereFilterTeamWithGlobalStats(filter fleet.TeamFilter, filterTableAlias string) string {
globalFilter := fmt.Sprintf("%s.team_id = 0 AND %[1]s.global_stats = 1", filterTableAlias)
teamIDFilter := fmt.Sprintf("%s.team_id", filterTableAlias)
return ds.whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(filter, globalFilter, teamIDFilter)
}
func (ds *Datastore) whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(
filter fleet.TeamFilter, globalSqlFilter string, teamIDSqlFilter string,
) string {
if filter.User == nil {
// This is likely unintentional, however we would like to return no
// results rather than panicking or returning some other error. At least
// log.
ds.logger.InfoContext(context.TODO(), "team filter missing user")
return "FALSE"
}
defaultAllowClause := globalSqlFilter
if filter.TeamID != nil {
defaultAllowClause = fmt.Sprintf("%s = %d", teamIDSqlFilter, *filter.TeamID)
}
if filter.User.GlobalRole != nil {
switch *filter.User.GlobalRole {
case fleet.RoleAdmin, fleet.RoleMaintainer, fleet.RoleTechnician, fleet.RoleObserverPlus:
return defaultAllowClause
case fleet.RoleObserver:
if filter.IncludeObserver {
return defaultAllowClause
}
return "FALSE"
default:
// Fall through to specific teams
}
}
// Collect matching teams
var idStrs []string
var teamIDSeen bool
for _, team := range filter.User.Teams {
if team.Role == fleet.RoleAdmin ||
team.Role == fleet.RoleMaintainer ||
team.Role == fleet.RoleTechnician ||
team.Role == fleet.RoleObserverPlus ||
(team.Role == fleet.RoleObserver && filter.IncludeObserver) {
idStrs = append(idStrs, fmt.Sprint(team.ID))
if filter.TeamID != nil && *filter.TeamID == team.ID {
teamIDSeen = true
}
}
}
if len(idStrs) == 0 {
// User has no global role and no teams allowed by includeObserver.
return "FALSE"
}
if filter.TeamID != nil {
if teamIDSeen {
// all good, this user has the right to see the requested team
return defaultAllowClause
}
return "FALSE"
}
return fmt.Sprintf("%s IN (%s)", teamIDSqlFilter, strings.Join(idStrs, ","))
}
// whereFilterTeams returns the appropriate condition to use in the WHERE
// clause to render only the appropriate teams.
//
// filter provides the filtering parameters that should be used. teamKey is the
// name/alias of the teams table to use in generating the SQL.
func (ds *Datastore) whereFilterTeams(filter fleet.TeamFilter, teamKey string) string {
if filter.User == nil {
// This is likely unintentional, however we would like to return no
// results rather than panicking or returning some other error. At least
// log.
ds.logger.InfoContext(context.TODO(), "team filter missing user")
return "FALSE"
}
if filter.User.GlobalRole != nil {
switch *filter.User.GlobalRole {
case fleet.RoleAdmin, fleet.RoleMaintainer, fleet.RoleTechnician, fleet.RoleGitOps, fleet.RoleObserverPlus:
return "TRUE"
case fleet.RoleObserver:
if filter.IncludeObserver {
return "TRUE"
}
return "FALSE"
default:
// Fall through to specific teams
}
}
// Collect matching teams
var idStrs []string
for _, team := range filter.User.Teams {
if team.Role == fleet.RoleAdmin ||
team.Role == fleet.RoleMaintainer ||
team.Role == fleet.RoleTechnician ||
team.Role == fleet.RoleGitOps ||
team.Role == fleet.RoleObserverPlus ||
(team.Role == fleet.RoleObserver && filter.IncludeObserver) {
idStrs = append(idStrs, fmt.Sprint(team.ID))
}
}
if len(idStrs) == 0 {
// User has no global role and no teams allowed by includeObserver.
return "FALSE"
}
return fmt.Sprintf("%s.id IN (%s)", teamKey, strings.Join(idStrs, ","))
}
// whereOmitIDs returns the appropriate condition to use in the WHERE
// clause to omit the provided IDs from the selection.
func (ds *Datastore) whereOmitIDs(colName string, omit []uint) string {
if len(omit) == 0 {
return "TRUE"
}
var idStrs []string
for _, id := range omit {
idStrs = append(idStrs, fmt.Sprint(id))
}
return fmt.Sprintf("%s NOT IN (%s)", colName, strings.Join(idStrs, ","))
}
func (ds *Datastore) whereFilterHostsByIdentifier(identifier, stmt string, params []interface{}) (string, []interface{}) {
if identifier == "" {
return stmt, params
}
stmt += " AND ? IN (h.hostname, h.osquery_host_id, h.node_key, h.uuid, h.hardware_serial)"
params = append(params, identifier)
return stmt, params
}
// registerTLS adds client certificate configuration to the mysql connection.
func registerTLS(conf config.MysqlConfig) error {
tlsCfg := config.TLS{
TLSCert: conf.TLSCert,
TLSKey: conf.TLSKey,
TLSCA: conf.TLSCA,
TLSServerName: conf.TLSServerName,
}
cfg, err := tlsCfg.ToTLSConfig()
if err != nil {
return err
}
if err := mysql.RegisterTLSConfig(conf.TLSConfig, cfg); err != nil {
return fmt.Errorf("register mysql tls config: %w", err)
}
return nil
}
// isForeignKeyError checks if the provided error is a MySQL child foreign key
// error (Error #1452)
func isChildForeignKeyError(err error) bool {
err = ctxerr.Cause(err)
mysqlErr, ok := err.(*mysql.MySQLError)
if !ok {
return false
}
// https://dev.mysql.com/doc/refman/5.7/en/error-messages-server.html#error_er_no_referenced_row_2
const ER_NO_REFERENCED_ROW_2 = 1452
return mysqlErr.Number == ER_NO_REFERENCED_ROW_2
}
type patternReplacer func(string) string
// likePattern returns a pattern to match m with LIKE.
func likePattern(m string) string {
m = strings.ReplaceAll(m, "_", "\\_")
m = strings.ReplaceAll(m, "%", "\\%")
return "%" + m + "%"
}
// noneReplacer doesn't manipulate
func noneReplacer(m string) string {
return m
}
// searchLike adds SQL and parameters for a "search" using LIKE syntax.
//
// The input columns must be sanitized if they are provided by the user.
func searchLike(sql string, params []interface{}, match string, columns ...string) (string, []interface{}) {
return searchLikePattern(sql, params, match, likePattern, columns...)
}
func searchLikePattern(sql string, params []interface{}, match string, replacer patternReplacer, columns ...string) (string, []interface{}) {
if len(columns) == 0 || len(match) == 0 {
return sql, params
}
pattern := replacer(match)
ors := make([]string, 0, len(columns))
for _, column := range columns {
ors = append(ors, column+" LIKE ?")
params = append(params, pattern)
}
sql += " AND (" + strings.Join(ors, " OR ") + ")"
return sql, params
}
/*
This regex matches any occurrence of a character from the ASCII character set followed by one or more characters that are not from the ASCII character set.
The first part `[[:ascii:]]` matches any character that is within the ASCII range (0 to 127 in the ASCII table),
while the second part `[^[:ascii:]]` matches any character that is not within the ASCII range.
So, when these two parts are combined with no space in between, the resulting regex matches any
sequence of characters where the first character is within the ASCII range and the following characters are not within the ASCII range.
*/
var (
nonascii = regexp.MustCompile(`(?P<ascii>[[:ascii:]])(?P<nonascii>[^[:ascii:]]+)`)
nonacsiiReplace = regexp.MustCompile(`[^[:ascii:]]`)
)
// hostSearchLike searches hosts based on the given columns plus searching in hosts_emails. Note:
// the host from the `hosts` table must be aliased to `h` in `sql`.
func hostSearchLike(sql string, params []any, match string, columns ...string) (string, []any) {
base, args := searchLike(sql, params, match, columns...)
// Always search in host_emails table in addition to the provided columns,
// so that any search query can surface results from human-host mapping information.
if len(match) > 0 && len(columns) > 0 {
// remove the closing paren and add the email condition to the list
base = strings.TrimSuffix(base, ")") + " OR (" + ` EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))`
args = append(args, likePattern(match))
}
return base, args
}
func hostSearchLikeAny(sql string, params []interface{}, match string, columns ...string) (string, []interface{}) {
return searchLikePattern(sql, params, buildWildcardMatchPhrase(match), noneReplacer, columns...)
}
func buildWildcardMatchPhrase(matchQuery string) string {
return replaceMatchAny(likePattern(matchQuery))
}
func hasNonASCIIRegex(s string) bool {
return nonascii.MatchString(s)
}
func replaceMatchAny(s string) string {
return nonacsiiReplace.ReplaceAllString(s, "_")
}
func (ds *Datastore) InnoDBStatus(ctx context.Context) (string, error) {
status := struct {
Type string `db:"Type"`
Name string `db:"Name"`
Status string `db:"Status"`
}{}
// using the writer even when doing a read to get the data from the main db node
err := ds.writer(ctx).GetContext(ctx, &status, "show engine innodb status")
if err != nil {
// To read innodb tables, DB user must have PROCESS privilege
// This can be set by DB admin like: GRANT PROCESS,SELECT ON *.* TO 'fleet'@'%';
if isMySQLAccessDenied(err) {
return "", &accessDeniedError{
Message: "getting innodb status: DB user must have global PROCESS and SELECT privilege",
InternalErr: err,
}
}
return "", ctxerr.Wrap(ctx, err, "getting innodb status")
}
return status.Status, nil
}
func (ds *Datastore) ProcessList(ctx context.Context) ([]fleet.MySQLProcess, error) {
var processList []fleet.MySQLProcess
// using the writer even when doing a read to get the data from the main db node
err := ds.writer(ctx).SelectContext(ctx, &processList, "show processlist")
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "Getting process list")
}
return processList, nil
}
func insertOnDuplicateDidInsertOrUpdate(res sql.Result) bool {
// From mysql's documentation:
//
// With ON DUPLICATE KEY UPDATE, the affected-rows value per row is 1 if
// the row is inserted as a new row, 2 if an existing row is updated, and
// 0 if an existing row is set to its current values. If you specify the
// CLIENT_FOUND_ROWS flag to the mysql_real_connect() C API function when
// connecting to mysqld, the affected-rows value is 1 (not 0) if an
// existing row is set to its current values.
//
// If a table contains an AUTO_INCREMENT column and INSERT ... ON DUPLICATE KEY UPDATE
// inserts or updates a row, the LAST_INSERT_ID() function returns the AUTO_INCREMENT value.
//
// https://dev.mysql.com/doc/refman/8.4/en/insert-on-duplicate.html
//
// Note that connection string sets CLIENT_FOUND_ROWS (see
// generateMysqlConnectionString in this package), so it does return 1 when
// an existing row is set to its current values, but with a last inserted id
// of 0.
//
// Also note that with our mysql driver, Result.LastInsertId and
// Result.RowsAffected can never return an error, they are retrieved at the
// time of the Exec call, and the result simply returns the integers it
// already holds:
// https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/result.go
lastID, _ := res.LastInsertId()
aff, _ := res.RowsAffected()
// something was updated (lastID != 0) AND row was found (aff == 1 or higher if more rows were found)
return lastID != 0 && aff > 0
}
type parameterizedStmt struct {
Statement string
Args []interface{}
}
// optimisticGetOrInsert encodes an efficient pattern of looking up a row's ID
// for a unique key that is more likely to already exist (i.e. the insert
// should be infrequent, the read should succeed most of the time).
// It proceeds as follows:
// 1. Try to read the ID from the read replica.
// 2. If it does not exist, try to insert the row in the primary.
// 3. If it fails due to a duplicate key, try to read the ID again, this
// time from the primary.
//
// The read statement must only SELECT the id column.
func (ds *Datastore) optimisticGetOrInsert(ctx context.Context, readStmt, insertStmt *parameterizedStmt) (id uint, err error) {
return ds.optimisticGetOrInsertWithWriter(ctx, ds.writer(ctx), readStmt, insertStmt)
}
// optimisticGetOrInsertWithWriter is the same as optimisticGetOrInsert but it
// uses the provided writer to perform the insert or second read operations.
// This makes it possible to use this from inside a transaction.
func (ds *Datastore) optimisticGetOrInsertWithWriter(ctx context.Context, writer sqlx.ExtContext, readStmt, insertStmt *parameterizedStmt) (id uint, err error) { //nolint: gocritic // it's ok in this case to use ds.reader even if we receive an ExtContext
readID := func(q sqlx.QueryerContext) (uint, error) {
var id uint
err := sqlx.GetContext(ctx, q, &id, readStmt.Statement, readStmt.Args...)
return id, err
}
// 1. read from the read replica, as it is likely to already exist
id, err = readID(ds.reader(ctx))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// this does not exist yet, try to insert it
res, err := writer.ExecContext(ctx, insertStmt.Statement, insertStmt.Args...)
if err != nil {
if IsDuplicate(err) {
// it might've been created between the select and the insert, read
// again this time from the primary database connection.
id, err := readID(writer)
if err != nil {
return 0, ctxerr.Wrap(ctx, err, "get id from writer")
}
return id, nil
}
return 0, ctxerr.Wrap(ctx, err, "insert")
}
id, _ := res.LastInsertId()
return uint(id), nil //nolint:gosec // dismiss G115
}
return 0, ctxerr.Wrap(ctx, err, "get id from reader")
}
return id, nil
}
// batchProcessDB abstracts the batch processing logic, for a given payload:
//
// - generateValueArgs will get called for each item, the expected return values are:
// - a string containing the placeholders for each item in the batch
// - a slice of arguments containing one item for each placeholder
//
// - executeBatch will get called on each batch to perform the operation in the db
//
// TODO(roberto): use this function in all the functions where we do ad-hoc
// batch processing.
func batchProcessDB[T any](
payload []T,
batchSize int,
generateValueArgs func(T) (string, []any),
executeBatch func(string, []any) error,
) error {
if len(payload) == 0 {
return nil
}
var (
args []any
sb strings.Builder
batchCount int
)
resetBatch := func() {
batchCount = 0
args = args[:0]
sb.Reset()
}
for _, item := range payload {
valuePart, itemArgs := generateValueArgs(item)
args = append(args, itemArgs...)
sb.WriteString(valuePart)
batchCount++
if batchCount >= batchSize {
if err := executeBatch(sb.String(), args); err != nil {
return err
}
resetBatch()
}
}
if batchCount > 0 {
if err := executeBatch(sb.String(), args); err != nil {
return err
}
}
return nil
}