mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Atomic vulnerability count calculations (#35317)
This commit is contained in:
parent
542e8ff259
commit
188a91cf4d
6 changed files with 1129 additions and 50 deletions
2
changes/35043-missing-vuln-counts
Normal file
2
changes/35043-missing-vuln-counts
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
- fixed issue where vulnerabilities would occasionally show as missing
|
||||
- added vulnerability seeding and performance testing tools
|
||||
|
|
@ -540,48 +540,28 @@ func getVulnHostCountQuery(scope CountScope) string {
|
|||
}
|
||||
|
||||
func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error {
|
||||
// set all counts to 0 to later identify rows to delete
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, "UPDATE vulnerability_host_counts SET host_count = 0")
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "initializing vulnerability host counts")
|
||||
}
|
||||
|
||||
globalHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, GlobalCount, maxRoutines)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "fetching global vulnerability host counts")
|
||||
}
|
||||
|
||||
err = ds.batchInsertHostCounts(ctx, globalHostCounts)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "inserting global vulnerability host counts")
|
||||
}
|
||||
|
||||
teamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, TeamCount, maxRoutines)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts")
|
||||
}
|
||||
|
||||
err = ds.batchInsertHostCounts(ctx, teamHostCounts)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts")
|
||||
}
|
||||
|
||||
noTeamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, NoTeamCount, maxRoutines)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "fetching no team vulnerability host counts")
|
||||
}
|
||||
|
||||
err = ds.batchInsertHostCounts(ctx, noTeamHostCounts)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts")
|
||||
counts := vulnerabilityCounts{
|
||||
Global: globalHostCounts,
|
||||
Team: teamHostCounts,
|
||||
NoTeam: noTeamHostCounts,
|
||||
}
|
||||
|
||||
err = ds.cleanupVulnerabilityHostCounts(ctx)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "cleaning up vulnerability host counts")
|
||||
}
|
||||
|
||||
return nil
|
||||
return ds.atomicTableSwapVulnerabilityCounts(ctx, counts)
|
||||
}
|
||||
|
||||
type hostCount struct {
|
||||
|
|
@ -591,46 +571,109 @@ type hostCount struct {
|
|||
GlobalStats bool `db:"global_stats"`
|
||||
}
|
||||
|
||||
func (ds *Datastore) cleanupVulnerabilityHostCounts(ctx context.Context) error {
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, "DELETE FROM vulnerability_host_counts WHERE host_count = 0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting zero host count entries: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
type vulnerabilityCounts struct {
|
||||
Global []hostCount
|
||||
Team []hostCount
|
||||
NoTeam []hostCount
|
||||
}
|
||||
|
||||
func (ds *Datastore) batchInsertHostCounts(ctx context.Context, counts []hostCount) error {
|
||||
const (
|
||||
vulnerabilityHostCountsSwapTable = "vulnerability_host_counts_swap"
|
||||
vulnerabilityHostCountsSwapTableSchema = `CREATE TABLE IF NOT EXISTS ` + vulnerabilityHostCountsSwapTable + ` LIKE vulnerability_host_counts`
|
||||
)
|
||||
|
||||
// atomicTableSwapVulnerabilityCounts implements atomic table swap pattern
|
||||
// 1. Populate swap table with new data
|
||||
// 2. Atomically rename tables to swap them
|
||||
// 3. Clean up old table
|
||||
func (ds *Datastore) atomicTableSwapVulnerabilityCounts(ctx context.Context, counts vulnerabilityCounts) error {
|
||||
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
// Create/recreate the swap table fresh
|
||||
_, err := tx.ExecContext(ctx, "DROP TABLE IF EXISTS "+vulnerabilityHostCountsSwapTable)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "dropping existing swap table")
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, vulnerabilityHostCountsSwapTableSchema)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "creating swap table")
|
||||
}
|
||||
|
||||
// Insert each group of counts separately
|
||||
if len(counts.Global) > 0 {
|
||||
err = ds.insertHostCountsIntoTable(ctx, tx, counts.Global, vulnerabilityHostCountsSwapTable)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "populating swap table with global counts")
|
||||
}
|
||||
}
|
||||
|
||||
if len(counts.Team) > 0 {
|
||||
err = ds.insertHostCountsIntoTable(ctx, tx, counts.Team, vulnerabilityHostCountsSwapTable)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "populating swap table with team counts")
|
||||
}
|
||||
}
|
||||
|
||||
if len(counts.NoTeam) > 0 {
|
||||
err = ds.insertHostCountsIntoTable(ctx, tx, counts.NoTeam, vulnerabilityHostCountsSwapTable)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "populating swap table with no-team counts")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Atomic table swap using RENAME TABLE
|
||||
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
_, err := tx.ExecContext(ctx, fmt.Sprintf(`
|
||||
RENAME TABLE
|
||||
vulnerability_host_counts TO vulnerability_host_counts_old,
|
||||
%s TO vulnerability_host_counts
|
||||
`, vulnerabilityHostCountsSwapTable))
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "atomic table swap")
|
||||
}
|
||||
|
||||
// Clean up old table (drop it)
|
||||
_, err = tx.ExecContext(ctx, "DROP TABLE vulnerability_host_counts_old")
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "dropping old table")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// insertHostCountsIntoTable inserts counts into specified table
|
||||
func (ds *Datastore) insertHostCountsIntoTable(ctx context.Context, tx sqlx.ExtContext, counts []hostCount, tableName string) error {
|
||||
if len(counts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
insertStmt := "INSERT INTO vulnerability_host_counts (team_id, cve, host_count, global_stats) VALUES "
|
||||
var insertArgs []interface{}
|
||||
insertStmt := fmt.Sprintf("INSERT INTO %s (team_id, cve, host_count, global_stats) VALUES ", tableName)
|
||||
|
||||
chunkSize := 100
|
||||
// Use smaller chunks to avoid parameter limits
|
||||
chunkSize := 500
|
||||
for i := 0; i < len(counts); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(counts) {
|
||||
end = len(counts)
|
||||
}
|
||||
end := min(i+chunkSize, len(counts))
|
||||
|
||||
valueStrings := make([]string, 0, end-i)
|
||||
chunkArgs := make([]interface{}, 0, (end-i)*4)
|
||||
|
||||
valueStrings := make([]string, 0, chunkSize)
|
||||
for _, count := range counts[i:end] {
|
||||
valueStrings = append(valueStrings, "(?, ?, ?, ?)")
|
||||
insertArgs = append(insertArgs, count.TeamID, count.CVE, count.HostCount, count.GlobalStats)
|
||||
chunkArgs = append(chunkArgs, count.TeamID, count.CVE, count.HostCount, count.GlobalStats)
|
||||
}
|
||||
|
||||
insertStmt += strings.Join(valueStrings, ", ")
|
||||
insertStmt += " ON DUPLICATE KEY UPDATE host_count = VALUES(host_count);"
|
||||
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, insertStmt, insertArgs...)
|
||||
fullStmt := insertStmt + strings.Join(valueStrings, ", ")
|
||||
_, err := tx.ExecContext(ctx, fullStmt, chunkArgs...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting host counts: %w", err)
|
||||
return fmt.Errorf("inserting host counts chunk %d-%d into %s: %w", i, end-1, tableName, err)
|
||||
}
|
||||
|
||||
insertStmt = "INSERT INTO vulnerability_host_counts (team_id, cve, host_count, global_stats) VALUES "
|
||||
insertArgs = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
135
tools/software/vulnerabilities/performance_test/README.md
Normal file
135
tools/software/vulnerabilities/performance_test/README.md
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# Vulnerability Performance Testing Tools
|
||||
|
||||
This directory contains tools for testing the performance of Fleet's vulnerability-related datastore methods.
|
||||
|
||||
## Tools
|
||||
|
||||
### Seeder (`seeder/volume_vuln_seeder.go`)
|
||||
|
||||
Seeds the database with test data for performance testing.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
go run seeder/volume_vuln_seeder.go [options]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
|
||||
- `-hosts=N` - Number of hosts to create (default: 100)
|
||||
- `-teams=N` - Number of teams to create (default: 5)
|
||||
- `-cves=N` - Total number of unique CVEs in the system (default: 500)
|
||||
- `-software=N` - Total number of unique software packages (default: 500)
|
||||
- `-help` - Show help information
|
||||
- `-verbose` - Enable verbose timing output for each step
|
||||
|
||||
**Example:**
|
||||
|
||||
```bash
|
||||
go run seeder/volume_vuln_seeder.go -hosts=1000 -teams=10 -cves=2000 -software=4000
|
||||
```
|
||||
|
||||
### Performance Tester (`tester/performance_tester.go`)
|
||||
|
||||
Benchmarks any Fleet datastore method with statistical analysis.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
go run tester/performance_tester.go [options]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
|
||||
- `-funcs=NAME[,NAME2,...]` - Comma-separated list of test functions (default: "UpdateVulnerabilityHostCounts")
|
||||
- `-iterations=N` - Number of iterations per test (default: 5)
|
||||
- `-verbose` - Show timing for each iteration
|
||||
- `-details` - Show detailed statistics including percentiles
|
||||
- `-list` - List available test functions
|
||||
- `-help` - Show help information
|
||||
|
||||
**Available Test Functions:**
|
||||
|
||||
- `UpdateVulnerabilityHostCounts` - Test vulnerability host count updates
|
||||
|
||||
### Adding New Test Functions
|
||||
|
||||
To add support for additional datastore methods, edit the `testFunctions` map in `tester/performance_tester.go`:
|
||||
|
||||
```go
|
||||
var testFunctions = map[string]TestFunction{
|
||||
// Existing functions...
|
||||
|
||||
// Add new function
|
||||
"CountHosts": func(ctx context.Context, ds *mysql.Datastore) error {
|
||||
_, err := ds.CountHosts(ctx, fleet.TeamFilter{User: &fleet.User{}}, fleet.HostListOptions{})
|
||||
return err
|
||||
},
|
||||
|
||||
// Add function with parameters
|
||||
"ListHosts:100": func(ctx context.Context, ds *mysql.Datastore) error {
|
||||
_, err := ds.ListHosts(ctx, fleet.TeamFilter{User: &fleet.User{}}, fleet.HostListOptions{
|
||||
ListOptions: fleet.ListOptions{Page: 0, PerPage: 100},
|
||||
})
|
||||
return err
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
Each function should:
|
||||
|
||||
1. Accept `context.Context` and `*mysql.Datastore` as parameters
|
||||
2. Return only an `error`
|
||||
3. Handle any return values from the datastore method (discard non-error returns)
|
||||
4. Use meaningful parameter values for realistic testing
|
||||
|
||||
**Examples:**
|
||||
|
||||
```bash
|
||||
# Test single function with details
|
||||
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=10 -details
|
||||
|
||||
# Test different batch sizes
|
||||
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts:5,UpdateVulnerabilityHostCounts:20 -iterations=5
|
||||
|
||||
# Verbose output
|
||||
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -verbose
|
||||
```
|
||||
|
||||
## Performance Analysis
|
||||
|
||||
The tools provide comprehensive performance metrics:
|
||||
|
||||
- **Total time** - Sum of all successful iterations
|
||||
- **Average time** - Mean execution time
|
||||
- **Min/Max time** - Fastest and slowest iterations
|
||||
- **Success rate** - Percentage of successful vs failed iterations
|
||||
- **Percentiles** - P50, P90, P99 response times (with `-details`)
|
||||
|
||||
## Typical Workflow
|
||||
|
||||
1. **Seed test data:**
|
||||
|
||||
```bash
|
||||
go run seeder/volume_vuln_seeder.go -hosts=1000 -teams=10 -cves=2000 -software=4000
|
||||
```
|
||||
|
||||
2. **Test baseline performance:**
|
||||
|
||||
```bash
|
||||
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=10 -details
|
||||
```
|
||||
|
||||
3. **Make code changes to optimize**
|
||||
|
||||
4. **Test optimized performance:**
|
||||
|
||||
```bash
|
||||
go run tester/performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=10 -details
|
||||
```
|
||||
|
||||
5. **Compare results**
|
||||
|
||||
## Notes
|
||||
|
||||
- The seeder is not idempotent - run `make db-reset` to reset the database before reseeding.
|
||||
|
|
@ -0,0 +1,656 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
func retryOnDeadlock(operation func() error, maxRetries int) error {
|
||||
var err error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
err = operation()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a deadlock error
|
||||
if strings.Contains(err.Error(), "Deadlock found") || strings.Contains(err.Error(), "1213") {
|
||||
if attempt < maxRetries {
|
||||
// Exponential backoff with jitter
|
||||
// #nosec G404 - weak random is acceptable for retry backoff
|
||||
backoff := time.Duration(10+rand.Intn(50)) * time.Millisecond * time.Duration(1<<attempt)
|
||||
time.Sleep(backoff)
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func timeStep(name string, verbose bool, fn func() error) error {
|
||||
if verbose {
|
||||
fmt.Printf("Starting: %s...\n", name)
|
||||
}
|
||||
start := time.Now()
|
||||
err := fn()
|
||||
duration := time.Since(start)
|
||||
if verbose {
|
||||
fmt.Printf("Completed: %s in %v\n", name, duration)
|
||||
} else {
|
||||
fmt.Printf("%s: %v\n", name, duration)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
// MySQL config
|
||||
mysqlAddr = "localhost:3306"
|
||||
mysqlUser = "fleet"
|
||||
mysqlPass = "insecure"
|
||||
mysqlDB = "fleet"
|
||||
)
|
||||
|
||||
// Common CVE patterns for realistic data
|
||||
var cvePatterns = []string{
|
||||
"CVE-2024-%04d", "CVE-2023-%04d", "CVE-2022-%04d", "CVE-2021-%04d",
|
||||
}
|
||||
|
||||
// batchCreateHosts creates multiple hosts in batches for better performance
|
||||
func batchCreateHosts(ctx context.Context, ds *mysql.Datastore, hostCount int, teams []*fleet.Team, verbose bool) ([]*fleet.Host, error) {
|
||||
batchSize := 100 // Insert 100 hosts per transaction
|
||||
var allHosts []*fleet.Host
|
||||
now := time.Now()
|
||||
|
||||
db, err := getDB(ds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get DB connection: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
for batchStart := 0; batchStart < hostCount; batchStart += batchSize {
|
||||
batchEnd := batchStart + batchSize
|
||||
if batchEnd > hostCount {
|
||||
batchEnd = hostCount
|
||||
}
|
||||
|
||||
// Prepare batch insert
|
||||
var args []interface{}
|
||||
var placeholders []string
|
||||
|
||||
for i := batchStart; i < batchEnd; i++ {
|
||||
var teamID *uint
|
||||
if len(teams) > 0 && i%3 != 0 { // 2/3 of hosts have teams
|
||||
teamID = &teams[i%len(teams)].ID
|
||||
}
|
||||
|
||||
identifier := fmt.Sprintf("test-host-%d", i)
|
||||
|
||||
osqueryHostID := fmt.Sprintf("osquery-host-%d", i)
|
||||
placeholders = append(placeholders, "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
|
||||
args = append(args,
|
||||
osqueryHostID, // osquery_host_id
|
||||
now, // detail_updated_at
|
||||
now, // label_updated_at
|
||||
now, // policy_updated_at
|
||||
identifier, // node_key
|
||||
identifier, // hostname
|
||||
identifier, // computer_name
|
||||
identifier, // uuid
|
||||
"ubuntu", // platform
|
||||
"", // platform_like
|
||||
"", // osquery_version
|
||||
"Ubuntu 20.04.6 LTS", // os_version
|
||||
0, // uptime
|
||||
0, // memory
|
||||
teamID, // team_id
|
||||
0, // distributed_interval
|
||||
0, // logger_tls_period
|
||||
0, // config_tls_refresh
|
||||
false, // refetch_requested
|
||||
"", // hardware_serial
|
||||
nil, // refetch_critical_queries_until (can be NULL)
|
||||
)
|
||||
}
|
||||
|
||||
// Execute batch insert (exactly matching the NewHost function)
|
||||
sql := `INSERT INTO hosts (
|
||||
osquery_host_id, detail_updated_at, label_updated_at, policy_updated_at,
|
||||
node_key, hostname, computer_name, uuid, platform, platform_like,
|
||||
osquery_version, os_version, uptime, memory, team_id,
|
||||
distributed_interval, logger_tls_period, config_tls_refresh,
|
||||
refetch_requested, hardware_serial, refetch_critical_queries_until
|
||||
) VALUES ` + strings.Join(placeholders, ", ")
|
||||
|
||||
_, err := db.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch insert hosts %d-%d: %w", batchStart, batchEnd-1, err)
|
||||
}
|
||||
|
||||
// Fetch the created hosts to get their IDs
|
||||
var batchHosts []fleet.Host
|
||||
var uuids []string
|
||||
for i := batchStart; i < batchEnd; i++ {
|
||||
uuids = append(uuids, fmt.Sprintf("'test-host-%d'", i))
|
||||
}
|
||||
err = sqlx.SelectContext(ctx, db, &batchHosts,
|
||||
"SELECT id, uuid, hostname, computer_name, node_key, detail_updated_at, label_updated_at, policy_updated_at, platform, os_version, team_id FROM hosts WHERE uuid IN ("+strings.Join(uuids, ",")+")")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch created hosts: %w", err)
|
||||
}
|
||||
|
||||
// Insert host_display_names for the created hosts
|
||||
if len(batchHosts) > 0 {
|
||||
var displayNamePlaceholders []string
|
||||
var displayNameArgs []interface{}
|
||||
|
||||
for _, host := range batchHosts {
|
||||
displayName := host.Hostname // Use hostname as display name (same logic as migration)
|
||||
if host.ComputerName != "" {
|
||||
displayName = host.ComputerName
|
||||
}
|
||||
displayNamePlaceholders = append(displayNamePlaceholders, "(?, ?)")
|
||||
displayNameArgs = append(displayNameArgs, host.ID, displayName)
|
||||
}
|
||||
|
||||
displayNameSQL := "INSERT INTO host_display_names (host_id, display_name) VALUES " + strings.Join(displayNamePlaceholders, ", ")
|
||||
_, err = db.ExecContext(ctx, displayNameSQL, displayNameArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch insert host display names %d-%d: %w", batchStart, batchEnd-1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to pointers and add to result
|
||||
for i := range batchHosts {
|
||||
allHosts = append(allHosts, &batchHosts[i])
|
||||
}
|
||||
|
||||
if verbose && (batchEnd%500 == 0 || batchEnd == hostCount) {
|
||||
fmt.Printf(" Created %d/%d hosts\n", batchEnd, hostCount)
|
||||
}
|
||||
}
|
||||
|
||||
return allHosts, nil
|
||||
}
|
||||
|
||||
// batchUpdateHostOS updates operating system info for multiple hosts efficiently
|
||||
func batchUpdateHostOS(ctx context.Context, ds *mysql.Datastore, hosts []*fleet.Host, verbose bool) error {
|
||||
batchSize := 100
|
||||
|
||||
for batchStart := 0; batchStart < len(hosts); batchStart += batchSize {
|
||||
batchEnd := batchStart + batchSize
|
||||
if batchEnd > len(hosts) {
|
||||
batchEnd = len(hosts)
|
||||
}
|
||||
|
||||
for i := batchStart; i < batchEnd; i++ {
|
||||
host := hosts[i]
|
||||
err := retryOnDeadlock(func() error {
|
||||
return ds.UpdateHostOperatingSystem(ctx, host.ID, fleet.OperatingSystem{
|
||||
Name: "Ubuntu",
|
||||
Version: fmt.Sprintf("20.04.%d", i%10), // Vary the version slightly
|
||||
Platform: "ubuntu",
|
||||
Arch: "x86_64",
|
||||
KernelVersion: "5.4.0-148-generic",
|
||||
DisplayVersion: "20.04",
|
||||
})
|
||||
}, 3)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update host %d OS: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if verbose && (batchEnd%500 == 0 || batchEnd == len(hosts)) {
|
||||
fmt.Printf(" Updated OS for %d/%d hosts\n", batchEnd, len(hosts))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// batchInstallSoftware installs software on hosts in batches for better performance
|
||||
func batchInstallSoftware(ctx context.Context, ds *mysql.Datastore, hosts []*fleet.Host, vulnerableSoftware []fleet.Software, verbose bool) error {
|
||||
db, err := getDB(ds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get DB connection: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// First, create all software entries using INSERT ... ON DUPLICATE KEY UPDATE
|
||||
softwareMap := make(map[string]uint) // name+version+source -> software_id
|
||||
|
||||
if len(vulnerableSoftware) > 0 {
|
||||
// Insert software in smaller batches to avoid max_allowed_packet issues
|
||||
batchSize := 100
|
||||
for batchStart := 0; batchStart < len(vulnerableSoftware); batchStart += batchSize {
|
||||
batchEnd := batchStart + batchSize
|
||||
if batchEnd > len(vulnerableSoftware) {
|
||||
batchEnd = len(vulnerableSoftware)
|
||||
}
|
||||
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
|
||||
for i := batchStart; i < batchEnd; i++ {
|
||||
software := vulnerableSoftware[i]
|
||||
placeholders = append(placeholders, "(?, ?, ?, ?, ?, ?, ?, UNHEX(MD5(CONCAT(COALESCE(?, ''), COALESCE(?, ''), ?))))")
|
||||
args = append(args,
|
||||
software.Name,
|
||||
software.Version,
|
||||
software.Source,
|
||||
software.BundleIdentifier,
|
||||
software.Release,
|
||||
software.Vendor,
|
||||
software.Arch,
|
||||
// Checksum calculation args (name, version, source)
|
||||
software.Name,
|
||||
software.Version,
|
||||
software.Source,
|
||||
)
|
||||
}
|
||||
|
||||
sql := `INSERT INTO software (name, version, source, bundle_identifier, ` + "`release`" + `, vendor, arch, checksum)
|
||||
VALUES ` + strings.Join(placeholders, ", ") + `
|
||||
ON DUPLICATE KEY UPDATE name = VALUES(name)`
|
||||
_, err := db.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("batch insert software batch %d-%d: %w", batchStart, batchEnd-1, err)
|
||||
}
|
||||
|
||||
if verbose && (batchEnd%200 == 0 || batchEnd == len(vulnerableSoftware)) {
|
||||
fmt.Printf(" Inserted software batch %d/%d\n", batchEnd, len(vulnerableSoftware))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now get all software IDs (this approach guarantees we find them)
|
||||
for _, software := range vulnerableSoftware {
|
||||
key := fmt.Sprintf("%s|%s|%s", software.Name, software.Version, software.Source)
|
||||
|
||||
var softwareID uint
|
||||
err := sqlx.GetContext(ctx, db, &softwareID,
|
||||
"SELECT id FROM software WHERE name = ? AND version = ? AND source = ? LIMIT 1",
|
||||
software.Name, software.Version, software.Source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get software ID for %s: %w", software.Name, err)
|
||||
}
|
||||
|
||||
softwareMap[key] = softwareID
|
||||
}
|
||||
|
||||
// Now batch install software on hosts
|
||||
batchSize := 50 // Smaller batches to avoid deadlocks
|
||||
|
||||
for batchStart := 0; batchStart < len(hosts); batchStart += batchSize {
|
||||
batchEnd := batchStart + batchSize
|
||||
if batchEnd > len(hosts) {
|
||||
batchEnd = len(hosts)
|
||||
}
|
||||
|
||||
// Process this batch of hosts
|
||||
for i := batchStart; i < batchEnd; i++ {
|
||||
host := hosts[i]
|
||||
|
||||
// Each host gets 20-80% of the software
|
||||
// #nosec G404 - weak random is acceptable for test data generation
|
||||
pct := 0.2 + rand.Float64()*0.6
|
||||
hostSoftwareCount := int(float64(len(vulnerableSoftware)) * pct)
|
||||
|
||||
// Randomly select which software this host has
|
||||
hostVulnSoftware := make([]fleet.Software, len(vulnerableSoftware))
|
||||
copy(hostVulnSoftware, vulnerableSoftware)
|
||||
rand.Shuffle(len(hostVulnSoftware), func(i, j int) {
|
||||
hostVulnSoftware[i], hostVulnSoftware[j] = hostVulnSoftware[j], hostVulnSoftware[i]
|
||||
})
|
||||
|
||||
hostSoftware := hostVulnSoftware[:hostSoftwareCount]
|
||||
|
||||
// Clear existing software for this host
|
||||
_, err := db.ExecContext(ctx, "DELETE FROM host_software WHERE host_id = ?", host.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("clear host %d software: %w", host.ID, err)
|
||||
}
|
||||
|
||||
// Batch insert software for this host
|
||||
if len(hostSoftware) > 0 {
|
||||
var placeholders []string
|
||||
var args []interface{}
|
||||
|
||||
for _, software := range hostSoftware {
|
||||
key := fmt.Sprintf("%s|%s|%s", software.Name, software.Version, software.Source)
|
||||
softwareID, exists := softwareMap[key]
|
||||
if !exists {
|
||||
continue // Skip if software ID not found
|
||||
}
|
||||
|
||||
placeholders = append(placeholders, "(?, ?)")
|
||||
args = append(args, host.ID, softwareID)
|
||||
}
|
||||
|
||||
if len(placeholders) > 0 {
|
||||
sql := "INSERT IGNORE INTO host_software (host_id, software_id) VALUES " + strings.Join(placeholders, ", ")
|
||||
_, err := db.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("batch insert software for host %d: %w", host.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if verbose && (batchEnd%200 == 0 || batchEnd == len(hosts)) {
|
||||
fmt.Printf(" Installed software on %d/%d hosts\n", batchEnd, len(hosts))
|
||||
}
|
||||
|
||||
// Small delay between batches to reduce DB pressure
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createTestTeam(ctx context.Context, ds *mysql.Datastore, name string) (*fleet.Team, error) {
|
||||
team := &fleet.Team{
|
||||
Name: name,
|
||||
}
|
||||
return ds.NewTeam(ctx, team)
|
||||
}
|
||||
|
||||
func generateVulnerableSoftware(softwareCount int) []fleet.Software {
|
||||
var software []fleet.Software
|
||||
|
||||
for i := range softwareCount {
|
||||
// Create vulnerable software
|
||||
software = append(software, fleet.Software{
|
||||
Name: fmt.Sprintf("vulnerable-package-%d", i),
|
||||
// #nosec G404 - weak random is acceptable for test data generation
|
||||
Version: fmt.Sprintf("1.%d.0", rand.Intn(100)),
|
||||
Source: "Package (deb)",
|
||||
})
|
||||
}
|
||||
|
||||
return software
|
||||
}
|
||||
|
||||
func generateCVEs(cveCount int) []string {
|
||||
var cves []string
|
||||
for i := 0; i < cveCount; i++ {
|
||||
// #nosec G404 - weak random is acceptable for test data generation
|
||||
yearIdx := rand.Intn(len(cvePatterns))
|
||||
// #nosec G404 - weak random is acceptable for test data generation
|
||||
cveID := fmt.Sprintf(cvePatterns[yearIdx], rand.Intn(9999)+1)
|
||||
cves = append(cves, cveID)
|
||||
}
|
||||
return cves
|
||||
}
|
||||
|
||||
// getDB gets a new sqlx database connection for direct queries
|
||||
func getDB(ds *mysql.Datastore) (*sqlx.DB, error) {
|
||||
cfg := config.MysqlConfig{
|
||||
Protocol: "tcp",
|
||||
Address: mysqlAddr,
|
||||
Username: mysqlUser,
|
||||
Password: mysqlPass,
|
||||
Database: mysqlDB,
|
||||
}
|
||||
|
||||
dsn := cfg.Username + ":" + cfg.Password + "@" + cfg.Protocol + "(" + cfg.Address + ")/" + cfg.Database + "?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
return sqlx.Open("mysql", dsn)
|
||||
}
|
||||
|
||||
func seedSoftwareCVEs(ctx context.Context, ds *mysql.Datastore, cves []string) error {
|
||||
db, err := getDB(ds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get DB connection: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// First, get software IDs that exist
|
||||
var softwareIDs []uint
|
||||
err = sqlx.SelectContext(ctx, db, &softwareIDs, "SELECT id FROM software LIMIT 1000")
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch software IDs: %w", err)
|
||||
}
|
||||
|
||||
if len(softwareIDs) == 0 {
|
||||
return errors.New("no software found - run seedVulnerabilities first")
|
||||
}
|
||||
|
||||
// Insert software_cve mappings
|
||||
for i, cve := range cves {
|
||||
// Each CVE affects 1-3 software packages
|
||||
// #nosec G404 - weak random is acceptable for test data generation
|
||||
affectedCount := 1 + rand.Intn(3)
|
||||
for j := 0; j < affectedCount && j < len(softwareIDs); j++ {
|
||||
softwareID := softwareIDs[(i+j)%len(softwareIDs)]
|
||||
err := retryOnDeadlock(func() error {
|
||||
_, err := db.ExecContext(ctx,
|
||||
"INSERT IGNORE INTO software_cve (software_id, cve) VALUES (?, ?)",
|
||||
softwareID, cve)
|
||||
return err
|
||||
}, 3)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert software_cve for %s: %w", cve, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureOSRecordsExist(ctx context.Context, ds *mysql.Datastore) error {
|
||||
db, err := getDB(ds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get DB connection: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check if we have any OS records
|
||||
var count int
|
||||
err = sqlx.GetContext(ctx, db, &count, "SELECT COUNT(*) FROM operating_systems")
|
||||
if err != nil {
|
||||
return fmt.Errorf("count operating systems: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Found %d existing OS records\n", count)
|
||||
|
||||
if count == 0 {
|
||||
fmt.Printf("No OS records found. The hosts may not have had their OS info updated yet.\n")
|
||||
fmt.Printf("Try running with fewer hosts first, or check that UpdateHostOperatingSystem was called.\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedOSVulnerabilities(ctx context.Context, ds *mysql.Datastore, cves []string) error {
|
||||
db, err := getDB(ds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get DB connection: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Get OS IDs
|
||||
var osIDs []uint
|
||||
err = sqlx.SelectContext(ctx, db, &osIDs, "SELECT id FROM operating_systems LIMIT 100")
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch OS IDs: %w", err)
|
||||
}
|
||||
|
||||
if len(osIDs) == 0 {
|
||||
fmt.Printf("Warning: No operating systems found in database. Skipping OS vulnerability seeding.\n")
|
||||
fmt.Printf("This might be normal if your test setup doesn't include OS vulnerability testing.\n")
|
||||
return nil // Don't fail, just skip OS vulnerabilities
|
||||
}
|
||||
|
||||
fmt.Printf("Found %d operating systems to map vulnerabilities to\n", len(osIDs))
|
||||
|
||||
// Insert OS vulnerabilities (about 30% of CVEs affect OS)
|
||||
for i, cve := range cves {
|
||||
// #nosec G404 - weak random is acceptable for test data generation
|
||||
if rand.Float64() < 0.3 { // 30% chance this CVE affects OS
|
||||
osID := osIDs[i%len(osIDs)]
|
||||
err := retryOnDeadlock(func() error {
|
||||
_, err := db.ExecContext(ctx,
|
||||
"INSERT IGNORE INTO operating_system_vulnerabilities (operating_system_id, cve) VALUES (?, ?)",
|
||||
osID, cve)
|
||||
return err
|
||||
}, 3)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert OS vulnerability for %s: %w", cve, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedVulnerabilities(ctx context.Context, ds *mysql.Datastore, hostCount, teamCount, cveCount, softwareCount int, verbose bool) error {
|
||||
var teams []*fleet.Team
|
||||
err := timeStep(fmt.Sprintf("Creating %d teams", teamCount), verbose, func() error {
|
||||
for i := 0; i < teamCount; i++ {
|
||||
team, err := createTestTeam(ctx, ds, fmt.Sprintf("test-team-%d", i))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create team %d: %w", i, err)
|
||||
}
|
||||
teams = append(teams, team)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var hosts []*fleet.Host
|
||||
err = timeStep(fmt.Sprintf("Creating %d hosts", hostCount), verbose, func() error {
|
||||
var err error
|
||||
hosts, err = batchCreateHosts(ctx, ds, hostCount, teams, verbose)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = timeStep("Updating host operating systems", verbose, func() error {
|
||||
return batchUpdateHostOS(ctx, ds, hosts, verbose)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var cves []string
|
||||
var vulnerableSoftware []fleet.Software
|
||||
err = timeStep(fmt.Sprintf("Generating %d CVEs and %d software packages", cveCount, softwareCount), verbose, func() error {
|
||||
cves = generateCVEs(cveCount)
|
||||
vulnerableSoftware = generateVulnerableSoftware(softwareCount)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Install software on each host using batch approach
|
||||
err = timeStep(fmt.Sprintf("Installing software on %d hosts", len(hosts)), verbose, func() error {
|
||||
return batchInstallSoftware(ctx, ds, hosts, vulnerableSoftware, verbose)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = timeStep("Seeding software-CVE mappings", verbose, func() error {
|
||||
return seedSoftwareCVEs(ctx, ds, cves)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("seed software CVEs: %w", err)
|
||||
}
|
||||
|
||||
err = timeStep("Ensuring OS records exist", verbose, func() error {
|
||||
return ensureOSRecordsExist(ctx, ds)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ensure OS records: %w", err)
|
||||
}
|
||||
|
||||
err = timeStep("Seeding OS vulnerabilities", verbose, func() error {
|
||||
return seedOSVulnerabilities(ctx, ds, cves)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("seed OS vulnerabilities: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
hostCount = flag.Int("hosts", 100, "Number of hosts to create")
|
||||
teamCount = flag.Int("teams", 5, "Number of teams to create")
|
||||
cveCount = flag.Int("cves", 500, "Total number of unique CVEs in the system")
|
||||
softwareCount = flag.Int("software", 500, "Total number of unique software packages (each host gets 20-80% randomly)")
|
||||
help = flag.Bool("help", false, "Show help information")
|
||||
verbose = flag.Bool("verbose", false, "Enable verbose timing output for each step")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *help {
|
||||
fmt.Printf("Fleet Test Data Seeder\n\n")
|
||||
fmt.Printf("This tool creates test data for Fleet performance testing.\n\n")
|
||||
fmt.Printf("Data model:\n")
|
||||
fmt.Printf("- Creates %d total unique software packages\n", *softwareCount)
|
||||
fmt.Printf("- Each host gets 20-80%% of software packages randomly assigned\n")
|
||||
fmt.Printf("- Creates %d total unique CVEs\n", *cveCount)
|
||||
fmt.Printf("- Each CVE affects 1-3 random software packages\n")
|
||||
fmt.Printf("- About 30%% of CVEs also affect the operating system\n")
|
||||
fmt.Printf("- Host vulnerability counts depend on which software they have installed\n\n")
|
||||
fmt.Printf("Examples:\n")
|
||||
fmt.Printf(" %s -hosts=1000 -teams=10 -cves=500 -software=1000\n", os.Args[0])
|
||||
fmt.Printf(" %s -hosts=5000 -teams=20 -cves=2000 -software=5000 -verbose\n", os.Args[0])
|
||||
fmt.Printf("\n")
|
||||
fmt.Printf("After seeding data, use performance_tester.go to test datastore methods.\n")
|
||||
fmt.Printf("\n")
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect to datastore
|
||||
ds, err := mysql.New(config.MysqlConfig{
|
||||
Protocol: "tcp",
|
||||
Address: mysqlAddr,
|
||||
Username: mysqlUser,
|
||||
Password: mysqlPass,
|
||||
Database: mysqlDB,
|
||||
}, clock.C)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer ds.Close()
|
||||
|
||||
fmt.Printf("Seeding test data...\n")
|
||||
fmt.Printf("Configuration: %d hosts, %d teams, %d CVEs, %d software packages\n", *hostCount, *teamCount, *cveCount, *softwareCount)
|
||||
|
||||
if err := seedVulnerabilities(ctx, ds, *hostCount, *teamCount, *cveCount, *softwareCount, *verbose); err != nil {
|
||||
fmt.Printf("Failed to seed vulnerabilities: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Data seeding complete!\n")
|
||||
fmt.Printf("\nUse performance_tester.go to test datastore methods with this data.\n")
|
||||
fmt.Printf("Example: go run performance_tester.go -funcs=UpdateVulnerabilityHostCounts -iterations=5\n")
|
||||
|
||||
fmt.Println("Done.")
|
||||
}
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
)
|
||||
|
||||
var (
|
||||
// MySQL config
|
||||
mysqlAddr = "localhost:3306"
|
||||
mysqlUser = "fleet"
|
||||
mysqlPass = "insecure"
|
||||
mysqlDB = "fleet"
|
||||
)
|
||||
|
||||
// TestFunction represents a datastore method to test
|
||||
type TestFunction func(context.Context, *mysql.Datastore) error
|
||||
|
||||
// All available test functions
|
||||
var testFunctions = map[string]TestFunction{
|
||||
"UpdateVulnerabilityHostCounts": func(ctx context.Context, ds *mysql.Datastore) error {
|
||||
return ds.UpdateVulnerabilityHostCounts(ctx, 1)
|
||||
},
|
||||
}
|
||||
|
||||
// PerformanceResult holds the results of a performance test
|
||||
type PerformanceResult struct {
|
||||
TestFunction string
|
||||
TotalTime time.Duration
|
||||
AverageTime time.Duration
|
||||
MinTime time.Duration
|
||||
MaxTime time.Duration
|
||||
SuccessfulIterations int
|
||||
FailedIterations int
|
||||
Iterations []time.Duration
|
||||
}
|
||||
|
||||
func runPerformanceTest(ctx context.Context, ds *mysql.Datastore, testFuncName string, iterations int, verbose bool) *PerformanceResult {
|
||||
testFunc, exists := testFunctions[testFuncName]
|
||||
if !exists {
|
||||
fmt.Printf("Unknown test function: %s\n", testFuncName)
|
||||
fmt.Printf("Available functions: %v\n", getTestFunctionNames())
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Running %d iterations of %s...\n", iterations, testFuncName)
|
||||
|
||||
result := &PerformanceResult{
|
||||
TestFunction: testFuncName,
|
||||
Iterations: make([]time.Duration, 0, iterations),
|
||||
}
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
start := time.Now()
|
||||
|
||||
if err := testFunc(ctx, ds); err != nil {
|
||||
log.Printf("Iteration %d failed: %v", i+1, err)
|
||||
result.FailedIterations++
|
||||
continue
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
result.Iterations = append(result.Iterations, duration)
|
||||
result.TotalTime += duration
|
||||
result.SuccessfulIterations++
|
||||
|
||||
if verbose {
|
||||
fmt.Printf("Iteration %d: %v\n", i+1, duration)
|
||||
} else {
|
||||
fmt.Printf(".")
|
||||
}
|
||||
}
|
||||
|
||||
if !verbose {
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
|
||||
if result.SuccessfulIterations == 0 {
|
||||
fmt.Printf("All iterations failed!\n")
|
||||
return result
|
||||
}
|
||||
|
||||
// Calculate statistics
|
||||
result.AverageTime = result.TotalTime / time.Duration(result.SuccessfulIterations)
|
||||
|
||||
// Find min and max
|
||||
result.MinTime = result.Iterations[0]
|
||||
result.MaxTime = result.Iterations[0]
|
||||
for _, duration := range result.Iterations {
|
||||
if duration < result.MinTime {
|
||||
result.MinTime = duration
|
||||
}
|
||||
if duration > result.MaxTime {
|
||||
result.MaxTime = duration
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func printResults(results []*PerformanceResult, showDetails bool) {
|
||||
if len(results) == 0 {
|
||||
fmt.Printf("No results to display\n")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Print("\n" + strings.Repeat("=", 80) + "\n")
|
||||
fmt.Printf("PERFORMANCE TEST RESULTS\n")
|
||||
fmt.Print(strings.Repeat("=", 80) + "\n\n")
|
||||
|
||||
for _, result := range results {
|
||||
if result == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Function: %s\n", result.TestFunction)
|
||||
fmt.Printf(" Total time: %v\n", result.TotalTime)
|
||||
fmt.Printf(" Average time: %v\n", result.AverageTime)
|
||||
fmt.Printf(" Min time: %v\n", result.MinTime)
|
||||
fmt.Printf(" Max time: %v\n", result.MaxTime)
|
||||
fmt.Printf(" Success rate: %d/%d (%.1f%%)\n",
|
||||
result.SuccessfulIterations,
|
||||
result.SuccessfulIterations+result.FailedIterations,
|
||||
float64(result.SuccessfulIterations)/float64(result.SuccessfulIterations+result.FailedIterations)*100)
|
||||
|
||||
if showDetails && len(result.Iterations) > 0 {
|
||||
fmt.Printf(" All times: ")
|
||||
for i, duration := range result.Iterations {
|
||||
if i > 0 {
|
||||
fmt.Printf(", ")
|
||||
}
|
||||
fmt.Printf("%v", duration)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
// Calculate percentiles
|
||||
sortedTimes := make([]time.Duration, len(result.Iterations))
|
||||
copy(sortedTimes, result.Iterations)
|
||||
sort.Slice(sortedTimes, func(i, j int) bool {
|
||||
return sortedTimes[i] < sortedTimes[j]
|
||||
})
|
||||
|
||||
if len(sortedTimes) >= 2 {
|
||||
p50 := sortedTimes[len(sortedTimes)/2]
|
||||
p90 := sortedTimes[int(float64(len(sortedTimes))*0.9)]
|
||||
p99 := sortedTimes[int(float64(len(sortedTimes))*0.99)]
|
||||
fmt.Printf(" P50: %v\n", p50)
|
||||
fmt.Printf(" P90: %v\n", p90)
|
||||
fmt.Printf(" P99: %v\n", p99)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func getTestFunctionNames() []string {
|
||||
var names []string
|
||||
for name := range testFunctions {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
testFuncs = flag.String("funcs", "UpdateVulnerabilityHostCounts", "Comma-separated list of test functions to run")
|
||||
iterations = flag.Int("iterations", 5, "Number of iterations per test function")
|
||||
verbose = flag.Bool("verbose", false, "Show timing for each iteration")
|
||||
details = flag.Bool("details", false, "Show detailed statistics including percentiles")
|
||||
listFuncs = flag.Bool("list", false, "List available test functions and exit")
|
||||
help = flag.Bool("help", false, "Show help information")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *listFuncs {
|
||||
fmt.Printf("Available test functions:\n")
|
||||
for _, name := range getTestFunctionNames() {
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if *help {
|
||||
fmt.Printf("Fleet Datastore Performance Tester\n\n")
|
||||
fmt.Printf("This tool measures the performance of Fleet datastore methods.\n")
|
||||
fmt.Printf("It assumes test data has already been seeded using the data seeding tool.\n\n")
|
||||
fmt.Printf("Available test functions:\n")
|
||||
for _, name := range getTestFunctionNames() {
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
fmt.Printf("\nExamples:\n")
|
||||
fmt.Printf(" %s -funcs=UpdateVulnerabilityHostCounts -iterations=10\n", os.Args[0])
|
||||
fmt.Printf(" %s -funcs=UpdateVulnerabilityHostCounts -iterations=5 -details\n", os.Args[0])
|
||||
fmt.Printf(" %s -funcs=UpdateVulnerabilityHostCounts -verbose\n", os.Args[0])
|
||||
fmt.Printf("\n")
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect to datastore
|
||||
ds, err := mysql.New(config.MysqlConfig{
|
||||
Protocol: "tcp",
|
||||
Address: mysqlAddr,
|
||||
Username: mysqlUser,
|
||||
Password: mysqlPass,
|
||||
Database: mysqlDB,
|
||||
}, clock.C)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer func() { _ = ds.Close() }()
|
||||
|
||||
// Parse test functions
|
||||
funcNames := strings.Split(*testFuncs, ",")
|
||||
var results []*PerformanceResult
|
||||
|
||||
for _, funcName := range funcNames {
|
||||
funcName = strings.TrimSpace(funcName)
|
||||
if funcName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := runPerformanceTest(ctx, ds, funcName, *iterations, *verbose)
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
printResults(results, *details)
|
||||
}
|
||||
Loading…
Reference in a new issue