fleet/server/datastore/mysql/mysql_test.go
Nico b40fa26e2e
Follow-up changes to observer live query bypass (#41146)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #36093

This is a follow-up of https://github.com/fleetdm/fleet/pull/40717

# Checklist for submitter

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

## Testing

- [x] Added/updated automated tests

- [x] QA'd all new/changed functionality manually

Verified that the manual test cases I described in
https://github.com/fleetdm/fleet/pull/40717 still pass.

Used the following setup:
- 1 host on Servers.
- 1 host on Servers (canary).
- 9999 hosts on Unassigned.

<img width="1292" height="448" alt="Screenshot 2026-03-10 at 9 41 33 PM"
src="https://github.com/user-attachments/assets/37ba2ad9-aa7b-4d40-b134-56a943e2635c"
/>


Users:
- Team user with these assignments for test cases 1 and 2.

<img width="570" height="269" alt="Screenshot 2026-03-10 at 9 42 41 PM"
src="https://github.com/user-attachments/assets/f4bcf180-b7cc-4d80-a727-26ce887cbe84"
/>

- Global observer user for test cases 3 to 5.

### Test case 1

Report on Workstations (canary) with observers_can_run=true

<img width="470" height="538" alt="Screenshot 2026-03-10 at 9 42 30 PM"
src="https://github.com/user-attachments/assets/11c02ee9-c6eb-463a-9d4b-168a6155feed"
/>

Tested that I'm only able to target that host using "All hosts", "macOS"
and other labels. Also, searching for specific hosts under "Target
specific hosts" only retrieves that host.



https://github.com/user-attachments/assets/150d986a-b4f2-49ab-86d9-0308685873eb

### Test case 2

Confirmed that I'm not able to target `perf-host-1` from `Servers
(canary)` using a manual label with the same report above.
For this, I created a manual label and assigned only to `perf-host-1`:

<img width="603" height="349" alt="Screenshot 2026-03-10 at 9 50 52 PM"
src="https://github.com/user-attachments/assets/98b4a27a-4e46-466e-a377-622d36903feb"
/>

Note that 0 hosts are targeted and **Run** is disabled:
<img width="950" height="814" alt="Screenshot 2026-03-10 at 9 52 26 PM"
src="https://github.com/user-attachments/assets/3b42c0e9-3005-40cc-8733-85b9b729ce89"
/>

### Test case 3

Accessed same report in `Workstations (canary)` above with a Global
Observer user.
Confirmed that no hosts can be targeted in any way:

<img width="977" height="649" alt="Screenshot 2026-03-11 at 8 29 26 AM"
src="https://github.com/user-attachments/assets/ac87ac7e-3097-4228-a724-1d9324dec504"
/>
<img width="986" height="746" alt="Screenshot 2026-03-11 at 8 30 06 AM"
src="https://github.com/user-attachments/assets/5ca592d2-be8c-43c0-8a27-d18fdee35442"
/>
<img width="1017" height="812" alt="Screenshot 2026-03-11 at 8 30 12 AM"
src="https://github.com/user-attachments/assets/fb92940d-3ab2-4136-9e04-825f2c5eb3fe"
/>
<img width="998" height="809" alt="Screenshot 2026-03-11 at 8 30 17 AM"
src="https://github.com/user-attachments/assets/67cc9c0a-e1aa-49df-ad68-1988d6471d32"
/>
<img width="1444" height="311" alt="Screenshot 2026-03-11 at 8 30 35 AM"
src="https://github.com/user-attachments/assets/4b725bf1-0d6d-4458-840e-a96666a34903"
/>
<img width="1444" height="303" alt="Screenshot 2026-03-11 at 8 30 42 AM"
src="https://github.com/user-attachments/assets/54a9cd65-90f5-4454-a713-334e23118295"
/>

### Test case 4

As a global observer, accessing a global report with
observers_can_run=true, I can target all the hosts across all teams.

<img width="951" height="640" alt="Screenshot 2026-03-11 at 8 34 58 AM"
src="https://github.com/user-attachments/assets/3c235b3d-acd5-4801-834f-6fe6cd67d3dd"
/>
<img width="1448" height="527" alt="Screenshot 2026-03-11 at 8 35 06 AM"
src="https://github.com/user-attachments/assets/0f5f663d-8597-4320-aceb-ee6f168ec552"
/>
<img width="1474" height="179" alt="Screenshot 2026-03-11 at 8 35 14 AM"
src="https://github.com/user-attachments/assets/042eda04-e7f6-4c21-9503-878a23435fcd"
/>
 
### Test case 5

With the same report from test case 4, but observers_can_run=false, I
can't target any hosts.

<img width="971" height="804" alt="Screenshot 2026-03-11 at 8 36 49 AM"
src="https://github.com/user-attachments/assets/3a3a9fe3-a159-4ef9-8b08-4c987b9c0828"
/>
<img width="967" height="813" alt="Screenshot 2026-03-11 at 8 37 00 AM"
src="https://github.com/user-attachments/assets/aba5588d-dd96-4b88-9911-ebdd743bfa65"
/>
2026-03-11 13:42:33 -03:00

1640 lines
45 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"math/big"
"net"
"os"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/VividCortex/mysqlerr"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
"github.com/fleetdm/fleet/v4/server/fleet"
common_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
"github.com/fleetdm/fleet/v4/server/platform/mysql/testing_utils"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDatastoreReplica(t *testing.T) {
// a bit unfortunate to create temp databases just for this - could be mixed
// with other tests when/if we move to subtests to minimize the number of
// databases created for tests (see #1805).
ctx := context.Background()
t.Run("noreplica", func(t *testing.T) {
ds := CreateMySQLDSWithOptions(t, nil)
defer ds.Close()
require.Equal(t, ds.reader(ctx), ds.writer(ctx))
})
t.Run("replica", func(t *testing.T) {
opts := &testing_utils.DatastoreTestOptions{DummyReplica: true}
ds := CreateMySQLDSWithOptions(t, opts)
defer ds.Close()
require.NotEqual(t, ds.reader(ctx), ds.writer(ctx))
// create a new host
host, err := ds.NewHost(ctx, &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("1"),
UUID: "1",
Hostname: "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
})
require.NoError(t, err)
require.NotNil(t, host)
// trying to read it fails, not replicated yet
_, err = ds.Host(ctx, host.ID)
require.Error(t, err)
require.True(t, errors.Is(err, sql.ErrNoRows), err)
// force read from primary works
ctx = ctxdb.RequirePrimary(ctx, true)
got, err := ds.Host(ctx, host.ID)
require.NoError(t, err)
require.Equal(t, host.ID, got.ID)
// but from replica still fails
ctx = ctxdb.RequirePrimary(ctx, false)
_, err = ds.Host(ctx, host.ID)
require.Error(t, err)
require.True(t, errors.Is(err, sql.ErrNoRows))
opts.RunReplication()
// now it can read it from replica
got, err = ds.Host(ctx, host.ID)
require.NoError(t, err)
require.Equal(t, host.ID, got.ID)
})
}
func TestSanitizeColumn(t *testing.T) {
t.Parallel()
testCases := []struct {
input string
output string
}{
{"", ""},
{"foobar-column", "`foobar-column`"},
{"foobar_column", "`foobar_column`"},
{"foobar;column", "`foobarcolumn`"},
{"foobar#", "`foobar`"},
{"foobar*baz", "`foobarbaz`"},
{"....", ""},
{"h.id", "`h`.`id`"},
{"id;delete from hosts", "`iddeletefromhosts`"},
{"select * from foo", "`selectfromfoo`"},
}
for _, tt := range testCases {
t.Run(tt.input, func(t *testing.T) {
require.Equal(t, tt.output, sanitizeColumn(tt.input))
})
}
}
func TestSearchLike(t *testing.T) {
t.Parallel()
testCases := []struct {
inSQL string
inParams []interface{}
match string
columns []string
outSQL string
outParams []interface{}
}{
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{},
match: "foobar",
columns: []string{"hostname"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ?)",
outParams: []interface{}{"%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{3},
match: "foobar",
columns: []string{},
outSQL: "SELECT * FROM HOSTS WHERE TRUE",
outParams: []interface{}{3},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{1},
match: "foobar",
columns: []string{"hostname"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ?)",
outParams: []interface{}{1, "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{1},
match: "foobar",
columns: []string{"hostname", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%foobar%", "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{1},
match: "foobar",
columns: []string{"hostname", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%foobar%", "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
inParams: []interface{}{1},
match: "forty_%",
columns: []string{"ipv4", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%forty\\_\\%%", "%forty\\_\\%%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
inParams: []interface{}{1},
match: "forty_%",
columns: []string{"ipv4", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%forty\\_\\%%", "%forty\\_\\%%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
inParams: []interface{}{1},
match: "a@b.c",
columns: []string{"ipv4", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%a@b.c%", "%a@b.c%"},
},
}
for _, tt := range testCases {
t.Run("", func(t *testing.T) {
sql, params := searchLike(tt.inSQL, tt.inParams, tt.match, tt.columns...)
assert.Equal(t, tt.outSQL, sql)
assert.Equal(t, tt.outParams, params)
})
}
}
func TestHostSearchLike(t *testing.T) {
t.Parallel()
testCases := []struct {
inSQL string
inParams []any
match string
columns []string
outSQL string
outParams []any
}{
{
inSQL: "SELECT * FROM HOSTS h WHERE TRUE",
inParams: []any{},
match: "foobar",
columns: []string{"hostname"},
outSQL: "SELECT * FROM HOSTS h WHERE TRUE AND (hostname LIKE ? OR ( EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))",
outParams: []any{"%foobar%", "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS h WHERE 1=1",
inParams: []any{1},
match: "a@b.c",
columns: []string{"ipv4"},
outSQL: "SELECT * FROM HOSTS h WHERE 1=1 AND (ipv4 LIKE ? OR ( EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))",
outParams: []any{1, "%a@b.c%", "%a@b.c%"},
},
}
for _, tt := range testCases {
t.Run("", func(t *testing.T) {
sql, params := hostSearchLike(tt.inSQL, tt.inParams, tt.match, tt.columns...)
assert.Equal(t, tt.outSQL, sql)
assert.Equal(t, tt.outParams, params)
})
}
}
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
dbmock := sqlx.NewDb(db, "sqlmock")
ds := &Datastore{
primary: dbmock,
replica: dbmock,
logger: slog.New(slog.DiscardHandler),
}
return mock, ds
}
func TestWithRetryTxxSuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
require.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxRollbackSuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("fail"))
mock.ExpectRollback()
require.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxRollbackError(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("fail"))
mock.ExpectRollback().WillReturnError(errors.New("rollback failed"))
require.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxRetrySuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
// Return a retryable error
mock.ExpectExec("SELECT 1").WillReturnError(&mysql.MySQLError{Number: mysqlerr.ER_LOCK_DEADLOCK})
mock.ExpectRollback()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
assert.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxCommitRetrySuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
// Return a retryable error
mock.ExpectCommit().WillReturnError(&mysql.MySQLError{Number: mysqlerr.ER_LOCK_DEADLOCK})
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
assert.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxCommitError(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
// Return a retryable error
mock.ExpectCommit().WillReturnError(errors.New("fail"))
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestAppendListOptionsToSQLSecure(t *testing.T) {
// Test allowlist for mapping order keys to actual column names
testAllowlist := common_mysql.OrderKeyAllowlist{
"name": "name",
}
sql := "SELECT * FROM my_table"
opts := fleet.ListOptions{
OrderKey: "name",
}
actual, _, err := appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
require.NoError(t, err)
expected := "SELECT * FROM my_table ORDER BY name ASC LIMIT 1000000"
assert.Equal(t, expected, actual)
sql = "SELECT * FROM my_table"
opts.OrderDirection = fleet.OrderDescending
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
require.NoError(t, err)
expected = "SELECT * FROM my_table ORDER BY name DESC LIMIT 1000000"
assert.Equal(t, expected, actual)
opts = fleet.ListOptions{
PerPage: 10,
}
sql = "SELECT * FROM my_table"
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
require.NoError(t, err)
expected = "SELECT * FROM my_table LIMIT 10"
assert.Equal(t, expected, actual)
sql = "SELECT * FROM my_table"
opts.Page = 2
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
require.NoError(t, err)
expected = "SELECT * FROM my_table LIMIT 10 OFFSET 20"
assert.Equal(t, expected, actual)
opts = fleet.ListOptions{}
sql = "SELECT * FROM my_table"
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
require.NoError(t, err)
expected = "SELECT * FROM my_table LIMIT 1000000"
assert.Equal(t, expected, actual)
// Test that invalid order key returns an error
opts = fleet.ListOptions{OrderKey: "invalid_column"}
sql = "SELECT * FROM my_table"
_, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
require.Error(t, err)
var invalidKeyErr common_mysql.InvalidOrderKeyError
require.ErrorAs(t, err, &invalidKeyErr)
require.Equal(t, "invalid_column", invalidKeyErr.Key)
}
func TestAppendListOptionsToSQL(t *testing.T) {
sql := "SELECT * FROM my_table"
opts := fleet.ListOptions{
OrderKey: "***name***",
}
actual, _ := appendListOptionsToSQL(sql, &opts)
expected := "SELECT * FROM my_table ORDER BY `name` ASC LIMIT 1000000"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
sql = "SELECT * FROM my_table"
opts.OrderDirection = fleet.OrderDescending
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table ORDER BY `name` DESC LIMIT 1000000"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
opts = fleet.ListOptions{
PerPage: 10,
}
sql = "SELECT * FROM my_table"
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table LIMIT 10"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
sql = "SELECT * FROM my_table"
opts.Page = 2
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table LIMIT 10 OFFSET 20"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
opts = fleet.ListOptions{}
sql = "SELECT * FROM my_table"
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table LIMIT 1000000"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
}
func TestWhereFilterHostsByTeams(t *testing.T) {
t.Parallel()
testCases := []struct {
filter fleet.TeamFilter
expected string
}{
// No teams or global role
{
filter: fleet.TeamFilter{
User: &fleet.User{},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{Teams: []fleet.UserTeam{}},
},
expected: "FALSE",
},
// Global role
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
},
expected: "TRUE",
},
// Team roles
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
},
},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
},
},
IncludeObserver: true,
},
expected: "hosts.team_id IN (1)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
},
},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
},
expected: "hosts.team_id IN (2)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
IncludeObserver: true,
},
expected: "hosts.team_id IN (1,2)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
// Invalid role should be ignored
{Role: "bad", Team: fleet.Team{ID: 37}},
},
},
},
expected: "hosts.team_id IN (2)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
// Invalid role should be ignored
},
},
},
expected: "hosts.team_id IN (2,3)",
},
{
filter: fleet.TeamFilter{
TeamID: ptr.Uint(1),
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
TeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: false,
TeamID: ptr.Uint(1),
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
IncludeObserver: false,
TeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
TeamID: ptr.Uint(3),
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
TeamID: ptr.Uint(2),
},
expected: "hosts.team_id = 2",
},
// ObserverTeamID: restricts observer access to a specific team (e.g. the live query's own team)
{
// Global observer with ObserverTeamID set: only that team's hosts
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
ObserverTeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
// Observer on two teams with ObserverTeamID set: only the specified team
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
},
},
IncludeObserver: true,
ObserverTeamID: ptr.Uint(1),
},
expected: "hosts.team_id IN (1)",
},
{
// Admin on team 3 + observer on teams 1 and 2 with ObserverTeamID=1:
// admin access to team 3 is unaffected; observer access limited to team 1 only
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
},
},
IncludeObserver: true,
ObserverTeamID: ptr.Uint(1),
},
expected: "hosts.team_id IN (1,3)",
},
{
// Observer on team 1, maintainer on team 2, running a team-2 query (ObserverTeamID=2):
// team-1 observer access is excluded because ObserverTeamID=2 != team 1;
// maintainer access on team 2 is unaffected — only team-2 hosts are returned.
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
IncludeObserver: true,
ObserverTeamID: ptr.Uint(2),
},
expected: "hosts.team_id IN (2)",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
sql := ds.whereFilterHostsByTeams(tt.filter, "hosts")
assert.Equal(t, tt.expected, sql)
})
}
}
func TestWhereOmitIDs(t *testing.T) {
t.Parallel()
testCases := []struct {
omits []uint
expected string
}{
{
omits: nil,
expected: "TRUE",
},
{
omits: []uint{},
expected: "TRUE",
},
{
omits: []uint{1, 3, 4},
expected: "id NOT IN (1,3,4)",
},
{
omits: []uint{42},
expected: "id NOT IN (42)",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
sql := ds.whereOmitIDs("id", tt.omits)
assert.Equal(t, tt.expected, sql)
})
}
}
func TestWithRetryTxWithRollback(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxWillRollbackWhenPanic(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
defer func() { recover() }() //nolint:errcheck
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
panic("ROLLBACK")
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithTxWithRollback(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withTx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithTxWillRollbackWhenPanic(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
defer func() { recover() }() //nolint:errcheck
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withTx(context.Background(), func(tx sqlx.ExtContext) error {
panic("ROLLBACK")
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestHealthCheckDetectsReadOnly(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
// Healthy: primary is writable.
mock.ExpectQuery("SELECT @@read_only").
WillReturnRows(sqlmock.NewRows([]string{"@@read_only"}).AddRow(0))
require.NoError(t, ds.HealthCheck())
// Unhealthy: primary is read-only (failover scenario).
mock.ExpectQuery("SELECT @@read_only").
WillReturnRows(sqlmock.NewRows([]string{"@@read_only"}).AddRow(1))
err := ds.HealthCheck()
require.Error(t, err)
assert.Contains(t, err.Error(), "read-only")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestNewReadsPasswordFromDisk(t *testing.T) {
passwordFile, err := os.CreateTemp(t.TempDir(), "*.passwordtest")
require.NoError(t, err)
_, err = passwordFile.WriteString(testing_utils.TestPassword)
require.NoError(t, err)
passwordPath := passwordFile.Name()
require.NoError(t, passwordFile.Close())
dbName := t.Name()
// Create a datastore client in order to run migrations as usual
mysqlConfig := config.MysqlConfig{
Username: testing_utils.TestUsername,
Password: "",
PasswordPath: passwordPath,
Address: testing_utils.TestAddress,
Database: dbName,
}
ds, err := newDSWithConfig(t, dbName, mysqlConfig)
require.NoError(t, err)
defer ds.Close()
require.NoError(t, ds.HealthCheck())
}
func newDSWithConfig(t *testing.T, dbName string, config config.MysqlConfig) (*Datastore, error) {
db, err := sql.Open(
"mysql",
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testing_utils.TestUsername, testing_utils.TestPassword,
testing_utils.TestAddress),
)
require.NoError(t, err)
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
require.NoError(t, err)
ds, err := New(config, clock.NewMockClock(), Logger(slog.New(slog.DiscardHandler)), LimitAttempts(1))
return ds, err
}
func generateTestCert(t *testing.T) (string, string) {
privateKeyCA, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"aa"},
},
NotBefore: time.Now().Add(-1 * time.Duration(24) * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKeyCA.PublicKey, privateKeyCA)
require.NoError(t, err)
publicPem, err := os.CreateTemp(t.TempDir(), "*-ca.pem")
require.NoError(t, err)
require.NoError(t, pem.Encode(publicPem, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
require.NoError(t, publicPem.Close())
keyPem, err := os.CreateTemp(t.TempDir(), "*-key.pem")
require.NoError(t, err)
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKeyCA)
require.NoError(t, pem.Encode(keyPem, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}))
require.NoError(t, keyPem.Close())
return publicPem.Name(), keyPem.Name()
}
func TestNewUsesRegisterTLS(t *testing.T) {
dbName := t.Name()
ca, _ := generateTestCert(t)
cert, key := generateTestCert(t)
mysqlConfig := config.MysqlConfig{
Username: testing_utils.TestUsername,
Password: testing_utils.TestPassword,
Address: testing_utils.TestAddress,
Database: dbName,
TLSCA: ca,
TLSCert: cert,
TLSKey: key,
}
// This fails because the certificate mysql is using is different than the one generated here
_, err := newDSWithConfig(t, dbName, mysqlConfig)
require.Error(t, err)
// TODO: we're using a Regexp because the message is different depending on the version of mysql,
// we should refactor and use different error types instead.
require.Regexp(t, "(x509|tls|EOF)", err.Error())
}
func TestWhereFilterTeams(t *testing.T) {
t.Parallel()
testCases := []struct {
filter fleet.TeamFilter
expected string
}{
// No teams or global role
{
filter: fleet.TeamFilter{User: nil},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: false,
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}},
expected: "t.id IN (1)",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}},
expected: "t.id IN (1)",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
IncludeObserver: true,
},
expected: "t.id IN (1)",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
sql := ds.whereFilterTeams(tt.filter, "t")
assert.Equal(t, tt.expected, sql)
})
}
}
func TestCompareVersions(t *testing.T) {
for _, tc := range []struct {
name string
v1 []int64
v2 []int64
knownUnknowns map[int64]struct{}
expMissing []int64
expUnknown []int64
expEqual bool
}{
{
name: "both-empty",
v1: nil,
v2: nil,
expEqual: true,
},
{
name: "equal",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3},
expEqual: true,
},
{
name: "equal-out-of-order",
v1: []int64{1, 2, 3},
v2: []int64{1, 3, 2},
expEqual: true,
},
{
name: "empty-with-unknown",
v1: nil,
v2: []int64{1},
expEqual: false,
expUnknown: []int64{1},
},
{
name: "empty-with-missing",
v1: []int64{1},
v2: nil,
expEqual: false,
expMissing: []int64{1},
},
{
name: "missing",
v1: []int64{1, 2, 3},
v2: []int64{1, 3},
expMissing: []int64{2},
expEqual: false,
},
{
name: "unknown",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3, 4},
expUnknown: []int64{4},
expEqual: false,
},
{
name: "known-unknown",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3, 4},
knownUnknowns: map[int64]struct{}{
4: {},
},
expEqual: true,
},
{
name: "unknowns",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3, 4, 5},
expUnknown: []int64{5},
knownUnknowns: map[int64]struct{}{
4: {},
},
expEqual: false,
},
{
name: "missing-and-unknown",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 4},
expMissing: []int64{3},
expUnknown: []int64{4},
expEqual: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
missing, unknown, equal := compareVersions(tc.v1, tc.v2, tc.knownUnknowns)
require.Equal(t, tc.expMissing, missing)
require.Equal(t, tc.expUnknown, unknown)
require.Equal(t, tc.expEqual, equal)
})
}
}
func TestDebugs(t *testing.T) {
ds := CreateMySQLDS(t)
status, err := ds.InnoDBStatus(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, status)
processList, err := ds.ProcessList(context.Background())
require.NoError(t, err)
require.Greater(t, len(processList), 0)
}
func TestWantedModesEnabled(t *testing.T) {
ds := CreateMySQLDS(t)
var sqlMode string
err := ds.writer(context.Background()).GetContext(context.Background(), &sqlMode, `SELECT @@SQL_MODE`)
require.NoError(t, err)
require.Contains(t, sqlMode, "ANSI_QUOTES")
require.Contains(t, sqlMode, "ONLY_FULL_GROUP_BY")
}
func Test_buildWildcardMatchPhrase(t *testing.T) {
type args struct {
matchQuery string
}
tests := []struct {
name string
args args
want string
}{
{
name: "",
args: args{matchQuery: "test"},
want: "%test%",
},
{
name: "underscores are escaped",
args: args{matchQuery: "Host_1"},
want: "%Host\\_1%",
},
{
name: "percent are escaped",
args: args{matchQuery: "Host%1"},
want: "%Host\\%1%",
},
{
name: "percent & underscore are escaped",
args: args{matchQuery: "Host_%1"},
want: "%Host\\_\\%1%",
},
{
name: "underscores added for wildcard search are not escaped",
args: args{matchQuery: "Alices MacbookPro"},
want: "%Alice_s MacbookPro%",
},
{
name: "underscores added for wildcard search are not escaped, but underscores in matchQuery are",
args: args{matchQuery: "Alices Macbook_Pro"},
want: "%Alice_s Macbook\\_Pro%",
},
{
name: "multiple occurances of wildcard are not escaped",
args: args{matchQuery: "Alices Macbook_Pro"},
want: "%Alice__s Macbook\\_Pro%",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, buildWildcardMatchPhrase(tt.args.matchQuery), "buildWildcardMatchPhrase(%v)", tt.args.matchQuery)
})
}
}
func TestWhereFilterTeamWithGlobalStats(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
filter fleet.TeamFilter
expected string
}{
// No teams or global role
{
name: "empty user",
filter: fleet.TeamFilter{
User: &fleet.User{},
},
expected: "FALSE",
},
{
name: "empty user teams",
filter: fleet.TeamFilter{
User: &fleet.User{Teams: []fleet.UserTeam{}},
},
expected: "FALSE",
},
// Global role
{
name: "global admin",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
},
expected: "hosts.team_id = 0 AND hosts.global_stats = 1",
},
{
name: "global maintainer",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
},
expected: "hosts.team_id = 0 AND hosts.global_stats = 1",
},
{
name: "global observer",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
},
expected: "FALSE",
},
{
name: "global observer include",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
},
expected: "hosts.team_id = 0 AND hosts.global_stats = 1",
},
// Team roles
{
name: "team observer",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
},
},
},
expected: "FALSE",
},
{
name: "team observer include",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
},
},
IncludeObserver: true,
},
expected: "hosts.team_id IN (1)",
},
{
name: "multi team observer",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
},
},
},
expected: "FALSE",
},
{
name: "multi team maintainer and observer",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
},
expected: "hosts.team_id IN (2)",
},
{
name: "multi team maintainer and observer include",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
IncludeObserver: true,
},
expected: "hosts.team_id IN (1,2)",
},
{
name: "multi team maintainer and observer with invalid role",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
// Invalid role should be ignored
{Role: "bad", Team: fleet.Team{ID: 37}},
},
},
},
expected: "hosts.team_id IN (2)",
},
{
name: "multi team maintainer and observer and admin",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
// Invalid role should be ignored
},
},
},
expected: "hosts.team_id IN (2,3)",
},
{
name: "team id only",
filter: fleet.TeamFilter{
TeamID: ptr.Uint(1),
},
expected: "FALSE",
},
{
name: "team id with observer include",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
TeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
name: "team id with observer exclude",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: false,
TeamID: ptr.Uint(1),
},
expected: "FALSE",
},
{
name: "team id with admin exclude observer",
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
IncludeObserver: false,
TeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
name: "team id not in multiple team roles",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
TeamID: ptr.Uint(3),
},
expected: "FALSE",
},
{
name: "team id in multiple team roles",
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
TeamID: ptr.Uint(2),
},
expected: "hosts.team_id = 2",
},
}
for _, tt := range testCases {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
sql := ds.whereFilterTeamWithGlobalStats(tt.filter, "hosts")
assert.Equal(t, tt.expected, sql)
})
}
}
func TestBatchProcessDB(t *testing.T) {
type testData struct {
id int
value string
}
payload := []interface{}{
&testData{id: 1, value: "a"},
&testData{id: 2, value: "b"},
&testData{id: 3, value: "c"},
}
generateValueArgs := func(item interface{}) (string, []any) {
p := item.(*testData)
valuePart := "(?, ?),"
args := []any{p.id, p.value}
return valuePart, args
}
t.Run("TestEmptyPayload", func(t *testing.T) {
executeBatch := func(valuePart string, args []any) error {
return errors.New("execute shouldn't be called for an empty payload")
}
err := batchProcessDB([]interface{}{}, 1000, generateValueArgs, executeBatch)
require.NoError(t, err)
})
t.Run("TestSingleBatch", func(t *testing.T) {
callCount := 0
executeBatch := func(valuePart string, args []any) error {
callCount++
require.Equal(t, 2, len(args)/2) // each item adds 2 args
return nil
}
err := batchProcessDB(payload[:2], 2, generateValueArgs, executeBatch)
require.NoError(t, err)
require.Equal(t, 1, callCount)
})
t.Run("TestMultipleBatches", func(t *testing.T) {
callCount := 0
executeBatch := func(valuePart string, args []any) error {
callCount++
require.Equal(t, 2/callCount, len(args)/2) // each item adds 2 args
return nil
}
err := batchProcessDB(payload, 2, generateValueArgs, executeBatch)
require.NoError(t, err)
require.Equal(t, 2, callCount)
})
}
func TestGetContextTryStmt(t *testing.T) {
ctx := context.Background()
dbMock, ds := mockDatastore(t)
ds.stmtCache = map[string]*sqlx.Stmt{}
t.Run("get with unknown statement error", func(t *testing.T) {
count := 0
query := "SELECT 1"
// first call to cache the statement
dbMock.ExpectPrepare(query)
mockResult := sqlmock.NewRows([]string{query})
mockResult.AddRow("1")
dbMock.ExpectQuery(query).WillReturnRows(mockResult)
err := ds.getContextTryStmt(ctx, &count, query)
require.NoError(t, err)
require.NoError(t, dbMock.ExpectationsWereMet())
// verify that the statement was cached
stmt := ds.loadOrPrepareStmt(ctx, query)
require.NotNil(t, stmt)
// call again to trigger the unknown statement error and ensure it retries
// first query, make it fail
queryMock := dbMock.ExpectQuery(query)
mySQLErr := &mysql.MySQLError{
Number: mysqlerr.ER_UNKNOWN_STMT_HANDLER,
}
queryMock.WillReturnError(mySQLErr)
// after the failure, a second call is made, this time without
// the prepared statement
mockResult = sqlmock.NewRows([]string{query})
mockResult.AddRow("1")
dbMock.ExpectQuery(query).WillReturnRows(mockResult)
// make the call and verify we removed the prepared statement
err = ds.getContextTryStmt(ctx, &count, query)
require.NoError(t, err)
require.NoError(t, dbMock.ExpectationsWereMet())
stmt = ds.loadOrPrepareStmt(ctx, query)
require.Nil(t, stmt)
})
t.Run("get with other error", func(t *testing.T) {
dbMock, ds := mockDatastore(t)
ds.stmtCache = map[string]*sqlx.Stmt{}
count := 0
query := "SELECT 1"
// first call to cache the statement
dbMock.ExpectPrepare(query)
mockResult := sqlmock.NewRows([]string{query})
mockResult.AddRow("1")
dbMock.ExpectQuery(query).WillReturnRows(mockResult)
err := ds.getContextTryStmt(ctx, &count, query)
require.NoError(t, err)
require.Equal(t, 1, count)
require.NoError(t, dbMock.ExpectationsWereMet())
// verify that the statement was cached
stmt := ds.loadOrPrepareStmt(ctx, query)
require.NotNil(t, stmt)
// return a duplicate error
queryMock := dbMock.ExpectQuery(query)
mySQLErr := &mysql.MySQLError{
Number: mysqlerr.ER_DUP_ENTRY,
}
queryMock.WillReturnError(mySQLErr)
count = 0
err = ds.getContextTryStmt(ctx, &count, query)
require.ErrorIs(t, mySQLErr, err)
require.NoError(t, dbMock.ExpectationsWereMet())
stmt = ds.loadOrPrepareStmt(ctx, query)
require.NotNil(t, stmt)
})
}
// createTestDatabase creates a temporary test database with the given name and
// registers cleanup logic to drop it and close the connection when the test
// completes.
func createTestDatabase(t *testing.T, dbName string) {
t.Helper()
db, err := sql.Open(
"mysql",
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testing_utils.TestUsername, testing_utils.TestPassword,
testing_utils.TestAddress),
)
require.NoError(t, err)
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
require.NoError(t, err)
t.Cleanup(func() {
_, _ = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName))
db.Close()
})
}
// TestReplicaPasswordReadFromDisk verifies that when a replica config uses PasswordPath,
// the password read from disk by checkAndModifyConfig is preserved for the actual DB connection.
//
// This is a regression test for https://github.com/fleetdm/fleet/pull/39689.
// Before the fix, NewDBConnections called fromCommonMysqlConfig twice for the replica config,
// and the second call created a fresh config where Password was empty (the mutation from
// checkAndModifyConfig reading PasswordPath was lost). This caused the replica to connect
// with an empty password instead of the one from disk.
func TestReplicaPasswordReadFromDisk(t *testing.T) {
// Write the correct password to a temp file.
passwordFile, err := os.CreateTemp(t.TempDir(), "*.passwordtest")
require.NoError(t, err)
_, err = passwordFile.WriteString(testing_utils.TestPassword)
require.NoError(t, err)
require.NoError(t, passwordFile.Close())
dbName := t.Name()
// Create the test database using a direct connection.
createTestDatabase(t, dbName)
// Primary config uses Password directly — this always works.
primaryConfig := config.MysqlConfig{
Username: testing_utils.TestUsername,
Password: testing_utils.TestPassword,
Address: testing_utils.TestAddress,
Database: dbName,
}
// Replica config uses PasswordPath instead of Password.
// checkAndModifyConfig must read the file and set Password on the config
// that is later passed to NewDB. Before the fix the mutation was discarded.
replicaConfig := config.MysqlConfig{
Username: testing_utils.TestUsername,
PasswordPath: passwordFile.Name(),
Address: testing_utils.TestAddress,
Database: dbName,
}
conns, err := NewDBConnections(primaryConfig, Replica(&replicaConfig), LimitAttempts(1), Logger(slog.New(slog.DiscardHandler)))
require.NoError(t, err, "replica connection should succeed when PasswordPath is used — "+
"if this fails with 'Access denied' the password read from disk was not preserved for the replica")
defer conns.Primary.Close()
defer conns.Replica.Close()
// Verify the replica connection actually works.
require.NoError(t, conns.Replica.Ping())
}
// TestReplicaTLSConfigPreserved verifies that when a replica config has TLSCA set,
// the TLSConfig="custom" mutation from checkAndModifyConfig is preserved so the
// replica actually connects with TLS.
//
// This is a regression test for https://github.com/fleetdm/fleet/pull/39689.
// Before the fix, NewDBConnections called fromCommonMysqlConfig twice for the replica config.
// The first call's config got TLSConfig="custom" via checkAndModifyConfig, but that config was
// block-scoped and discarded. The second call created a fresh config where TLSConfig was
// empty, so the replica silently connected without TLS.
func TestReplicaTLSConfigPreserved(t *testing.T) {
dbName := t.Name()
// Create the test database using a direct connection.
createTestDatabase(t, dbName)
// Generate a test CA cert that does NOT match the MySQL server's cert.
ca, _ := generateTestCert(t)
cert, key := generateTestCert(t)
// Primary config without TLS — connects normally.
primaryConfig := config.MysqlConfig{
Username: testing_utils.TestUsername,
Password: testing_utils.TestPassword,
Address: testing_utils.TestAddress,
Database: dbName,
}
// Replica config with TLS — checkAndModifyConfig should set TLSConfig="custom"
// and register the TLS profile. When NewDB builds the DSN it must include
// tls=custom so the driver actually uses TLS with our test CA.
replicaConfig := config.MysqlConfig{
Username: testing_utils.TestUsername,
Password: testing_utils.TestPassword,
Address: testing_utils.TestAddress,
Database: dbName,
TLSCA: ca,
TLSCert: cert,
TLSKey: key,
}
// After the fix the replica connection attempt uses TLS with our test CA,
// which doesn't match the MySQL server's certificate, so we expect a TLS error.
//
// Before the fix TLSConfig was empty for the replica's NewDB call, so the
// replica connected without TLS (no error) — meaning this assertion would fail.
_, err := NewDBConnections(primaryConfig, Replica(&replicaConfig), LimitAttempts(1), Logger(slog.New(slog.DiscardHandler)))
require.Error(t, err, "replica connection should fail with a TLS error when TLSCA is set — "+
"if this succeeds, TLS was silently not applied to the replica")
require.Regexp(t, "(x509|tls|EOF)", err.Error())
// Ensure the failure is due to the TLS handshake/connection and not TLS config registration.
assert.NotContains(t, err.Error(), "RegisterTLSConfig")
}