fleet/tools/mysql-tests/rds/iam_auth.go
Scott Gress 602f5a470b
Feat 1817 add iam auth to mysql and redis (#32488)
for #1817 

# Details

This PR gives Fleet servers the ability to connect to RDS MySQL and
Elasticache Redis via AWS [Identity and Access Management
(IAM)](https://aws.amazon.com/iam/). It is based almost entirely on the
work of @titanous, branched from his [original pull
request](https://github.com/fleetdm/fleet/pull/31075). The main
differences between his branch and this are:

1. Removal of auto-detection of AWS region (and cache name for
Elasticache) in favor of specifying these values in configuration. The
auto-detection is admittedly handy but parsing AWS host URLs is not
considered a best practice.
2. Relying on the existence of these new configs to determine whether or
not to connect via IAM. This sidesteps a thorny issue of whether to try
an IAM-based Elasticache connection when a password is not supplied,
since this is technically a valid setup.

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [X] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

## Testing

- [X] Added/updated automated tests
- [X] QA'd all new/changed functionality manually - besides using
@titanous's excellent test tool, I verified the following end-to-end:
  - [X] regular (non RDS) MySQL connection
  - [X] RDS MySQL connection using username/password
  - [X] RDS MySQL connection using IAM (no role)
  - [X] RDS MySQL connection using IAM (assuming role)
  - [X] regular (non Elasticache) Redis connection
  - [X] Elasticache Redis connection using username/password
  - [X] Elasticache Redis connection using NO password (without IAM)
  - [X] Elasticache Redis connection using IAM (no role)
  - [X] Elasticache Redis connection using IAM (assuming role)

---------

Co-authored-by: Jonathan Rudenberg <jonathan@titanous.com>
Co-authored-by: Noah Talerman <47070608+noahtalerman@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-09-04 10:08:47 -05:00

91 lines
2.4 KiB
Go

//nolint:gocritic // Test tool, not production code
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
kitlog "github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/jmoiron/sqlx"
)
var (
endpointFlag = flag.String("endpoint", "", "RDS endpoint address (without port)")
portFlag = flag.String("port", "3306", "Database port")
userFlag = flag.String("user", "fleet_iam_user", "Username for IAM authentication")
dbNameFlag = flag.String("db", "fleet", "Database name")
regionFlag = flag.String("region", "", "AWS region")
assumeRoleFlag = flag.String("assume-role", "", "STS assume role ARN (optional)")
externalIDFlag = flag.String("external-id", "", "STS external ID (optional)")
)
func main() {
flag.Parse()
if *endpointFlag == "" {
log.Fatal("RDS endpoint is required (-endpoint flag)")
}
if *userFlag == "" {
log.Fatal("Username is required (-user flag)")
}
logger := level.NewFilter(kitlog.NewLogfmtLogger(os.Stderr), level.AllowDebug())
// Configure MySQL connection with IAM auth
mysqlConfig := &config.MysqlConfig{
Protocol: "tcp",
Address: fmt.Sprintf("%s:%s", *endpointFlag, *portFlag),
Username: *userFlag,
Database: *dbNameFlag,
StsAssumeRoleArn: *assumeRoleFlag,
StsExternalID: *externalIDFlag,
}
if regionFlag != nil && *regionFlag != "" {
mysqlConfig.Region = *regionFlag
}
dbOpts := &common_mysql.DBOptions{
MaxAttempts: 3,
Logger: logger,
}
log.Printf("Connecting to RDS at %s:%s with IAM auth for user %s", *endpointFlag, *portFlag, *userFlag)
if *assumeRoleFlag != "" {
log.Printf("Using assume role: %s", *assumeRoleFlag)
}
log.Println("📋 Testing connection with IAM token...")
db, err := common_mysql.NewDB(mysqlConfig, dbOpts, "")
if err != nil {
log.Printf("❌ Connection failed: %v", err)
os.Exit(1)
}
defer db.Close()
if err := testConnection(db); err != nil {
log.Printf("❌ Test failed: %v", err)
os.Exit(1)
}
log.Println("✅ Connection successful!")
}
func testConnection(db *sqlx.DB) error {
ctx := context.Background()
// Execute test query
var version string
if err := db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version); err != nil {
return fmt.Errorf("failed to query version: %w", err)
}
log.Printf(" Database version: %s", version)
return nil
}