mirror of
https://github.com/fleetdm/fleet
synced 2026-05-01 10:27:35 +00:00
Simplified and modularized IAM auth setup for MySQL connections. <!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #36846 Manually QA'ed by setting up RDS with IAM and running Fleet like: ``` FLEET_MYSQL_ADDRESS=fleet-iam-test-public.xxxxxxxxx.us-east-2.rds.amazonaws.com:3306 \ FLEET_MYSQL_USERNAME=fleet_iam \ FLEET_MYSQL_DATABASE=fleet \ FLEET_MYSQL_REGION=us-east-2 \ ./build/fleet serve ``` # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Reorganized IAM authentication infrastructure for RDS databases to improve code organization and maintainability. * Enhanced the database connection layer to support flexible authentication configuration methods while maintaining full backward compatibility with existing configurations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
220 lines
6.5 KiB
Go
220 lines
6.5 KiB
Go
package common_mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/config"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/go-kit/log"
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/ngrok/sqlmw"
|
|
)
|
|
|
|
// ConnectorFactory creates a driver.Connector for custom database authentication.
|
|
// This allows injecting authentication mechanisms (like AWS IAM) without adding
|
|
// dependencies to this package.
|
|
type ConnectorFactory func(dsn string, logger log.Logger) (driver.Connector, error)
|
|
|
|
// TestSQLMode combines ANSI mode components with MySQL 8 default strict modes for testing
|
|
// ANSI mode includes: REAL_AS_FLOAT, PIPES_AS_CONCAT, ANSI_QUOTES, IGNORE_SPACE, ONLY_FULL_GROUP_BY
|
|
// We add all MySQL 8.0 default strict modes to match production behavior
|
|
// Note: The value needs to be wrapped in single quotes when passed to MySQL DSN due to comma separation
|
|
// Reference: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html
|
|
const TestSQLMode = "'REAL_AS_FLOAT,PIPES_AS_CONCAT,ANSI_QUOTES,IGNORE_SPACE,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION'"
|
|
|
|
type DBOptions struct {
|
|
// MaxAttempts configures the number of retries to connect to the DB
|
|
MaxAttempts int
|
|
Logger log.Logger
|
|
ReplicaConfig *config.MysqlConfig
|
|
Interceptor sqlmw.Interceptor
|
|
TracingConfig *config.LoggingConfig
|
|
MinLastOpenedAtDiff time.Duration
|
|
SqlMode string
|
|
PrivateKey string
|
|
// ConnectorFactory is an optional factory for creating custom database connectors.
|
|
// When set, it's used instead of the standard connection method.
|
|
ConnectorFactory ConnectorFactory
|
|
}
|
|
|
|
func NewDB(conf *config.MysqlConfig, opts *DBOptions, otelDriverName string) (*sqlx.DB, error) {
|
|
driverName := "mysql"
|
|
|
|
if opts.TracingConfig != nil && opts.TracingConfig.TracingEnabled {
|
|
if opts.TracingConfig.TracingType == "opentelemetry" {
|
|
driverName = otelDriverName
|
|
} else {
|
|
driverName = "apm/mysql"
|
|
}
|
|
}
|
|
if opts.Interceptor != nil {
|
|
driverName = "mysql-mw"
|
|
sql.Register(driverName, sqlmw.Driver(mysql.MySQLDriver{}, opts.Interceptor))
|
|
}
|
|
if opts.SqlMode != "" {
|
|
conf.SQLMode = opts.SqlMode
|
|
}
|
|
|
|
dsn := generateMysqlConnectionString(*conf)
|
|
|
|
var db *sqlx.DB
|
|
if opts.ConnectorFactory != nil {
|
|
connector, err := opts.ConnectorFactory(dsn, opts.Logger)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create connector: %w", err)
|
|
}
|
|
db = sqlx.NewDb(sql.OpenDB(connector), driverName)
|
|
} else {
|
|
var err error
|
|
db, err = sqlx.Open(driverName, dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
db.SetMaxIdleConns(conf.MaxIdleConns)
|
|
db.SetMaxOpenConns(conf.MaxOpenConns)
|
|
db.SetConnMaxLifetime(time.Second * time.Duration(conf.ConnMaxLifetime))
|
|
|
|
var dbError error
|
|
for attempt := 0; attempt < opts.MaxAttempts; attempt++ {
|
|
dbError = db.Ping()
|
|
if dbError == nil {
|
|
// we're connected!
|
|
break
|
|
}
|
|
interval := time.Duration(attempt) * time.Second
|
|
opts.Logger.Log("mysql", fmt.Sprintf(
|
|
"could not connect to db: %v, sleeping %v", dbError, interval))
|
|
time.Sleep(interval)
|
|
}
|
|
|
|
if dbError != nil {
|
|
return nil, dbError
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
// generateMysqlConnectionString returns a MySQL connection string using the
|
|
// provided configuration.
|
|
func generateMysqlConnectionString(conf config.MysqlConfig) string {
|
|
params := url.Values{
|
|
// using collation implicitly sets the charset too
|
|
// and it's the recommended way to do it per the
|
|
// driver documentation:
|
|
// https://github.com/go-sql-driver/mysql#charset
|
|
"collation": []string{"utf8mb4_unicode_ci"},
|
|
"parseTime": []string{"true"},
|
|
"loc": []string{"UTC"},
|
|
"time_zone": []string{"'-00:00'"},
|
|
"clientFoundRows": []string{"true"},
|
|
"allowNativePasswords": []string{"true"},
|
|
"group_concat_max_len": []string{"4194304"},
|
|
"multiStatements": []string{"true"},
|
|
}
|
|
if conf.Password == "" && conf.PasswordPath == "" && conf.Region != "" {
|
|
params.Set("allowCleartextPasswords", "true")
|
|
if conf.TLSConfig == "" {
|
|
params.Set("tls", "rdsmysql")
|
|
}
|
|
} else if conf.TLSConfig != "" {
|
|
params.Set("tls", conf.TLSConfig)
|
|
}
|
|
if conf.SQLMode != "" {
|
|
params.Set("sql_mode", conf.SQLMode)
|
|
}
|
|
|
|
dsn := fmt.Sprintf(
|
|
"%s:%s@%s(%s)/%s?%s",
|
|
conf.Username,
|
|
conf.Password,
|
|
conf.Protocol,
|
|
conf.Address,
|
|
conf.Database,
|
|
params.Encode(),
|
|
)
|
|
|
|
return dsn
|
|
}
|
|
|
|
func WithTxx(ctx context.Context, db *sqlx.DB, fn TxFn, logger log.Logger) error {
|
|
tx, err := db.BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "create transaction")
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
|
|
}
|
|
panic(p)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
rbErr := tx.Rollback()
|
|
if rbErr != nil && rbErr != sql.ErrTxDone {
|
|
return ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error())
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "commit transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// WithReadOnlyTxx executes fn within an isolated, read-only transaction
|
|
func WithReadOnlyTxx(ctx context.Context, reader *sqlx.DB, fn ReadTxFn, logger log.Logger) error {
|
|
tx, err := reader.BeginTxx(ctx, &sql.TxOptions{
|
|
ReadOnly: true,
|
|
Isolation: sql.LevelRepeatableRead,
|
|
})
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "create read-only transaction")
|
|
}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
logger.Log("err", err, "msg", "error encountered during read-only transaction panic rollback")
|
|
}
|
|
panic(p)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
rbErr := tx.Rollback()
|
|
if rbErr != nil && rbErr != sql.ErrTxDone {
|
|
return ctxerr.Wrapf(ctx, err, "got err '%s' rolling back read-only transaction after err", rbErr.Error())
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "commit read-only transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MySQL is really particular about using zero values or old values for
|
|
// timestamps, so we set a default value that is plenty far in the past, but
|
|
// hopefully accepted by most MySQL configurations.
|
|
//
|
|
// NOTE: #3229 proposes a better fix that uses *time.Time for
|
|
// ScheduledQueryStats.LastExecuted.
|
|
var DefaultNonZeroTime = "2000-01-01T00:00:00Z"
|
|
|
|
func GetDefaultNonZeroTime() time.Time {
|
|
return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
|
|
}
|