mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #36093 This is a follow-up of https://github.com/fleetdm/fleet/pull/40717 # Checklist for submitter - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually Verified that the manual test cases I described in https://github.com/fleetdm/fleet/pull/40717 still pass. Used the following setup: - 1 host on Servers. - 1 host on Servers (canary). - 9999 hosts on Unassigned. <img width="1292" height="448" alt="Screenshot 2026-03-10 at 9 41 33 PM" src="https://github.com/user-attachments/assets/37ba2ad9-aa7b-4d40-b134-56a943e2635c" /> Users: - Team user with these assignments for test cases 1 and 2. <img width="570" height="269" alt="Screenshot 2026-03-10 at 9 42 41 PM" src="https://github.com/user-attachments/assets/f4bcf180-b7cc-4d80-a727-26ce887cbe84" /> - Global observer user for test cases 3 to 5. ### Test case 1 Report on Workstations (canary) with observers_can_run=true <img width="470" height="538" alt="Screenshot 2026-03-10 at 9 42 30 PM" src="https://github.com/user-attachments/assets/11c02ee9-c6eb-463a-9d4b-168a6155feed" /> Tested that I'm only able to target that host using "All hosts", "macOS" and other labels. Also, searching for specific hosts under "Target specific hosts" only retrieves that host. https://github.com/user-attachments/assets/150d986a-b4f2-49ab-86d9-0308685873eb ### Test case 2 Confirmed that I'm not able to target `perf-host-1` from `Servers (canary)` using a manual label with the same report above. For this, I created a manual label and assigned only to `perf-host-1`: <img width="603" height="349" alt="Screenshot 2026-03-10 at 9 50 52 PM" src="https://github.com/user-attachments/assets/98b4a27a-4e46-466e-a377-622d36903feb" /> Note that 0 hosts are targeted and **Run** is disabled: <img width="950" height="814" alt="Screenshot 2026-03-10 at 9 52 26 PM" src="https://github.com/user-attachments/assets/3b42c0e9-3005-40cc-8733-85b9b729ce89" /> ### Test case 3 Accessed same report in `Workstations (canary)` above with a Global Observer user. Confirmed that no hosts can be targeted in any way: <img width="977" height="649" alt="Screenshot 2026-03-11 at 8 29 26 AM" src="https://github.com/user-attachments/assets/ac87ac7e-3097-4228-a724-1d9324dec504" /> <img width="986" height="746" alt="Screenshot 2026-03-11 at 8 30 06 AM" src="https://github.com/user-attachments/assets/5ca592d2-be8c-43c0-8a27-d18fdee35442" /> <img width="1017" height="812" alt="Screenshot 2026-03-11 at 8 30 12 AM" src="https://github.com/user-attachments/assets/fb92940d-3ab2-4136-9e04-825f2c5eb3fe" /> <img width="998" height="809" alt="Screenshot 2026-03-11 at 8 30 17 AM" src="https://github.com/user-attachments/assets/67cc9c0a-e1aa-49df-ad68-1988d6471d32" /> <img width="1444" height="311" alt="Screenshot 2026-03-11 at 8 30 35 AM" src="https://github.com/user-attachments/assets/4b725bf1-0d6d-4458-840e-a96666a34903" /> <img width="1444" height="303" alt="Screenshot 2026-03-11 at 8 30 42 AM" src="https://github.com/user-attachments/assets/54a9cd65-90f5-4454-a713-334e23118295" /> ### Test case 4 As a global observer, accessing a global report with observers_can_run=true, I can target all the hosts across all teams. <img width="951" height="640" alt="Screenshot 2026-03-11 at 8 34 58 AM" src="https://github.com/user-attachments/assets/3c235b3d-acd5-4801-834f-6fe6cd67d3dd" /> <img width="1448" height="527" alt="Screenshot 2026-03-11 at 8 35 06 AM" src="https://github.com/user-attachments/assets/0f5f663d-8597-4320-aceb-ee6f168ec552" /> <img width="1474" height="179" alt="Screenshot 2026-03-11 at 8 35 14 AM" src="https://github.com/user-attachments/assets/042eda04-e7f6-4c21-9503-878a23435fcd" /> ### Test case 5 With the same report from test case 4, but observers_can_run=false, I can't target any hosts. <img width="971" height="804" alt="Screenshot 2026-03-11 at 8 36 49 AM" src="https://github.com/user-attachments/assets/3a3a9fe3-a159-4ef9-8b08-4c987b9c0828" /> <img width="967" height="813" alt="Screenshot 2026-03-11 at 8 37 00 AM" src="https://github.com/user-attachments/assets/aba5588d-dd96-4b88-9911-ebdd743bfa65" />
1640 lines
45 KiB
Go
1640 lines
45 KiB
Go
package mysql
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/x509"
|
||
"crypto/x509/pkix"
|
||
"database/sql"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
"log/slog"
|
||
"math/big"
|
||
"net"
|
||
"os"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/DATA-DOG/go-sqlmock"
|
||
"github.com/VividCortex/mysqlerr"
|
||
"github.com/WatchBeam/clock"
|
||
"github.com/fleetdm/fleet/v4/server/config"
|
||
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
|
||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||
common_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
|
||
"github.com/fleetdm/fleet/v4/server/platform/mysql/testing_utils"
|
||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||
"github.com/go-sql-driver/mysql"
|
||
"github.com/jmoiron/sqlx"
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
func TestDatastoreReplica(t *testing.T) {
|
||
// a bit unfortunate to create temp databases just for this - could be mixed
|
||
// with other tests when/if we move to subtests to minimize the number of
|
||
// databases created for tests (see #1805).
|
||
|
||
ctx := context.Background()
|
||
t.Run("noreplica", func(t *testing.T) {
|
||
ds := CreateMySQLDSWithOptions(t, nil)
|
||
defer ds.Close()
|
||
require.Equal(t, ds.reader(ctx), ds.writer(ctx))
|
||
})
|
||
|
||
t.Run("replica", func(t *testing.T) {
|
||
opts := &testing_utils.DatastoreTestOptions{DummyReplica: true}
|
||
ds := CreateMySQLDSWithOptions(t, opts)
|
||
defer ds.Close()
|
||
require.NotEqual(t, ds.reader(ctx), ds.writer(ctx))
|
||
|
||
// create a new host
|
||
host, err := ds.NewHost(ctx, &fleet.Host{
|
||
DetailUpdatedAt: time.Now(),
|
||
LabelUpdatedAt: time.Now(),
|
||
PolicyUpdatedAt: time.Now(),
|
||
SeenTime: time.Now(),
|
||
NodeKey: ptr.String("1"),
|
||
UUID: "1",
|
||
Hostname: "foo.local",
|
||
PrimaryIP: "192.168.1.1",
|
||
PrimaryMac: "30-65-EC-6F-C4-58",
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, host)
|
||
|
||
// trying to read it fails, not replicated yet
|
||
_, err = ds.Host(ctx, host.ID)
|
||
require.Error(t, err)
|
||
require.True(t, errors.Is(err, sql.ErrNoRows), err)
|
||
|
||
// force read from primary works
|
||
ctx = ctxdb.RequirePrimary(ctx, true)
|
||
got, err := ds.Host(ctx, host.ID)
|
||
require.NoError(t, err)
|
||
require.Equal(t, host.ID, got.ID)
|
||
|
||
// but from replica still fails
|
||
ctx = ctxdb.RequirePrimary(ctx, false)
|
||
_, err = ds.Host(ctx, host.ID)
|
||
require.Error(t, err)
|
||
require.True(t, errors.Is(err, sql.ErrNoRows))
|
||
|
||
opts.RunReplication()
|
||
|
||
// now it can read it from replica
|
||
got, err = ds.Host(ctx, host.ID)
|
||
require.NoError(t, err)
|
||
require.Equal(t, host.ID, got.ID)
|
||
})
|
||
}
|
||
|
||
func TestSanitizeColumn(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
input string
|
||
output string
|
||
}{
|
||
{"", ""},
|
||
{"foobar-column", "`foobar-column`"},
|
||
{"foobar_column", "`foobar_column`"},
|
||
{"foobar;column", "`foobarcolumn`"},
|
||
{"foobar#", "`foobar`"},
|
||
{"foobar*baz", "`foobarbaz`"},
|
||
{"....", ""},
|
||
{"h.id", "`h`.`id`"},
|
||
{"id;delete from hosts", "`iddeletefromhosts`"},
|
||
{"select * from foo", "`selectfromfoo`"},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
t.Run(tt.input, func(t *testing.T) {
|
||
require.Equal(t, tt.output, sanitizeColumn(tt.input))
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestSearchLike(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
inSQL string
|
||
inParams []interface{}
|
||
match string
|
||
columns []string
|
||
outSQL string
|
||
outParams []interface{}
|
||
}{
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
|
||
inParams: []interface{}{},
|
||
match: "foobar",
|
||
columns: []string{"hostname"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ?)",
|
||
outParams: []interface{}{"%foobar%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
|
||
inParams: []interface{}{3},
|
||
match: "foobar",
|
||
columns: []string{},
|
||
outSQL: "SELECT * FROM HOSTS WHERE TRUE",
|
||
outParams: []interface{}{3},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
|
||
inParams: []interface{}{1},
|
||
match: "foobar",
|
||
columns: []string{"hostname"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ?)",
|
||
outParams: []interface{}{1, "%foobar%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
|
||
inParams: []interface{}{1},
|
||
match: "foobar",
|
||
columns: []string{"hostname", "uuid"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ? OR uuid LIKE ?)",
|
||
outParams: []interface{}{1, "%foobar%", "%foobar%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
|
||
inParams: []interface{}{1},
|
||
match: "foobar",
|
||
columns: []string{"hostname", "uuid"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ? OR uuid LIKE ?)",
|
||
outParams: []interface{}{1, "%foobar%", "%foobar%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
|
||
inParams: []interface{}{1},
|
||
match: "forty_%",
|
||
columns: []string{"ipv4", "uuid"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
|
||
outParams: []interface{}{1, "%forty\\_\\%%", "%forty\\_\\%%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
|
||
inParams: []interface{}{1},
|
||
match: "forty_%",
|
||
columns: []string{"ipv4", "uuid"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
|
||
outParams: []interface{}{1, "%forty\\_\\%%", "%forty\\_\\%%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
|
||
inParams: []interface{}{1},
|
||
match: "a@b.c",
|
||
columns: []string{"ipv4", "uuid"},
|
||
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
|
||
outParams: []interface{}{1, "%a@b.c%", "%a@b.c%"},
|
||
},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
t.Run("", func(t *testing.T) {
|
||
sql, params := searchLike(tt.inSQL, tt.inParams, tt.match, tt.columns...)
|
||
assert.Equal(t, tt.outSQL, sql)
|
||
assert.Equal(t, tt.outParams, params)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestHostSearchLike(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
inSQL string
|
||
inParams []any
|
||
match string
|
||
columns []string
|
||
outSQL string
|
||
outParams []any
|
||
}{
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS h WHERE TRUE",
|
||
inParams: []any{},
|
||
match: "foobar",
|
||
columns: []string{"hostname"},
|
||
outSQL: "SELECT * FROM HOSTS h WHERE TRUE AND (hostname LIKE ? OR ( EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))",
|
||
outParams: []any{"%foobar%", "%foobar%"},
|
||
},
|
||
{
|
||
inSQL: "SELECT * FROM HOSTS h WHERE 1=1",
|
||
inParams: []any{1},
|
||
match: "a@b.c",
|
||
columns: []string{"ipv4"},
|
||
outSQL: "SELECT * FROM HOSTS h WHERE 1=1 AND (ipv4 LIKE ? OR ( EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))",
|
||
outParams: []any{1, "%a@b.c%", "%a@b.c%"},
|
||
},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
t.Run("", func(t *testing.T) {
|
||
sql, params := hostSearchLike(tt.inSQL, tt.inParams, tt.match, tt.columns...)
|
||
assert.Equal(t, tt.outSQL, sql)
|
||
assert.Equal(t, tt.outParams, params)
|
||
})
|
||
}
|
||
}
|
||
|
||
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
|
||
db, mock, err := sqlmock.New()
|
||
require.NoError(t, err)
|
||
dbmock := sqlx.NewDb(db, "sqlmock")
|
||
ds := &Datastore{
|
||
primary: dbmock,
|
||
replica: dbmock,
|
||
logger: slog.New(slog.DiscardHandler),
|
||
}
|
||
|
||
return mock, ds
|
||
}
|
||
|
||
func TestWithRetryTxxSuccess(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
|
||
mock.ExpectCommit()
|
||
|
||
require.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithRetryTxxRollbackSuccess(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("fail"))
|
||
mock.ExpectRollback()
|
||
|
||
require.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithRetryTxxRollbackError(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("fail"))
|
||
mock.ExpectRollback().WillReturnError(errors.New("rollback failed"))
|
||
|
||
require.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithRetryTxxRetrySuccess(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
// Return a retryable error
|
||
mock.ExpectExec("SELECT 1").WillReturnError(&mysql.MySQLError{Number: mysqlerr.ER_LOCK_DEADLOCK})
|
||
mock.ExpectRollback()
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
|
||
mock.ExpectCommit()
|
||
|
||
assert.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithRetryTxxCommitRetrySuccess(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
|
||
// Return a retryable error
|
||
mock.ExpectCommit().WillReturnError(&mysql.MySQLError{Number: mysqlerr.ER_LOCK_DEADLOCK})
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
|
||
mock.ExpectCommit()
|
||
|
||
assert.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithRetryTxxCommitError(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
|
||
// Return a retryable error
|
||
mock.ExpectCommit().WillReturnError(errors.New("fail"))
|
||
|
||
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestAppendListOptionsToSQLSecure(t *testing.T) {
|
||
// Test allowlist for mapping order keys to actual column names
|
||
testAllowlist := common_mysql.OrderKeyAllowlist{
|
||
"name": "name",
|
||
}
|
||
|
||
sql := "SELECT * FROM my_table"
|
||
opts := fleet.ListOptions{
|
||
OrderKey: "name",
|
||
}
|
||
|
||
actual, _, err := appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
|
||
require.NoError(t, err)
|
||
expected := "SELECT * FROM my_table ORDER BY name ASC LIMIT 1000000"
|
||
assert.Equal(t, expected, actual)
|
||
|
||
sql = "SELECT * FROM my_table"
|
||
opts.OrderDirection = fleet.OrderDescending
|
||
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
|
||
require.NoError(t, err)
|
||
expected = "SELECT * FROM my_table ORDER BY name DESC LIMIT 1000000"
|
||
assert.Equal(t, expected, actual)
|
||
|
||
opts = fleet.ListOptions{
|
||
PerPage: 10,
|
||
}
|
||
|
||
sql = "SELECT * FROM my_table"
|
||
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
|
||
require.NoError(t, err)
|
||
expected = "SELECT * FROM my_table LIMIT 10"
|
||
assert.Equal(t, expected, actual)
|
||
|
||
sql = "SELECT * FROM my_table"
|
||
opts.Page = 2
|
||
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
|
||
require.NoError(t, err)
|
||
expected = "SELECT * FROM my_table LIMIT 10 OFFSET 20"
|
||
assert.Equal(t, expected, actual)
|
||
|
||
opts = fleet.ListOptions{}
|
||
sql = "SELECT * FROM my_table"
|
||
actual, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
|
||
require.NoError(t, err)
|
||
expected = "SELECT * FROM my_table LIMIT 1000000"
|
||
assert.Equal(t, expected, actual)
|
||
|
||
// Test that invalid order key returns an error
|
||
opts = fleet.ListOptions{OrderKey: "invalid_column"}
|
||
sql = "SELECT * FROM my_table"
|
||
_, _, err = appendListOptionsToSQLSecure(sql, &opts, testAllowlist)
|
||
require.Error(t, err)
|
||
var invalidKeyErr common_mysql.InvalidOrderKeyError
|
||
require.ErrorAs(t, err, &invalidKeyErr)
|
||
require.Equal(t, "invalid_column", invalidKeyErr.Key)
|
||
}
|
||
|
||
func TestAppendListOptionsToSQL(t *testing.T) {
|
||
sql := "SELECT * FROM my_table"
|
||
opts := fleet.ListOptions{
|
||
OrderKey: "***name***",
|
||
}
|
||
|
||
actual, _ := appendListOptionsToSQL(sql, &opts)
|
||
expected := "SELECT * FROM my_table ORDER BY `name` ASC LIMIT 1000000"
|
||
if actual != expected {
|
||
t.Error("Expected", expected, "Actual", actual)
|
||
}
|
||
|
||
sql = "SELECT * FROM my_table"
|
||
opts.OrderDirection = fleet.OrderDescending
|
||
actual, _ = appendListOptionsToSQL(sql, &opts)
|
||
expected = "SELECT * FROM my_table ORDER BY `name` DESC LIMIT 1000000"
|
||
if actual != expected {
|
||
t.Error("Expected", expected, "Actual", actual)
|
||
}
|
||
|
||
opts = fleet.ListOptions{
|
||
PerPage: 10,
|
||
}
|
||
|
||
sql = "SELECT * FROM my_table"
|
||
actual, _ = appendListOptionsToSQL(sql, &opts)
|
||
expected = "SELECT * FROM my_table LIMIT 10"
|
||
if actual != expected {
|
||
t.Error("Expected", expected, "Actual", actual)
|
||
}
|
||
|
||
sql = "SELECT * FROM my_table"
|
||
opts.Page = 2
|
||
actual, _ = appendListOptionsToSQL(sql, &opts)
|
||
expected = "SELECT * FROM my_table LIMIT 10 OFFSET 20"
|
||
if actual != expected {
|
||
t.Error("Expected", expected, "Actual", actual)
|
||
}
|
||
|
||
opts = fleet.ListOptions{}
|
||
sql = "SELECT * FROM my_table"
|
||
actual, _ = appendListOptionsToSQL(sql, &opts)
|
||
expected = "SELECT * FROM my_table LIMIT 1000000"
|
||
|
||
if actual != expected {
|
||
t.Error("Expected", expected, "Actual", actual)
|
||
}
|
||
}
|
||
|
||
func TestWhereFilterHostsByTeams(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
filter fleet.TeamFilter
|
||
expected string
|
||
}{
|
||
// No teams or global role
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{Teams: []fleet.UserTeam{}},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
|
||
// Global role
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||
},
|
||
expected: "TRUE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||
},
|
||
expected: "TRUE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "TRUE",
|
||
},
|
||
|
||
// Team roles
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
},
|
||
},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "hosts.team_id IN (1)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
},
|
||
expected: "hosts.team_id IN (2)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "hosts.team_id IN (1,2)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
// Invalid role should be ignored
|
||
{Role: "bad", Team: fleet.Team{ID: 37}},
|
||
},
|
||
},
|
||
},
|
||
expected: "hosts.team_id IN (2)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
|
||
// Invalid role should be ignored
|
||
},
|
||
},
|
||
},
|
||
expected: "hosts.team_id IN (2,3)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: true,
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id = 1",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: false,
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||
IncludeObserver: false,
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id = 1",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
TeamID: ptr.Uint(3),
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
TeamID: ptr.Uint(2),
|
||
},
|
||
expected: "hosts.team_id = 2",
|
||
},
|
||
|
||
// ObserverTeamID: restricts observer access to a specific team (e.g. the live query's own team)
|
||
{
|
||
// Global observer with ObserverTeamID set: only that team's hosts
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: true,
|
||
ObserverTeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id = 1",
|
||
},
|
||
{
|
||
// Observer on two teams with ObserverTeamID set: only the specified team
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
ObserverTeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id IN (1)",
|
||
},
|
||
{
|
||
// Admin on team 3 + observer on teams 1 and 2 with ObserverTeamID=1:
|
||
// admin access to team 3 is unaffected; observer access limited to team 1 only
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
|
||
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
ObserverTeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id IN (1,3)",
|
||
},
|
||
{
|
||
// Observer on team 1, maintainer on team 2, running a team-2 query (ObserverTeamID=2):
|
||
// team-1 observer access is excluded because ObserverTeamID=2 != team 1;
|
||
// maintainer access on team 2 is unaffected — only team-2 hosts are returned.
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
ObserverTeamID: ptr.Uint(2),
|
||
},
|
||
expected: "hosts.team_id IN (2)",
|
||
},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
tt := tt
|
||
t.Run("", func(t *testing.T) {
|
||
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
|
||
sql := ds.whereFilterHostsByTeams(tt.filter, "hosts")
|
||
assert.Equal(t, tt.expected, sql)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWhereOmitIDs(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
omits []uint
|
||
expected string
|
||
}{
|
||
{
|
||
omits: nil,
|
||
expected: "TRUE",
|
||
},
|
||
{
|
||
omits: []uint{},
|
||
expected: "TRUE",
|
||
},
|
||
{
|
||
omits: []uint{1, 3, 4},
|
||
expected: "id NOT IN (1,3,4)",
|
||
},
|
||
{
|
||
omits: []uint{42},
|
||
expected: "id NOT IN (42)",
|
||
},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
tt := tt
|
||
t.Run("", func(t *testing.T) {
|
||
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
|
||
sql := ds.whereOmitIDs("id", tt.omits)
|
||
assert.Equal(t, tt.expected, sql)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWithRetryTxWithRollback(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
|
||
mock.ExpectRollback()
|
||
|
||
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithRetryTxWillRollbackWhenPanic(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
defer func() { recover() }() //nolint:errcheck
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
|
||
mock.ExpectRollback()
|
||
|
||
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
panic("ROLLBACK")
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithTxWithRollback(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
|
||
mock.ExpectRollback()
|
||
|
||
assert.Error(t, ds.withTx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
_, err := tx.ExecContext(context.Background(), "SELECT 1")
|
||
return err
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestWithTxWillRollbackWhenPanic(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
defer func() { recover() }() //nolint:errcheck
|
||
|
||
mock.ExpectBegin()
|
||
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
|
||
mock.ExpectRollback()
|
||
|
||
assert.Error(t, ds.withTx(context.Background(), func(tx sqlx.ExtContext) error {
|
||
panic("ROLLBACK")
|
||
}))
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestHealthCheckDetectsReadOnly(t *testing.T) {
|
||
mock, ds := mockDatastore(t)
|
||
defer ds.Close()
|
||
|
||
// Healthy: primary is writable.
|
||
mock.ExpectQuery("SELECT @@read_only").
|
||
WillReturnRows(sqlmock.NewRows([]string{"@@read_only"}).AddRow(0))
|
||
require.NoError(t, ds.HealthCheck())
|
||
|
||
// Unhealthy: primary is read-only (failover scenario).
|
||
mock.ExpectQuery("SELECT @@read_only").
|
||
WillReturnRows(sqlmock.NewRows([]string{"@@read_only"}).AddRow(1))
|
||
err := ds.HealthCheck()
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "read-only")
|
||
|
||
require.NoError(t, mock.ExpectationsWereMet())
|
||
}
|
||
|
||
func TestNewReadsPasswordFromDisk(t *testing.T) {
|
||
passwordFile, err := os.CreateTemp(t.TempDir(), "*.passwordtest")
|
||
require.NoError(t, err)
|
||
_, err = passwordFile.WriteString(testing_utils.TestPassword)
|
||
require.NoError(t, err)
|
||
passwordPath := passwordFile.Name()
|
||
require.NoError(t, passwordFile.Close())
|
||
|
||
dbName := t.Name()
|
||
|
||
// Create a datastore client in order to run migrations as usual
|
||
mysqlConfig := config.MysqlConfig{
|
||
Username: testing_utils.TestUsername,
|
||
Password: "",
|
||
PasswordPath: passwordPath,
|
||
Address: testing_utils.TestAddress,
|
||
Database: dbName,
|
||
}
|
||
ds, err := newDSWithConfig(t, dbName, mysqlConfig)
|
||
require.NoError(t, err)
|
||
defer ds.Close()
|
||
require.NoError(t, ds.HealthCheck())
|
||
}
|
||
|
||
func newDSWithConfig(t *testing.T, dbName string, config config.MysqlConfig) (*Datastore, error) {
|
||
db, err := sql.Open(
|
||
"mysql",
|
||
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testing_utils.TestUsername, testing_utils.TestPassword,
|
||
testing_utils.TestAddress),
|
||
)
|
||
require.NoError(t, err)
|
||
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
|
||
require.NoError(t, err)
|
||
|
||
ds, err := New(config, clock.NewMockClock(), Logger(slog.New(slog.DiscardHandler)), LimitAttempts(1))
|
||
return ds, err
|
||
}
|
||
|
||
func generateTestCert(t *testing.T) (string, string) {
|
||
privateKeyCA, err := rsa.GenerateKey(rand.Reader, 1024)
|
||
require.NoError(t, err)
|
||
|
||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||
require.NoError(t, err)
|
||
template := x509.Certificate{
|
||
SerialNumber: serialNumber,
|
||
Subject: pkix.Name{
|
||
Organization: []string{"aa"},
|
||
},
|
||
NotBefore: time.Now().Add(-1 * time.Duration(24) * time.Hour),
|
||
NotAfter: time.Now().Add(24 * time.Hour),
|
||
IsCA: true,
|
||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||
BasicConstraintsValid: true,
|
||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||
}
|
||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKeyCA.PublicKey, privateKeyCA)
|
||
require.NoError(t, err)
|
||
|
||
publicPem, err := os.CreateTemp(t.TempDir(), "*-ca.pem")
|
||
require.NoError(t, err)
|
||
require.NoError(t, pem.Encode(publicPem, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||
require.NoError(t, publicPem.Close())
|
||
|
||
keyPem, err := os.CreateTemp(t.TempDir(), "*-key.pem")
|
||
require.NoError(t, err)
|
||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKeyCA)
|
||
require.NoError(t, pem.Encode(keyPem, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}))
|
||
require.NoError(t, keyPem.Close())
|
||
|
||
return publicPem.Name(), keyPem.Name()
|
||
}
|
||
|
||
func TestNewUsesRegisterTLS(t *testing.T) {
|
||
dbName := t.Name()
|
||
|
||
ca, _ := generateTestCert(t)
|
||
cert, key := generateTestCert(t)
|
||
|
||
mysqlConfig := config.MysqlConfig{
|
||
Username: testing_utils.TestUsername,
|
||
Password: testing_utils.TestPassword,
|
||
Address: testing_utils.TestAddress,
|
||
Database: dbName,
|
||
TLSCA: ca,
|
||
TLSCert: cert,
|
||
TLSKey: key,
|
||
}
|
||
// This fails because the certificate mysql is using is different than the one generated here
|
||
_, err := newDSWithConfig(t, dbName, mysqlConfig)
|
||
require.Error(t, err)
|
||
// TODO: we're using a Regexp because the message is different depending on the version of mysql,
|
||
// we should refactor and use different error types instead.
|
||
require.Regexp(t, "(x509|tls|EOF)", err.Error())
|
||
}
|
||
|
||
func TestWhereFilterTeams(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
filter fleet.TeamFilter
|
||
expected string
|
||
}{
|
||
// No teams or global role
|
||
{
|
||
filter: fleet.TeamFilter{User: nil},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||
},
|
||
expected: "TRUE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: false,
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "TRUE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}},
|
||
expected: "t.id IN (1)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}},
|
||
expected: "t.id IN (1)",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "t.id IN (1)",
|
||
},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
tt := tt
|
||
t.Run("", func(t *testing.T) {
|
||
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
|
||
sql := ds.whereFilterTeams(tt.filter, "t")
|
||
assert.Equal(t, tt.expected, sql)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestCompareVersions(t *testing.T) {
|
||
for _, tc := range []struct {
|
||
name string
|
||
|
||
v1 []int64
|
||
v2 []int64
|
||
knownUnknowns map[int64]struct{}
|
||
|
||
expMissing []int64
|
||
expUnknown []int64
|
||
expEqual bool
|
||
}{
|
||
{
|
||
name: "both-empty",
|
||
v1: nil,
|
||
v2: nil,
|
||
expEqual: true,
|
||
},
|
||
{
|
||
name: "equal",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 2, 3},
|
||
expEqual: true,
|
||
},
|
||
{
|
||
name: "equal-out-of-order",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 3, 2},
|
||
expEqual: true,
|
||
},
|
||
{
|
||
name: "empty-with-unknown",
|
||
v1: nil,
|
||
v2: []int64{1},
|
||
expEqual: false,
|
||
expUnknown: []int64{1},
|
||
},
|
||
{
|
||
name: "empty-with-missing",
|
||
v1: []int64{1},
|
||
v2: nil,
|
||
expEqual: false,
|
||
expMissing: []int64{1},
|
||
},
|
||
{
|
||
name: "missing",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 3},
|
||
expMissing: []int64{2},
|
||
expEqual: false,
|
||
},
|
||
{
|
||
name: "unknown",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 2, 3, 4},
|
||
expUnknown: []int64{4},
|
||
expEqual: false,
|
||
},
|
||
{
|
||
name: "known-unknown",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 2, 3, 4},
|
||
knownUnknowns: map[int64]struct{}{
|
||
4: {},
|
||
},
|
||
expEqual: true,
|
||
},
|
||
{
|
||
name: "unknowns",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 2, 3, 4, 5},
|
||
expUnknown: []int64{5},
|
||
knownUnknowns: map[int64]struct{}{
|
||
4: {},
|
||
},
|
||
expEqual: false,
|
||
},
|
||
{
|
||
name: "missing-and-unknown",
|
||
v1: []int64{1, 2, 3},
|
||
v2: []int64{1, 2, 4},
|
||
expMissing: []int64{3},
|
||
expUnknown: []int64{4},
|
||
expEqual: false,
|
||
},
|
||
} {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
missing, unknown, equal := compareVersions(tc.v1, tc.v2, tc.knownUnknowns)
|
||
require.Equal(t, tc.expMissing, missing)
|
||
require.Equal(t, tc.expUnknown, unknown)
|
||
require.Equal(t, tc.expEqual, equal)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestDebugs(t *testing.T) {
|
||
ds := CreateMySQLDS(t)
|
||
|
||
status, err := ds.InnoDBStatus(context.Background())
|
||
require.NoError(t, err)
|
||
assert.NotEmpty(t, status)
|
||
|
||
processList, err := ds.ProcessList(context.Background())
|
||
require.NoError(t, err)
|
||
require.Greater(t, len(processList), 0)
|
||
}
|
||
|
||
func TestWantedModesEnabled(t *testing.T) {
|
||
ds := CreateMySQLDS(t)
|
||
|
||
var sqlMode string
|
||
err := ds.writer(context.Background()).GetContext(context.Background(), &sqlMode, `SELECT @@SQL_MODE`)
|
||
require.NoError(t, err)
|
||
require.Contains(t, sqlMode, "ANSI_QUOTES")
|
||
require.Contains(t, sqlMode, "ONLY_FULL_GROUP_BY")
|
||
}
|
||
|
||
func Test_buildWildcardMatchPhrase(t *testing.T) {
|
||
type args struct {
|
||
matchQuery string
|
||
}
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
want string
|
||
}{
|
||
{
|
||
name: "",
|
||
args: args{matchQuery: "test"},
|
||
want: "%test%",
|
||
},
|
||
{
|
||
name: "underscores are escaped",
|
||
args: args{matchQuery: "Host_1"},
|
||
want: "%Host\\_1%",
|
||
},
|
||
{
|
||
name: "percent are escaped",
|
||
args: args{matchQuery: "Host%1"},
|
||
want: "%Host\\%1%",
|
||
},
|
||
{
|
||
name: "percent & underscore are escaped",
|
||
args: args{matchQuery: "Host_%1"},
|
||
want: "%Host\\_\\%1%",
|
||
},
|
||
{
|
||
name: "underscores added for wildcard search are not escaped",
|
||
args: args{matchQuery: "Alice‘s MacbookPro"},
|
||
want: "%Alice_s MacbookPro%",
|
||
},
|
||
{
|
||
name: "underscores added for wildcard search are not escaped, but underscores in matchQuery are",
|
||
args: args{matchQuery: "Alice‘s Macbook_Pro"},
|
||
want: "%Alice_s Macbook\\_Pro%",
|
||
},
|
||
{
|
||
name: "multiple occurances of wildcard are not escaped",
|
||
args: args{matchQuery: "Alice‘‘s Macbook_Pro"},
|
||
want: "%Alice__s Macbook\\_Pro%",
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
assert.Equalf(t, tt.want, buildWildcardMatchPhrase(tt.args.matchQuery), "buildWildcardMatchPhrase(%v)", tt.args.matchQuery)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestWhereFilterTeamWithGlobalStats(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
testCases := []struct {
|
||
name string
|
||
filter fleet.TeamFilter
|
||
expected string
|
||
}{
|
||
// No teams or global role
|
||
{
|
||
name: "empty user",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "empty user teams",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{Teams: []fleet.UserTeam{}},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
|
||
// Global role
|
||
{
|
||
name: "global admin",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||
},
|
||
expected: "hosts.team_id = 0 AND hosts.global_stats = 1",
|
||
},
|
||
{
|
||
name: "global maintainer",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||
},
|
||
expected: "hosts.team_id = 0 AND hosts.global_stats = 1",
|
||
},
|
||
{
|
||
name: "global observer",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "global observer include",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "hosts.team_id = 0 AND hosts.global_stats = 1",
|
||
},
|
||
|
||
// Team roles
|
||
{
|
||
name: "team observer",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
},
|
||
},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "team observer include",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "hosts.team_id IN (1)",
|
||
},
|
||
{
|
||
name: "multi team observer",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "multi team maintainer and observer",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
},
|
||
expected: "hosts.team_id IN (2)",
|
||
},
|
||
{
|
||
name: "multi team maintainer and observer include",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
IncludeObserver: true,
|
||
},
|
||
expected: "hosts.team_id IN (1,2)",
|
||
},
|
||
{
|
||
name: "multi team maintainer and observer with invalid role",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
// Invalid role should be ignored
|
||
{Role: "bad", Team: fleet.Team{ID: 37}},
|
||
},
|
||
},
|
||
},
|
||
expected: "hosts.team_id IN (2)",
|
||
},
|
||
{
|
||
name: "multi team maintainer and observer and admin",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
|
||
// Invalid role should be ignored
|
||
},
|
||
},
|
||
},
|
||
expected: "hosts.team_id IN (2,3)",
|
||
},
|
||
{
|
||
name: "team id only",
|
||
filter: fleet.TeamFilter{
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "team id with observer include",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: true,
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id = 1",
|
||
},
|
||
{
|
||
name: "team id with observer exclude",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||
IncludeObserver: false,
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "team id with admin exclude observer",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||
IncludeObserver: false,
|
||
TeamID: ptr.Uint(1),
|
||
},
|
||
expected: "hosts.team_id = 1",
|
||
},
|
||
{
|
||
name: "team id not in multiple team roles",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
TeamID: ptr.Uint(3),
|
||
},
|
||
expected: "FALSE",
|
||
},
|
||
{
|
||
name: "team id in multiple team roles",
|
||
filter: fleet.TeamFilter{
|
||
User: &fleet.User{
|
||
Teams: []fleet.UserTeam{
|
||
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
|
||
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
|
||
},
|
||
},
|
||
TeamID: ptr.Uint(2),
|
||
},
|
||
expected: "hosts.team_id = 2",
|
||
},
|
||
}
|
||
|
||
for _, tt := range testCases {
|
||
tt := tt
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
ds := &Datastore{logger: slog.New(slog.DiscardHandler)}
|
||
sql := ds.whereFilterTeamWithGlobalStats(tt.filter, "hosts")
|
||
assert.Equal(t, tt.expected, sql)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestBatchProcessDB(t *testing.T) {
|
||
type testData struct {
|
||
id int
|
||
value string
|
||
}
|
||
|
||
payload := []interface{}{
|
||
&testData{id: 1, value: "a"},
|
||
&testData{id: 2, value: "b"},
|
||
&testData{id: 3, value: "c"},
|
||
}
|
||
|
||
generateValueArgs := func(item interface{}) (string, []any) {
|
||
p := item.(*testData)
|
||
valuePart := "(?, ?),"
|
||
args := []any{p.id, p.value}
|
||
return valuePart, args
|
||
}
|
||
|
||
t.Run("TestEmptyPayload", func(t *testing.T) {
|
||
executeBatch := func(valuePart string, args []any) error {
|
||
return errors.New("execute shouldn't be called for an empty payload")
|
||
}
|
||
err := batchProcessDB([]interface{}{}, 1000, generateValueArgs, executeBatch)
|
||
require.NoError(t, err)
|
||
})
|
||
|
||
t.Run("TestSingleBatch", func(t *testing.T) {
|
||
callCount := 0
|
||
executeBatch := func(valuePart string, args []any) error {
|
||
callCount++
|
||
require.Equal(t, 2, len(args)/2) // each item adds 2 args
|
||
return nil
|
||
}
|
||
err := batchProcessDB(payload[:2], 2, generateValueArgs, executeBatch)
|
||
require.NoError(t, err)
|
||
require.Equal(t, 1, callCount)
|
||
})
|
||
|
||
t.Run("TestMultipleBatches", func(t *testing.T) {
|
||
callCount := 0
|
||
executeBatch := func(valuePart string, args []any) error {
|
||
callCount++
|
||
require.Equal(t, 2/callCount, len(args)/2) // each item adds 2 args
|
||
return nil
|
||
}
|
||
err := batchProcessDB(payload, 2, generateValueArgs, executeBatch)
|
||
require.NoError(t, err)
|
||
require.Equal(t, 2, callCount)
|
||
})
|
||
}
|
||
|
||
func TestGetContextTryStmt(t *testing.T) {
|
||
ctx := context.Background()
|
||
|
||
dbMock, ds := mockDatastore(t)
|
||
ds.stmtCache = map[string]*sqlx.Stmt{}
|
||
|
||
t.Run("get with unknown statement error", func(t *testing.T) {
|
||
count := 0
|
||
query := "SELECT 1"
|
||
|
||
// first call to cache the statement
|
||
dbMock.ExpectPrepare(query)
|
||
mockResult := sqlmock.NewRows([]string{query})
|
||
mockResult.AddRow("1")
|
||
dbMock.ExpectQuery(query).WillReturnRows(mockResult)
|
||
err := ds.getContextTryStmt(ctx, &count, query)
|
||
require.NoError(t, err)
|
||
require.NoError(t, dbMock.ExpectationsWereMet())
|
||
|
||
// verify that the statement was cached
|
||
stmt := ds.loadOrPrepareStmt(ctx, query)
|
||
require.NotNil(t, stmt)
|
||
|
||
// call again to trigger the unknown statement error and ensure it retries
|
||
// first query, make it fail
|
||
queryMock := dbMock.ExpectQuery(query)
|
||
mySQLErr := &mysql.MySQLError{
|
||
Number: mysqlerr.ER_UNKNOWN_STMT_HANDLER,
|
||
}
|
||
queryMock.WillReturnError(mySQLErr)
|
||
|
||
// after the failure, a second call is made, this time without
|
||
// the prepared statement
|
||
mockResult = sqlmock.NewRows([]string{query})
|
||
mockResult.AddRow("1")
|
||
dbMock.ExpectQuery(query).WillReturnRows(mockResult)
|
||
|
||
// make the call and verify we removed the prepared statement
|
||
err = ds.getContextTryStmt(ctx, &count, query)
|
||
require.NoError(t, err)
|
||
require.NoError(t, dbMock.ExpectationsWereMet())
|
||
stmt = ds.loadOrPrepareStmt(ctx, query)
|
||
require.Nil(t, stmt)
|
||
})
|
||
|
||
t.Run("get with other error", func(t *testing.T) {
|
||
dbMock, ds := mockDatastore(t)
|
||
ds.stmtCache = map[string]*sqlx.Stmt{}
|
||
count := 0
|
||
query := "SELECT 1"
|
||
|
||
// first call to cache the statement
|
||
dbMock.ExpectPrepare(query)
|
||
mockResult := sqlmock.NewRows([]string{query})
|
||
mockResult.AddRow("1")
|
||
dbMock.ExpectQuery(query).WillReturnRows(mockResult)
|
||
err := ds.getContextTryStmt(ctx, &count, query)
|
||
require.NoError(t, err)
|
||
require.Equal(t, 1, count)
|
||
require.NoError(t, dbMock.ExpectationsWereMet())
|
||
|
||
// verify that the statement was cached
|
||
stmt := ds.loadOrPrepareStmt(ctx, query)
|
||
require.NotNil(t, stmt)
|
||
|
||
// return a duplicate error
|
||
queryMock := dbMock.ExpectQuery(query)
|
||
mySQLErr := &mysql.MySQLError{
|
||
Number: mysqlerr.ER_DUP_ENTRY,
|
||
}
|
||
queryMock.WillReturnError(mySQLErr)
|
||
|
||
count = 0
|
||
err = ds.getContextTryStmt(ctx, &count, query)
|
||
require.ErrorIs(t, mySQLErr, err)
|
||
require.NoError(t, dbMock.ExpectationsWereMet())
|
||
stmt = ds.loadOrPrepareStmt(ctx, query)
|
||
require.NotNil(t, stmt)
|
||
})
|
||
}
|
||
|
||
// createTestDatabase creates a temporary test database with the given name and
|
||
// registers cleanup logic to drop it and close the connection when the test
|
||
// completes.
|
||
func createTestDatabase(t *testing.T, dbName string) {
|
||
t.Helper()
|
||
db, err := sql.Open(
|
||
"mysql",
|
||
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testing_utils.TestUsername, testing_utils.TestPassword,
|
||
testing_utils.TestAddress),
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
|
||
require.NoError(t, err)
|
||
|
||
t.Cleanup(func() {
|
||
_, _ = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName))
|
||
db.Close()
|
||
})
|
||
}
|
||
|
||
// TestReplicaPasswordReadFromDisk verifies that when a replica config uses PasswordPath,
|
||
// the password read from disk by checkAndModifyConfig is preserved for the actual DB connection.
|
||
//
|
||
// This is a regression test for https://github.com/fleetdm/fleet/pull/39689.
|
||
// Before the fix, NewDBConnections called fromCommonMysqlConfig twice for the replica config,
|
||
// and the second call created a fresh config where Password was empty (the mutation from
|
||
// checkAndModifyConfig reading PasswordPath was lost). This caused the replica to connect
|
||
// with an empty password instead of the one from disk.
|
||
func TestReplicaPasswordReadFromDisk(t *testing.T) {
|
||
// Write the correct password to a temp file.
|
||
passwordFile, err := os.CreateTemp(t.TempDir(), "*.passwordtest")
|
||
require.NoError(t, err)
|
||
_, err = passwordFile.WriteString(testing_utils.TestPassword)
|
||
require.NoError(t, err)
|
||
require.NoError(t, passwordFile.Close())
|
||
|
||
dbName := t.Name()
|
||
|
||
// Create the test database using a direct connection.
|
||
createTestDatabase(t, dbName)
|
||
|
||
// Primary config uses Password directly — this always works.
|
||
primaryConfig := config.MysqlConfig{
|
||
Username: testing_utils.TestUsername,
|
||
Password: testing_utils.TestPassword,
|
||
Address: testing_utils.TestAddress,
|
||
Database: dbName,
|
||
}
|
||
|
||
// Replica config uses PasswordPath instead of Password.
|
||
// checkAndModifyConfig must read the file and set Password on the config
|
||
// that is later passed to NewDB. Before the fix the mutation was discarded.
|
||
replicaConfig := config.MysqlConfig{
|
||
Username: testing_utils.TestUsername,
|
||
PasswordPath: passwordFile.Name(),
|
||
Address: testing_utils.TestAddress,
|
||
Database: dbName,
|
||
}
|
||
|
||
conns, err := NewDBConnections(primaryConfig, Replica(&replicaConfig), LimitAttempts(1), Logger(slog.New(slog.DiscardHandler)))
|
||
require.NoError(t, err, "replica connection should succeed when PasswordPath is used — "+
|
||
"if this fails with 'Access denied' the password read from disk was not preserved for the replica")
|
||
defer conns.Primary.Close()
|
||
defer conns.Replica.Close()
|
||
|
||
// Verify the replica connection actually works.
|
||
require.NoError(t, conns.Replica.Ping())
|
||
}
|
||
|
||
// TestReplicaTLSConfigPreserved verifies that when a replica config has TLSCA set,
|
||
// the TLSConfig="custom" mutation from checkAndModifyConfig is preserved so the
|
||
// replica actually connects with TLS.
|
||
//
|
||
// This is a regression test for https://github.com/fleetdm/fleet/pull/39689.
|
||
// Before the fix, NewDBConnections called fromCommonMysqlConfig twice for the replica config.
|
||
// The first call's config got TLSConfig="custom" via checkAndModifyConfig, but that config was
|
||
// block-scoped and discarded. The second call created a fresh config where TLSConfig was
|
||
// empty, so the replica silently connected without TLS.
|
||
func TestReplicaTLSConfigPreserved(t *testing.T) {
|
||
dbName := t.Name()
|
||
|
||
// Create the test database using a direct connection.
|
||
createTestDatabase(t, dbName)
|
||
|
||
// Generate a test CA cert that does NOT match the MySQL server's cert.
|
||
ca, _ := generateTestCert(t)
|
||
cert, key := generateTestCert(t)
|
||
|
||
// Primary config without TLS — connects normally.
|
||
primaryConfig := config.MysqlConfig{
|
||
Username: testing_utils.TestUsername,
|
||
Password: testing_utils.TestPassword,
|
||
Address: testing_utils.TestAddress,
|
||
Database: dbName,
|
||
}
|
||
|
||
// Replica config with TLS — checkAndModifyConfig should set TLSConfig="custom"
|
||
// and register the TLS profile. When NewDB builds the DSN it must include
|
||
// tls=custom so the driver actually uses TLS with our test CA.
|
||
replicaConfig := config.MysqlConfig{
|
||
Username: testing_utils.TestUsername,
|
||
Password: testing_utils.TestPassword,
|
||
Address: testing_utils.TestAddress,
|
||
Database: dbName,
|
||
TLSCA: ca,
|
||
TLSCert: cert,
|
||
TLSKey: key,
|
||
}
|
||
|
||
// After the fix the replica connection attempt uses TLS with our test CA,
|
||
// which doesn't match the MySQL server's certificate, so we expect a TLS error.
|
||
//
|
||
// Before the fix TLSConfig was empty for the replica's NewDB call, so the
|
||
// replica connected without TLS (no error) — meaning this assertion would fail.
|
||
_, err := NewDBConnections(primaryConfig, Replica(&replicaConfig), LimitAttempts(1), Logger(slog.New(slog.DiscardHandler)))
|
||
require.Error(t, err, "replica connection should fail with a TLS error when TLSCA is set — "+
|
||
"if this succeeds, TLS was silently not applied to the replica")
|
||
require.Regexp(t, "(x509|tls|EOF)", err.Error())
|
||
// Ensure the failure is due to the TLS handshake/connection and not TLS config registration.
|
||
assert.NotContains(t, err.Error(), "RegisterTLSConfig")
|
||
}
|