Atomic vulnerability count calculations (#35317)

This commit is contained in:
Tim Lee 2025-11-12 13:09:34 -07:00 committed by GitHub
parent 542e8ff259
commit 188a91cf4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1129 additions and 50 deletions

View file

@ -0,0 +1,2 @@
- fixed issue where vulnerabilities would occasionally show as missing
- added vulnerability seeding and performance testing tools

View file

@ -540,48 +540,28 @@ func getVulnHostCountQuery(scope CountScope) string {
}
func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error {
// set all counts to 0 to later identify rows to delete
_, err := ds.writer(ctx).ExecContext(ctx, "UPDATE vulnerability_host_counts SET host_count = 0")
if err != nil {
return ctxerr.Wrap(ctx, err, "initializing vulnerability host counts")
}
globalHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, GlobalCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching global vulnerability host counts")
}
err = ds.batchInsertHostCounts(ctx, globalHostCounts)
if err != nil {
return ctxerr.Wrap(ctx, err, "inserting global vulnerability host counts")
}
teamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, TeamCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts")
}
err = ds.batchInsertHostCounts(ctx, teamHostCounts)
if err != nil {
return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts")
}
noTeamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, NoTeamCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching no team vulnerability host counts")
}
err = ds.batchInsertHostCounts(ctx, noTeamHostCounts)
if err != nil {
return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts")
counts := vulnerabilityCounts{
Global: globalHostCounts,
Team: teamHostCounts,
NoTeam: noTeamHostCounts,
}
err = ds.cleanupVulnerabilityHostCounts(ctx)
if err != nil {
return ctxerr.Wrap(ctx, err, "cleaning up vulnerability host counts")
}
return nil
return ds.atomicTableSwapVulnerabilityCounts(ctx, counts)
}
type hostCount struct {
@ -591,46 +571,109 @@ type hostCount struct {
GlobalStats bool `db:"global_stats"`
}
func (ds *Datastore) cleanupVulnerabilityHostCounts(ctx context.Context) error {
_, err := ds.writer(ctx).ExecContext(ctx, "DELETE FROM vulnerability_host_counts WHERE host_count = 0")
if err != nil {
return fmt.Errorf("deleting zero host count entries: %w", err)
}
return nil
type vulnerabilityCounts struct {
Global []hostCount
Team []hostCount
NoTeam []hostCount
}
func (ds *Datastore) batchInsertHostCounts(ctx context.Context, counts []hostCount) error {
const (
vulnerabilityHostCountsSwapTable = "vulnerability_host_counts_swap"
vulnerabilityHostCountsSwapTableSchema = `CREATE TABLE IF NOT EXISTS ` + vulnerabilityHostCountsSwapTable + ` LIKE vulnerability_host_counts`
)
// atomicTableSwapVulnerabilityCounts implements atomic table swap pattern
// 1. Populate swap table with new data
// 2. Atomically rename tables to swap them
// 3. Clean up old table
func (ds *Datastore) atomicTableSwapVulnerabilityCounts(ctx context.Context, counts vulnerabilityCounts) error {
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// Create/recreate the swap table fresh
_, err := tx.ExecContext(ctx, "DROP TABLE IF EXISTS "+vulnerabilityHostCountsSwapTable)
if err != nil {
return ctxerr.Wrap(ctx, err, "dropping existing swap table")
}
_, err = tx.ExecContext(ctx, vulnerabilityHostCountsSwapTableSchema)
if err != nil {
return ctxerr.Wrap(ctx, err, "creating swap table")
}
// Insert each group of counts separately
if len(counts.Global) > 0 {
err = ds.insertHostCountsIntoTable(ctx, tx, counts.Global, vulnerabilityHostCountsSwapTable)
if err != nil {
return ctxerr.Wrap(ctx, err, "populating swap table with global counts")
}
}
if len(counts.Team) > 0 {
err = ds.insertHostCountsIntoTable(ctx, tx, counts.Team, vulnerabilityHostCountsSwapTable)
if err != nil {
return ctxerr.Wrap(ctx, err, "populating swap table with team counts")
}
}
if len(counts.NoTeam) > 0 {
err = ds.insertHostCountsIntoTable(ctx, tx, counts.NoTeam, vulnerabilityHostCountsSwapTable)
if err != nil {
return ctxerr.Wrap(ctx, err, "populating swap table with no-team counts")
}
}
return nil
})
if err != nil {
return err
}
// Atomic table swap using RENAME TABLE
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, fmt.Sprintf(`
RENAME TABLE
vulnerability_host_counts TO vulnerability_host_counts_old,
%s TO vulnerability_host_counts
`, vulnerabilityHostCountsSwapTable))
if err != nil {
return ctxerr.Wrap(ctx, err, "atomic table swap")
}
// Clean up old table (drop it)
_, err = tx.ExecContext(ctx, "DROP TABLE vulnerability_host_counts_old")
if err != nil {
return ctxerr.Wrap(ctx, err, "dropping old table")
}
return nil
})
}
// insertHostCountsIntoTable inserts counts into specified table
func (ds *Datastore) insertHostCountsIntoTable(ctx context.Context, tx sqlx.ExtContext, counts []hostCount, tableName string) error {
if len(counts) == 0 {
return nil
}
insertStmt := "INSERT INTO vulnerability_host_counts (team_id, cve, host_count, global_stats) VALUES "
var insertArgs []interface{}
insertStmt := fmt.Sprintf("INSERT INTO %s (team_id, cve, host_count, global_stats) VALUES ", tableName)
chunkSize := 100
// Use smaller chunks to avoid parameter limits
chunkSize := 500
for i := 0; i < len(counts); i += chunkSize {
end := i + chunkSize
if end > len(counts) {
end = len(counts)
}
end := min(i+chunkSize, len(counts))
valueStrings := make([]string, 0, end-i)
chunkArgs := make([]interface{}, 0, (end-i)*4)
valueStrings := make([]string, 0, chunkSize)
for _, count := range counts[i:end] {
valueStrings = append(valueStrings, "(?, ?, ?, ?)")
insertArgs = append(insertArgs, count.TeamID, count.CVE, count.HostCount, count.GlobalStats)
chunkArgs = append(chunkArgs, count.TeamID, count.CVE, count.HostCount, count.GlobalStats)
}
insertStmt += strings.Join(valueStrings, ", ")
insertStmt += " ON DUPLICATE KEY UPDATE host_count = VALUES(host_count);"
_, err := ds.writer(ctx).ExecContext(ctx, insertStmt, insertArgs...)
fullStmt := insertStmt + strings.Join(valueStrings, ", ")
_, err := tx.ExecContext(ctx, fullStmt, chunkArgs...)
if err != nil {
return fmt.Errorf("inserting host counts: %w", err)
return fmt.Errorf("inserting host counts chunk %d-%d into %s: %w", i, end-1, tableName, err)
}
insertStmt = "INSERT INTO vulnerability_host_counts (team_id, cve, host_count, global_stats) VALUES "
insertArgs = nil
}
return nil

View file

@ -0,0 +1,135 @@
# Vulnerability Performance Testing Tools
This directory contains tools for testing the performance of Fleet's vulnerability-related datastore methods.
## Tools
### Seeder (`seeder/volume_vuln_seeder.go`)
Seeds the database with test data for performance testing.
**Usage:**
```bash
go run seeder/volume_vuln_seeder.go [options]
```
**Options:**
- `-hosts=N` - Number of hosts to create (default: 100)
- `-teams=N` - Number of teams to create (default: 5)
- `-cves=N` - Total number of unique CVEs in the system (default: 500)
- `-software=N` - Total number of unique software packages (default: 500)
- `-help` - Show help information
- `-verbose` - Enable verbose timing output for each step
**Example:**
```bash
go run seeder/volume_vuln_seeder.go -hosts=1000 -teams=10 -cves=2000 -software=4000
```
### Performance Tester (`tester/performance_tester.go`)
Benchmarks any Fleet datastore method with statistical analysis.
**Usage:**
```bash
go run tester/performance_tester.go [options]
```
**Options:**
- `-funcs=NAME[,NAME2,...]` - Comma-separated list of test functions (default: "UpdateVulnerabilityHostCounts")
- `-iterations=N` - Number of iterations per test (default: 5)
- `-verbose` - Show timing for each iteration
- `-details` - Show detailed statistics including percentiles
- `-list` - List available test functions
- `-help` - Show help information
**Available Test Functions:**
- `UpdateVulnerabilityHostCounts` - Test vulnerability host count updates
### Adding New Test Functions
To add support for additional datastore methods, edit the `testFunctions` map in `tester/performance_tester.go`:
```go
var testFunctions = map[string]TestFunction{
// Existing functions...
// Add new function
"CountHosts": func(ctx context.Context, ds *mysql.Datastore) error {
_, err := ds.CountHosts(ctx, fleet.TeamFilter{User: &fleet.User{}}, fleet.HostListOptions{})
return err
},
// Add function with parameters
"ListHosts:100": func(ctx context.Context, ds *mysql.Datastore) error {
_, err := ds.ListHosts(ctx, fleet.TeamFilter{User: &fleet.User{}}, fleet.HostListOptions{
ListOptions: fleet.ListOptions{Page: 0, PerPage: 100},
})
return err
},
}
```
Each function should:
1. Accept `context.Context` and `*mysql.Datastore` as parameters
2. Return only an `error`
3. Handle any return values from the datastore method (discard non-error returns)
4. Use meaningful parameter values for realistic testing
**Examples:**
```bash
# Test single function with details
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=10 -details
# Test different batch sizes
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts:5,UpdateVulnerabilityHostCounts:20 -iterations=5
# Verbose output
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -verbose
```
## Performance Analysis
The tools provide comprehensive performance metrics:
- **Total time** - Sum of all successful iterations
- **Average time** - Mean execution time
- **Min/Max time** - Fastest and slowest iterations
- **Success rate** - Percentage of successful vs failed iterations
- **Percentiles** - P50, P90, P99 response times (with `-details`)
## Typical Workflow
1. **Seed test data:**
```bash
go run seeder/volume_vuln_seeder.go -hosts=1000 -teams=10 -cves=2000 -software=4000
```
2. **Test baseline performance:**
```bash
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=10 -details
```
3. **Make code changes to optimize**
4. **Test optimized performance:**
```bash
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=10 -details
```
5. **Compare results**
## Notes
- The seeder is not idempotent - run `make db-reset` to reset the database before reseeding.

View file

@ -0,0 +1,656 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"math/rand"
"os"
"strings"
"time"
"github.com/WatchBeam/clock"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func retryOnDeadlock(operation func() error, maxRetries int) error {
var err error
for attempt := 0; attempt <= maxRetries; attempt++ {
err = operation()
if err == nil {
return nil
}
// Check if it's a deadlock error
if strings.Contains(err.Error(), "Deadlock found") || strings.Contains(err.Error(), "1213") {
if attempt < maxRetries {
// Exponential backoff with jitter
// #nosec G404 - weak random is acceptable for retry backoff
backoff := time.Duration(10+rand.Intn(50)) * time.Millisecond * time.Duration(1<<attempt)
time.Sleep(backoff)
continue
}
}
break
}
return err
}
func timeStep(name string, verbose bool, fn func() error) error {
if verbose {
fmt.Printf("Starting: %s...\n", name)
}
start := time.Now()
err := fn()
duration := time.Since(start)
if verbose {
fmt.Printf("Completed: %s in %v\n", name, duration)
} else {
fmt.Printf("%s: %v\n", name, duration)
}
return err
}
var (
// MySQL config
mysqlAddr = "localhost:3306"
mysqlUser = "fleet"
mysqlPass = "insecure"
mysqlDB = "fleet"
)
// Common CVE patterns for realistic data
var cvePatterns = []string{
"CVE-2024-%04d", "CVE-2023-%04d", "CVE-2022-%04d", "CVE-2021-%04d",
}
// batchCreateHosts creates multiple hosts in batches for better performance
func batchCreateHosts(ctx context.Context, ds *mysql.Datastore, hostCount int, teams []*fleet.Team, verbose bool) ([]*fleet.Host, error) {
batchSize := 100 // Insert 100 hosts per transaction
var allHosts []*fleet.Host
now := time.Now()
db, err := getDB(ds)
if err != nil {
return nil, fmt.Errorf("get DB connection: %w", err)
}
defer db.Close()
for batchStart := 0; batchStart < hostCount; batchStart += batchSize {
batchEnd := batchStart + batchSize
if batchEnd > hostCount {
batchEnd = hostCount
}
// Prepare batch insert
var args []interface{}
var placeholders []string
for i := batchStart; i < batchEnd; i++ {
var teamID *uint
if len(teams) > 0 && i%3 != 0 { // 2/3 of hosts have teams
teamID = &teams[i%len(teams)].ID
}
identifier := fmt.Sprintf("test-host-%d", i)
osqueryHostID := fmt.Sprintf("osquery-host-%d", i)
placeholders = append(placeholders, "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
args = append(args,
osqueryHostID, // osquery_host_id
now, // detail_updated_at
now, // label_updated_at
now, // policy_updated_at
identifier, // node_key
identifier, // hostname
identifier, // computer_name
identifier, // uuid
"ubuntu", // platform
"", // platform_like
"", // osquery_version
"Ubuntu 20.04.6 LTS", // os_version
0, // uptime
0, // memory
teamID, // team_id
0, // distributed_interval
0, // logger_tls_period
0, // config_tls_refresh
false, // refetch_requested
"", // hardware_serial
nil, // refetch_critical_queries_until (can be NULL)
)
}
// Execute batch insert (exactly matching the NewHost function)
sql := `INSERT INTO hosts (
osquery_host_id, detail_updated_at, label_updated_at, policy_updated_at,
node_key, hostname, computer_name, uuid, platform, platform_like,
osquery_version, os_version, uptime, memory, team_id,
distributed_interval, logger_tls_period, config_tls_refresh,
refetch_requested, hardware_serial, refetch_critical_queries_until
) VALUES ` + strings.Join(placeholders, ", ")
_, err := db.ExecContext(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("batch insert hosts %d-%d: %w", batchStart, batchEnd-1, err)
}
// Fetch the created hosts to get their IDs
var batchHosts []fleet.Host
var uuids []string
for i := batchStart; i < batchEnd; i++ {
uuids = append(uuids, fmt.Sprintf("'test-host-%d'", i))
}
err = sqlx.SelectContext(ctx, db, &batchHosts,
"SELECT id, uuid, hostname, computer_name, node_key, detail_updated_at, label_updated_at, policy_updated_at, platform, os_version, team_id FROM hosts WHERE uuid IN ("+strings.Join(uuids, ",")+")")
if err != nil {
return nil, fmt.Errorf("fetch created hosts: %w", err)
}
// Insert host_display_names for the created hosts
if len(batchHosts) > 0 {
var displayNamePlaceholders []string
var displayNameArgs []interface{}
for _, host := range batchHosts {
displayName := host.Hostname // Use hostname as display name (same logic as migration)
if host.ComputerName != "" {
displayName = host.ComputerName
}
displayNamePlaceholders = append(displayNamePlaceholders, "(?, ?)")
displayNameArgs = append(displayNameArgs, host.ID, displayName)
}
displayNameSQL := "INSERT INTO host_display_names (host_id, display_name) VALUES " + strings.Join(displayNamePlaceholders, ", ")
_, err = db.ExecContext(ctx, displayNameSQL, displayNameArgs...)
if err != nil {
return nil, fmt.Errorf("batch insert host display names %d-%d: %w", batchStart, batchEnd-1, err)
}
}
// Convert to pointers and add to result
for i := range batchHosts {
allHosts = append(allHosts, &batchHosts[i])
}
if verbose && (batchEnd%500 == 0 || batchEnd == hostCount) {
fmt.Printf(" Created %d/%d hosts\n", batchEnd, hostCount)
}
}
return allHosts, nil
}
// batchUpdateHostOS updates operating system info for multiple hosts efficiently
func batchUpdateHostOS(ctx context.Context, ds *mysql.Datastore, hosts []*fleet.Host, verbose bool) error {
batchSize := 100
for batchStart := 0; batchStart < len(hosts); batchStart += batchSize {
batchEnd := batchStart + batchSize
if batchEnd > len(hosts) {
batchEnd = len(hosts)
}
for i := batchStart; i < batchEnd; i++ {
host := hosts[i]
err := retryOnDeadlock(func() error {
return ds.UpdateHostOperatingSystem(ctx, host.ID, fleet.OperatingSystem{
Name: "Ubuntu",
Version: fmt.Sprintf("20.04.%d", i%10), // Vary the version slightly
Platform: "ubuntu",
Arch: "x86_64",
KernelVersion: "5.4.0-148-generic",
DisplayVersion: "20.04",
})
}, 3)
if err != nil {
return fmt.Errorf("update host %d OS: %w", i, err)
}
}
if verbose && (batchEnd%500 == 0 || batchEnd == len(hosts)) {
fmt.Printf(" Updated OS for %d/%d hosts\n", batchEnd, len(hosts))
}
}
return nil
}
// batchInstallSoftware installs software on hosts in batches for better performance
func batchInstallSoftware(ctx context.Context, ds *mysql.Datastore, hosts []*fleet.Host, vulnerableSoftware []fleet.Software, verbose bool) error {
db, err := getDB(ds)
if err != nil {
return fmt.Errorf("get DB connection: %w", err)
}
defer db.Close()
// First, create all software entries using INSERT ... ON DUPLICATE KEY UPDATE
softwareMap := make(map[string]uint) // name+version+source -> software_id
if len(vulnerableSoftware) > 0 {
// Insert software in smaller batches to avoid max_allowed_packet issues
batchSize := 100
for batchStart := 0; batchStart < len(vulnerableSoftware); batchStart += batchSize {
batchEnd := batchStart + batchSize
if batchEnd > len(vulnerableSoftware) {
batchEnd = len(vulnerableSoftware)
}
var placeholders []string
var args []interface{}
for i := batchStart; i < batchEnd; i++ {
software := vulnerableSoftware[i]
placeholders = append(placeholders, "(?, ?, ?, ?, ?, ?, ?, UNHEX(MD5(CONCAT(COALESCE(?, ''), COALESCE(?, ''), ?))))")
args = append(args,
software.Name,
software.Version,
software.Source,
software.BundleIdentifier,
software.Release,
software.Vendor,
software.Arch,
// Checksum calculation args (name, version, source)
software.Name,
software.Version,
software.Source,
)
}
sql := `INSERT INTO software (name, version, source, bundle_identifier, ` + "`release`" + `, vendor, arch, checksum)
VALUES ` + strings.Join(placeholders, ", ") + `
ON DUPLICATE KEY UPDATE name = VALUES(name)`
_, err := db.ExecContext(ctx, sql, args...)
if err != nil {
return fmt.Errorf("batch insert software batch %d-%d: %w", batchStart, batchEnd-1, err)
}
if verbose && (batchEnd%200 == 0 || batchEnd == len(vulnerableSoftware)) {
fmt.Printf(" Inserted software batch %d/%d\n", batchEnd, len(vulnerableSoftware))
}
}
}
// Now get all software IDs (this approach guarantees we find them)
for _, software := range vulnerableSoftware {
key := fmt.Sprintf("%s|%s|%s", software.Name, software.Version, software.Source)
var softwareID uint
err := sqlx.GetContext(ctx, db, &softwareID,
"SELECT id FROM software WHERE name = ? AND version = ? AND source = ? LIMIT 1",
software.Name, software.Version, software.Source)
if err != nil {
return fmt.Errorf("get software ID for %s: %w", software.Name, err)
}
softwareMap[key] = softwareID
}
// Now batch install software on hosts
batchSize := 50 // Smaller batches to avoid deadlocks
for batchStart := 0; batchStart < len(hosts); batchStart += batchSize {
batchEnd := batchStart + batchSize
if batchEnd > len(hosts) {
batchEnd = len(hosts)
}
// Process this batch of hosts
for i := batchStart; i < batchEnd; i++ {
host := hosts[i]
// Each host gets 20-80% of the software
// #nosec G404 - weak random is acceptable for test data generation
pct := 0.2 + rand.Float64()*0.6
hostSoftwareCount := int(float64(len(vulnerableSoftware)) * pct)
// Randomly select which software this host has
hostVulnSoftware := make([]fleet.Software, len(vulnerableSoftware))
copy(hostVulnSoftware, vulnerableSoftware)
rand.Shuffle(len(hostVulnSoftware), func(i, j int) {
hostVulnSoftware[i], hostVulnSoftware[j] = hostVulnSoftware[j], hostVulnSoftware[i]
})
hostSoftware := hostVulnSoftware[:hostSoftwareCount]
// Clear existing software for this host
_, err := db.ExecContext(ctx, "DELETE FROM host_software WHERE host_id = ?", host.ID)
if err != nil {
return fmt.Errorf("clear host %d software: %w", host.ID, err)
}
// Batch insert software for this host
if len(hostSoftware) > 0 {
var placeholders []string
var args []interface{}
for _, software := range hostSoftware {
key := fmt.Sprintf("%s|%s|%s", software.Name, software.Version, software.Source)
softwareID, exists := softwareMap[key]
if !exists {
continue // Skip if software ID not found
}
placeholders = append(placeholders, "(?, ?)")
args = append(args, host.ID, softwareID)
}
if len(placeholders) > 0 {
sql := "INSERT IGNORE INTO host_software (host_id, software_id) VALUES " + strings.Join(placeholders, ", ")
_, err := db.ExecContext(ctx, sql, args...)
if err != nil {
return fmt.Errorf("batch insert software for host %d: %w", host.ID, err)
}
}
}
}
if verbose && (batchEnd%200 == 0 || batchEnd == len(hosts)) {
fmt.Printf(" Installed software on %d/%d hosts\n", batchEnd, len(hosts))
}
// Small delay between batches to reduce DB pressure
time.Sleep(50 * time.Millisecond)
}
return nil
}
func createTestTeam(ctx context.Context, ds *mysql.Datastore, name string) (*fleet.Team, error) {
team := &fleet.Team{
Name: name,
}
return ds.NewTeam(ctx, team)
}
func generateVulnerableSoftware(softwareCount int) []fleet.Software {
var software []fleet.Software
for i := range softwareCount {
// Create vulnerable software
software = append(software, fleet.Software{
Name: fmt.Sprintf("vulnerable-package-%d", i),
// #nosec G404 - weak random is acceptable for test data generation
Version: fmt.Sprintf("1.%d.0", rand.Intn(100)),
Source: "Package (deb)",
})
}
return software
}
func generateCVEs(cveCount int) []string {
var cves []string
for i := 0; i < cveCount; i++ {
// #nosec G404 - weak random is acceptable for test data generation
yearIdx := rand.Intn(len(cvePatterns))
// #nosec G404 - weak random is acceptable for test data generation
cveID := fmt.Sprintf(cvePatterns[yearIdx], rand.Intn(9999)+1)
cves = append(cves, cveID)
}
return cves
}
// getDB gets a new sqlx database connection for direct queries
func getDB(ds *mysql.Datastore) (*sqlx.DB, error) {
cfg := config.MysqlConfig{
Protocol: "tcp",
Address: mysqlAddr,
Username: mysqlUser,
Password: mysqlPass,
Database: mysqlDB,
}
dsn := cfg.Username + ":" + cfg.Password + "@" + cfg.Protocol + "(" + cfg.Address + ")/" + cfg.Database + "?charset=utf8mb4&parseTime=True&loc=Local"
return sqlx.Open("mysql", dsn)
}
func seedSoftwareCVEs(ctx context.Context, ds *mysql.Datastore, cves []string) error {
db, err := getDB(ds)
if err != nil {
return fmt.Errorf("get DB connection: %w", err)
}
defer db.Close()
// First, get software IDs that exist
var softwareIDs []uint
err = sqlx.SelectContext(ctx, db, &softwareIDs, "SELECT id FROM software LIMIT 1000")
if err != nil {
return fmt.Errorf("fetch software IDs: %w", err)
}
if len(softwareIDs) == 0 {
return errors.New("no software found - run seedVulnerabilities first")
}
// Insert software_cve mappings
for i, cve := range cves {
// Each CVE affects 1-3 software packages
// #nosec G404 - weak random is acceptable for test data generation
affectedCount := 1 + rand.Intn(3)
for j := 0; j < affectedCount && j < len(softwareIDs); j++ {
softwareID := softwareIDs[(i+j)%len(softwareIDs)]
err := retryOnDeadlock(func() error {
_, err := db.ExecContext(ctx,
"INSERT IGNORE INTO software_cve (software_id, cve) VALUES (?, ?)",
softwareID, cve)
return err
}, 3)
if err != nil {
return fmt.Errorf("insert software_cve for %s: %w", cve, err)
}
}
}
return nil
}
func ensureOSRecordsExist(ctx context.Context, ds *mysql.Datastore) error {
db, err := getDB(ds)
if err != nil {
return fmt.Errorf("get DB connection: %w", err)
}
defer db.Close()
// Check if we have any OS records
var count int
err = sqlx.GetContext(ctx, db, &count, "SELECT COUNT(*) FROM operating_systems")
if err != nil {
return fmt.Errorf("count operating systems: %w", err)
}
fmt.Printf("Found %d existing OS records\n", count)
if count == 0 {
fmt.Printf("No OS records found. The hosts may not have had their OS info updated yet.\n")
fmt.Printf("Try running with fewer hosts first, or check that UpdateHostOperatingSystem was called.\n")
}
return nil
}
func seedOSVulnerabilities(ctx context.Context, ds *mysql.Datastore, cves []string) error {
db, err := getDB(ds)
if err != nil {
return fmt.Errorf("get DB connection: %w", err)
}
defer db.Close()
// Get OS IDs
var osIDs []uint
err = sqlx.SelectContext(ctx, db, &osIDs, "SELECT id FROM operating_systems LIMIT 100")
if err != nil {
return fmt.Errorf("fetch OS IDs: %w", err)
}
if len(osIDs) == 0 {
fmt.Printf("Warning: No operating systems found in database. Skipping OS vulnerability seeding.\n")
fmt.Printf("This might be normal if your test setup doesn't include OS vulnerability testing.\n")
return nil // Don't fail, just skip OS vulnerabilities
}
fmt.Printf("Found %d operating systems to map vulnerabilities to\n", len(osIDs))
// Insert OS vulnerabilities (about 30% of CVEs affect OS)
for i, cve := range cves {
// #nosec G404 - weak random is acceptable for test data generation
if rand.Float64() < 0.3 { // 30% chance this CVE affects OS
osID := osIDs[i%len(osIDs)]
err := retryOnDeadlock(func() error {
_, err := db.ExecContext(ctx,
"INSERT IGNORE INTO operating_system_vulnerabilities (operating_system_id, cve) VALUES (?, ?)",
osID, cve)
return err
}, 3)
if err != nil {
return fmt.Errorf("insert OS vulnerability for %s: %w", cve, err)
}
}
}
return nil
}
func seedVulnerabilities(ctx context.Context, ds *mysql.Datastore, hostCount, teamCount, cveCount, softwareCount int, verbose bool) error {
var teams []*fleet.Team
err := timeStep(fmt.Sprintf("Creating %d teams", teamCount), verbose, func() error {
for i := 0; i < teamCount; i++ {
team, err := createTestTeam(ctx, ds, fmt.Sprintf("test-team-%d", i))
if err != nil {
return fmt.Errorf("create team %d: %w", i, err)
}
teams = append(teams, team)
}
return nil
})
if err != nil {
return err
}
var hosts []*fleet.Host
err = timeStep(fmt.Sprintf("Creating %d hosts", hostCount), verbose, func() error {
var err error
hosts, err = batchCreateHosts(ctx, ds, hostCount, teams, verbose)
return err
})
if err != nil {
return err
}
err = timeStep("Updating host operating systems", verbose, func() error {
return batchUpdateHostOS(ctx, ds, hosts, verbose)
})
if err != nil {
return err
}
var cves []string
var vulnerableSoftware []fleet.Software
err = timeStep(fmt.Sprintf("Generating %d CVEs and %d software packages", cveCount, softwareCount), verbose, func() error {
cves = generateCVEs(cveCount)
vulnerableSoftware = generateVulnerableSoftware(softwareCount)
return nil
})
if err != nil {
return err
}
// Install software on each host using batch approach
err = timeStep(fmt.Sprintf("Installing software on %d hosts", len(hosts)), verbose, func() error {
return batchInstallSoftware(ctx, ds, hosts, vulnerableSoftware, verbose)
})
if err != nil {
return err
}
err = timeStep("Seeding software-CVE mappings", verbose, func() error {
return seedSoftwareCVEs(ctx, ds, cves)
})
if err != nil {
return fmt.Errorf("seed software CVEs: %w", err)
}
err = timeStep("Ensuring OS records exist", verbose, func() error {
return ensureOSRecordsExist(ctx, ds)
})
if err != nil {
return fmt.Errorf("ensure OS records: %w", err)
}
err = timeStep("Seeding OS vulnerabilities", verbose, func() error {
return seedOSVulnerabilities(ctx, ds, cves)
})
if err != nil {
return fmt.Errorf("seed OS vulnerabilities: %w", err)
}
return nil
}
func main() {
var (
hostCount = flag.Int("hosts", 100, "Number of hosts to create")
teamCount = flag.Int("teams", 5, "Number of teams to create")
cveCount = flag.Int("cves", 500, "Total number of unique CVEs in the system")
softwareCount = flag.Int("software", 500, "Total number of unique software packages (each host gets 20-80% randomly)")
help = flag.Bool("help", false, "Show help information")
verbose = flag.Bool("verbose", false, "Enable verbose timing output for each step")
)
flag.Parse()
if *help {
fmt.Printf("Fleet Test Data Seeder\n\n")
fmt.Printf("This tool creates test data for Fleet performance testing.\n\n")
fmt.Printf("Data model:\n")
fmt.Printf("- Creates %d total unique software packages\n", *softwareCount)
fmt.Printf("- Each host gets 20-80%% of software packages randomly assigned\n")
fmt.Printf("- Creates %d total unique CVEs\n", *cveCount)
fmt.Printf("- Each CVE affects 1-3 random software packages\n")
fmt.Printf("- About 30%% of CVEs also affect the operating system\n")
fmt.Printf("- Host vulnerability counts depend on which software they have installed\n\n")
fmt.Printf("Examples:\n")
fmt.Printf(" %s -hosts=1000 -teams=10 -cves=500 -software=1000\n", os.Args[0])
fmt.Printf(" %s -hosts=5000 -teams=20 -cves=2000 -software=5000 -verbose\n", os.Args[0])
fmt.Printf("\n")
fmt.Printf("After seeding data, use performance_tester.go to test datastore methods.\n")
fmt.Printf("\n")
flag.Usage()
return
}
ctx := context.Background()
// Connect to datastore
ds, err := mysql.New(config.MysqlConfig{
Protocol: "tcp",
Address: mysqlAddr,
Username: mysqlUser,
Password: mysqlPass,
Database: mysqlDB,
}, clock.C)
if err != nil {
log.Fatal(err)
}
defer ds.Close()
fmt.Printf("Seeding test data...\n")
fmt.Printf("Configuration: %d hosts, %d teams, %d CVEs, %d software packages\n", *hostCount, *teamCount, *cveCount, *softwareCount)
if err := seedVulnerabilities(ctx, ds, *hostCount, *teamCount, *cveCount, *softwareCount, *verbose); err != nil {
fmt.Printf("Failed to seed vulnerabilities: %v\n", err)
return
}
fmt.Printf("Data seeding complete!\n")
fmt.Printf("\nUse performance_tester.go to test datastore methods with this data.\n")
fmt.Printf("Example: go run performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=5\n")
fmt.Println("Done.")
}

View file

@ -0,0 +1,243 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"sort"
"strings"
"time"
"github.com/WatchBeam/clock"
_ "github.com/go-sql-driver/mysql"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
)
var (
// MySQL config
mysqlAddr = "localhost:3306"
mysqlUser = "fleet"
mysqlPass = "insecure"
mysqlDB = "fleet"
)
// TestFunction represents a datastore method to test
type TestFunction func(context.Context, *mysql.Datastore) error
// All available test functions
var testFunctions = map[string]TestFunction{
"UpdateVulnerabilityHostCounts": func(ctx context.Context, ds *mysql.Datastore) error {
return ds.UpdateVulnerabilityHostCounts(ctx, 1)
},
}
// PerformanceResult holds the results of a performance test
type PerformanceResult struct {
TestFunction string
TotalTime time.Duration
AverageTime time.Duration
MinTime time.Duration
MaxTime time.Duration
SuccessfulIterations int
FailedIterations int
Iterations []time.Duration
}
func runPerformanceTest(ctx context.Context, ds *mysql.Datastore, testFuncName string, iterations int, verbose bool) *PerformanceResult {
testFunc, exists := testFunctions[testFuncName]
if !exists {
fmt.Printf("Unknown test function: %s\n", testFuncName)
fmt.Printf("Available functions: %v\n", getTestFunctionNames())
return nil
}
fmt.Printf("Running %d iterations of %s...\n", iterations, testFuncName)
result := &PerformanceResult{
TestFunction: testFuncName,
Iterations: make([]time.Duration, 0, iterations),
}
for i := 0; i < iterations; i++ {
start := time.Now()
if err := testFunc(ctx, ds); err != nil {
log.Printf("Iteration %d failed: %v", i+1, err)
result.FailedIterations++
continue
}
duration := time.Since(start)
result.Iterations = append(result.Iterations, duration)
result.TotalTime += duration
result.SuccessfulIterations++
if verbose {
fmt.Printf("Iteration %d: %v\n", i+1, duration)
} else {
fmt.Printf(".")
}
}
if !verbose {
fmt.Printf("\n")
}
if result.SuccessfulIterations == 0 {
fmt.Printf("All iterations failed!\n")
return result
}
// Calculate statistics
result.AverageTime = result.TotalTime / time.Duration(result.SuccessfulIterations)
// Find min and max
result.MinTime = result.Iterations[0]
result.MaxTime = result.Iterations[0]
for _, duration := range result.Iterations {
if duration < result.MinTime {
result.MinTime = duration
}
if duration > result.MaxTime {
result.MaxTime = duration
}
}
return result
}
func printResults(results []*PerformanceResult, showDetails bool) {
if len(results) == 0 {
fmt.Printf("No results to display\n")
return
}
fmt.Print("\n" + strings.Repeat("=", 80) + "\n")
fmt.Printf("PERFORMANCE TEST RESULTS\n")
fmt.Print(strings.Repeat("=", 80) + "\n\n")
for _, result := range results {
if result == nil {
continue
}
fmt.Printf("Function: %s\n", result.TestFunction)
fmt.Printf(" Total time: %v\n", result.TotalTime)
fmt.Printf(" Average time: %v\n", result.AverageTime)
fmt.Printf(" Min time: %v\n", result.MinTime)
fmt.Printf(" Max time: %v\n", result.MaxTime)
fmt.Printf(" Success rate: %d/%d (%.1f%%)\n",
result.SuccessfulIterations,
result.SuccessfulIterations+result.FailedIterations,
float64(result.SuccessfulIterations)/float64(result.SuccessfulIterations+result.FailedIterations)*100)
if showDetails && len(result.Iterations) > 0 {
fmt.Printf(" All times: ")
for i, duration := range result.Iterations {
if i > 0 {
fmt.Printf(", ")
}
fmt.Printf("%v", duration)
}
fmt.Printf("\n")
// Calculate percentiles
sortedTimes := make([]time.Duration, len(result.Iterations))
copy(sortedTimes, result.Iterations)
sort.Slice(sortedTimes, func(i, j int) bool {
return sortedTimes[i] < sortedTimes[j]
})
if len(sortedTimes) >= 2 {
p50 := sortedTimes[len(sortedTimes)/2]
p90 := sortedTimes[int(float64(len(sortedTimes))*0.9)]
p99 := sortedTimes[int(float64(len(sortedTimes))*0.99)]
fmt.Printf(" P50: %v\n", p50)
fmt.Printf(" P90: %v\n", p90)
fmt.Printf(" P99: %v\n", p99)
}
}
fmt.Printf("\n")
}
}
func getTestFunctionNames() []string {
var names []string
for name := range testFunctions {
names = append(names, name)
}
sort.Strings(names)
return names
}
func main() {
var (
testFuncs = flag.String("funcs", "UpdateVulnerabilityHostCounts", "Comma-separated list of test functions to run")
iterations = flag.Int("iterations", 5, "Number of iterations per test function")
verbose = flag.Bool("verbose", false, "Show timing for each iteration")
details = flag.Bool("details", false, "Show detailed statistics including percentiles")
listFuncs = flag.Bool("list", false, "List available test functions and exit")
help = flag.Bool("help", false, "Show help information")
)
flag.Parse()
if *listFuncs {
fmt.Printf("Available test functions:\n")
for _, name := range getTestFunctionNames() {
fmt.Printf(" %s\n", name)
}
return
}
if *help {
fmt.Printf("Fleet Datastore Performance Tester\n\n")
fmt.Printf("This tool measures the performance of Fleet datastore methods.\n")
fmt.Printf("It assumes test data has already been seeded using the data seeding tool.\n\n")
fmt.Printf("Available test functions:\n")
for _, name := range getTestFunctionNames() {
fmt.Printf(" %s\n", name)
}
fmt.Printf("\nExamples:\n")
fmt.Printf(" %s -funcs=UpdateVulnerabilityHostCounts -iterations=10\n", os.Args[0])
fmt.Printf(" %s -funcs=UpdateVulnerabilityHostCounts -iterations=5 -details\n", os.Args[0])
fmt.Printf(" %s -funcs=UpdateVulnerabilityHostCounts -verbose\n", os.Args[0])
fmt.Printf("\n")
flag.Usage()
return
}
ctx := context.Background()
// Connect to datastore
ds, err := mysql.New(config.MysqlConfig{
Protocol: "tcp",
Address: mysqlAddr,
Username: mysqlUser,
Password: mysqlPass,
Database: mysqlDB,
}, clock.C)
if err != nil {
log.Fatal(err)
}
defer func() { _ = ds.Close() }()
// Parse test functions
funcNames := strings.Split(*testFuncs, ",")
var results []*PerformanceResult
for _, funcName := range funcNames {
funcName = strings.TrimSpace(funcName)
if funcName == "" {
continue
}
result := runPerformanceTest(ctx, ds, funcName, *iterations, *verbose)
results = append(results, result)
}
printResults(results, *details)
}