fleet/server/datastore/mysql/common_mysql/common.go
Victor Lyuboslavsky 276af0f5b0
Refactored RDS IAM authentication logic into a dedicated rdsauth package (#36847)
Simplified and modularized IAM auth setup for MySQL connections.

<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #36846 

Manually QA'ed by setting up RDS with IAM and running Fleet like:
```
FLEET_MYSQL_ADDRESS=fleet-iam-test-public.xxxxxxxxx.us-east-2.rds.amazonaws.com:3306 \
  FLEET_MYSQL_USERNAME=fleet_iam \
  FLEET_MYSQL_DATABASE=fleet \
  FLEET_MYSQL_REGION=us-east-2 \
./build/fleet serve
```

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Reorganized IAM authentication infrastructure for RDS databases to
improve code organization and maintainability.
* Enhanced the database connection layer to support flexible
authentication configuration methods while maintaining full backward
compatibility with existing configurations.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-12-10 16:21:35 -06:00

220 lines
6.5 KiB
Go

package common_mysql
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net/url"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/go-kit/log"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/ngrok/sqlmw"
)
// ConnectorFactory creates a driver.Connector for custom database authentication.
// This allows injecting authentication mechanisms (like AWS IAM) without adding
// dependencies to this package.
type ConnectorFactory func(dsn string, logger log.Logger) (driver.Connector, error)
// TestSQLMode combines ANSI mode components with MySQL 8 default strict modes for testing
// ANSI mode includes: REAL_AS_FLOAT, PIPES_AS_CONCAT, ANSI_QUOTES, IGNORE_SPACE, ONLY_FULL_GROUP_BY
// We add all MySQL 8.0 default strict modes to match production behavior
// Note: The value needs to be wrapped in single quotes when passed to MySQL DSN due to comma separation
// Reference: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html
const TestSQLMode = "'REAL_AS_FLOAT,PIPES_AS_CONCAT,ANSI_QUOTES,IGNORE_SPACE,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION'"
type DBOptions struct {
// MaxAttempts configures the number of retries to connect to the DB
MaxAttempts int
Logger log.Logger
ReplicaConfig *config.MysqlConfig
Interceptor sqlmw.Interceptor
TracingConfig *config.LoggingConfig
MinLastOpenedAtDiff time.Duration
SqlMode string
PrivateKey string
// ConnectorFactory is an optional factory for creating custom database connectors.
// When set, it's used instead of the standard connection method.
ConnectorFactory ConnectorFactory
}
func NewDB(conf *config.MysqlConfig, opts *DBOptions, otelDriverName string) (*sqlx.DB, error) {
driverName := "mysql"
if opts.TracingConfig != nil && opts.TracingConfig.TracingEnabled {
if opts.TracingConfig.TracingType == "opentelemetry" {
driverName = otelDriverName
} else {
driverName = "apm/mysql"
}
}
if opts.Interceptor != nil {
driverName = "mysql-mw"
sql.Register(driverName, sqlmw.Driver(mysql.MySQLDriver{}, opts.Interceptor))
}
if opts.SqlMode != "" {
conf.SQLMode = opts.SqlMode
}
dsn := generateMysqlConnectionString(*conf)
var db *sqlx.DB
if opts.ConnectorFactory != nil {
connector, err := opts.ConnectorFactory(dsn, opts.Logger)
if err != nil {
return nil, fmt.Errorf("failed to create connector: %w", err)
}
db = sqlx.NewDb(sql.OpenDB(connector), driverName)
} else {
var err error
db, err = sqlx.Open(driverName, dsn)
if err != nil {
return nil, err
}
}
db.SetMaxIdleConns(conf.MaxIdleConns)
db.SetMaxOpenConns(conf.MaxOpenConns)
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
var dbError error
for attempt := 0; attempt < opts.MaxAttempts; attempt++ {
dbError = db.Ping()
if dbError == nil {
// we're connected!
break
}
interval := time.Duration(attempt) * time.Second
opts.Logger.Log("mysql", fmt.Sprintf(
"could not connect to db: %v, sleeping %v", dbError, interval))
time.Sleep(interval)
}
if dbError != nil {
return nil, dbError
}
return db, nil
}
// generateMysqlConnectionString returns a MySQL connection string using the
// provided configuration.
func generateMysqlConnectionString(conf config.MysqlConfig) string {
params := url.Values{
// using collation implicitly sets the charset too
// and it's the recommended way to do it per the
// driver documentation:
// https://github.com/go-sql-driver/mysql#charset
"collation": []string{"utf8mb4_unicode_ci"},
"parseTime": []string{"true"},
"loc": []string{"UTC"},
"time_zone": []string{"'-00:00'"},
"clientFoundRows": []string{"true"},
"allowNativePasswords": []string{"true"},
"group_concat_max_len": []string{"4194304"},
"multiStatements": []string{"true"},
}
if conf.Password == "" && conf.PasswordPath == "" && conf.Region != "" {
params.Set("allowCleartextPasswords", "true")
if conf.TLSConfig == "" {
params.Set("tls", "rdsmysql")
}
} else if conf.TLSConfig != "" {
params.Set("tls", conf.TLSConfig)
}
if conf.SQLMode != "" {
params.Set("sql_mode", conf.SQLMode)
}
dsn := fmt.Sprintf(
"%s:%s@%s(%s)/%s?%s",
conf.Username,
conf.Password,
conf.Protocol,
conf.Address,
conf.Database,
params.Encode(),
)
return dsn
}
func WithTxx(ctx context.Context, db *sqlx.DB, fn TxFn, logger log.Logger) error {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return ctxerr.Wrap(ctx, err, "create transaction")
}
defer func() {
if p := recover(); p != nil {
if err := tx.Rollback(); err != nil {
logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
}
panic(p)
}
}()
if err := fn(tx); err != nil {
rbErr := tx.Rollback()
if rbErr != nil && rbErr != sql.ErrTxDone {
return ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error())
}
return err
}
if err := tx.Commit(); err != nil {
return ctxerr.Wrap(ctx, err, "commit transaction")
}
return nil
}
// WithReadOnlyTxx executes fn within an isolated, read-only transaction
func WithReadOnlyTxx(ctx context.Context, reader *sqlx.DB, fn ReadTxFn, logger log.Logger) error {
tx, err := reader.BeginTxx(ctx, &sql.TxOptions{
ReadOnly: true,
Isolation: sql.LevelRepeatableRead,
})
if err != nil {
return ctxerr.Wrap(ctx, err, "create read-only transaction")
}
defer func() {
if p := recover(); p != nil {
if err := tx.Rollback(); err != nil {
logger.Log("err", err, "msg", "error encountered during read-only transaction panic rollback")
}
panic(p)
}
}()
if err := fn(tx); err != nil {
rbErr := tx.Rollback()
if rbErr != nil && rbErr != sql.ErrTxDone {
return ctxerr.Wrapf(ctx, err, "got err '%s' rolling back read-only transaction after err", rbErr.Error())
}
return err
}
if err := tx.Commit(); err != nil {
return ctxerr.Wrap(ctx, err, "commit read-only transaction")
}
return nil
}
// MySQL is really particular about using zero values or old values for
// timestamps, so we set a default value that is plenty far in the past, but
// hopefully accepted by most MySQL configurations.
//
// NOTE: #3229 proposes a better fix that uses *time.Time for
// ScheduledQueryStats.LastExecuted.
var DefaultNonZeroTime = "2000-01-01T00:00:00Z"
func GetDefaultNonZeroTime() time.Time {
return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
}