diff --git a/changes/14779-fix-host_batteries-deadlock b/changes/14779-fix-host_batteries-deadlock new file mode 100644 index 0000000000..d70a3eedbc --- /dev/null +++ b/changes/14779-fix-host_batteries-deadlock @@ -0,0 +1 @@ +* Fix possible deadlocks when upserting to `host_batteries` (found during load test). diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 71a4ac1144..df82354692 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -2912,65 +2912,95 @@ func (ds *Datastore) ReplaceHostDeviceMapping(ctx context.Context, hid uint, map } func (ds *Datastore) ReplaceHostBatteries(ctx context.Context, hid uint, mappings []*fleet.HostBattery) error { + for _, m := range mappings { + if hid != m.HostID { + return ctxerr.Errorf(ctx, "host batteries mapping are not all for the provided host id %d, found %d", hid, m.HostID) + } + } + + // The following SQL statements assume a small number of batteries reported per host. + // This is using the same pattern as ReplaceHostDeviceMapping. const ( - replaceStmt = ` - INSERT INTO - host_batteries ( + selStmt = ` + SELECT + id, host_id, serial_number, - cycle_count, - health - ) - VALUES - %s - ON DUPLICATE KEY UPDATE - cycle_count = VALUES(cycle_count), - health = VALUES(health), - updated_at = CURRENT_TIMESTAMP -` - valuesPart = `(?, ?, ?, ?),` + cycle_count, + health + FROM + host_batteries + WHERE + host_id = ?` - deleteExceptStmt = ` - DELETE FROM - host_batteries - WHERE - host_id = ? AND - serial_number NOT IN (?) -` - deleteAllStmt = ` - DELETE FROM - host_batteries - WHERE - host_id = ? -` + delStmt = ` + DELETE FROM + host_batteries + WHERE + id IN (?)` + + insStmt = ` + INSERT INTO + host_batteries (host_id, serial_number, cycle_count, health) + VALUES` + insPart = ` (?, ?, ?, ?),` ) - replaceArgs := make([]interface{}, 0, len(mappings)*4) - deleteNotIn := make([]string, 0, len(mappings)) - for _, hb := range mappings { - deleteNotIn = append(deleteNotIn, hb.SerialNumber) - replaceArgs = append(replaceArgs, hid, hb.SerialNumber, hb.CycleCount, hb.Health) + keyFn := func(b *fleet.HostBattery) string { + return b.SerialNumber + ":" + fmt.Sprint(b.CycleCount) + ":" + b.Health } return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { - // first, insert the new batteries or update the existing ones - if len(replaceArgs) > 0 { - if _, err := tx.ExecContext(ctx, fmt.Sprintf(replaceStmt, strings.TrimSuffix(strings.Repeat(valuesPart, len(mappings)), ",")), replaceArgs...); err != nil { - return ctxerr.Wrap(ctx, err, "upsert host batteries") + // Index the mappings by serial_number, to quickly check which ones + // need to be deleted and inserted. + toIns := make(map[string]*fleet.HostBattery) + serials := make(map[string]struct{}) + for _, m := range mappings { + if _, ok := serials[m.SerialNumber]; ok { + // Ignore multiple rows with the same serial number + // (e.g. in case of bugs in results reported by osquery). + continue + } + toIns[keyFn(m)] = m + serials[m.SerialNumber] = struct{}{} + } + + var prevMappings []*fleet.HostBattery + if err := sqlx.SelectContext(ctx, tx, &prevMappings, selStmt, hid); err != nil { + return ctxerr.Wrap(ctx, err, "select previous host batteries") + } + + var delIDs []uint + for _, pm := range prevMappings { + key := keyFn(pm) + if _, ok := toIns[key]; ok { + // already exists, no need to insert + delete(toIns, key) + } else { + // does not exist anymore, must be deleted + delIDs = append(delIDs, pm.ID) } } - // then, delete the old ones - if len(deleteNotIn) > 0 { - delStmt, args, err := sqlx.In(deleteExceptStmt, hid, deleteNotIn) + if len(delIDs) > 0 { + stmt, args, err := sqlx.In(delStmt, delIDs) if err != nil { - return ctxerr.Wrap(ctx, err, "generating host batteries delete NOT IN statement") + return ctxerr.Wrap(ctx, err, "prepare delete statement") } - if _, err := tx.ExecContext(ctx, delStmt, args...); err != nil { + if _, err := tx.ExecContext(ctx, stmt, args...); err != nil { return ctxerr.Wrap(ctx, err, "delete host batteries") } - } else if _, err := tx.ExecContext(ctx, deleteAllStmt, hid); err != nil { - return ctxerr.Wrap(ctx, err, "delete all host batteries") + } + + if len(toIns) > 0 { + var args []interface{} + for _, m := range toIns { + args = append(args, hid, m.SerialNumber, m.CycleCount, m.Health) + } + stmt := insStmt + strings.TrimSuffix(strings.Repeat(insPart, len(toIns)), ",") + if _, err := tx.ExecContext(ctx, stmt, args...); err != nil { + return ctxerr.Wrap(ctx, err, "insert host batteries") + } } return nil }) diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 36ebae62b4..710ffbe802 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -29,6 +29,7 @@ import ( "github.com/micromdm/nanodep/godep" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) var expLastExec = func() time.Time { @@ -134,6 +135,7 @@ func TestHosts(t *testing.T) { {"DeleteHosts", testHostsDeleteHosts}, {"HostIDsByOSVersion", testHostIDsByOSVersion}, {"ReplaceHostBatteries", testHostsReplaceHostBatteries}, + {"ReplaceHostBatteriesDeadlock", testHostsReplaceHostBatteriesDeadlock}, {"CountHostsNotResponding", testCountHostsNotResponding}, {"FailingPoliciesCount", testFailingPoliciesCount}, {"HostRecordNoPolicies", testHostsRecordNoPolicies}, @@ -6000,6 +6002,31 @@ func testHostsReplaceHostBatteries(t *testing.T, ds *Datastore) { require.NoError(t, err) require.ElementsMatch(t, h1Bat, bat1) + type timestamp struct { + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + } + var timestamps1 []timestamp + ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { + return sqlx.SelectContext(ctx, q, ×tamps1, `SELECT created_at, updated_at FROM host_batteries WHERE host_id = ?`, h1.ID) + }) + + // Insert the same battery data again. + err = ds.ReplaceHostBatteries(ctx, h1.ID, h1Bat) + require.NoError(t, err) + + var timestamps2 []timestamp + ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { + return sqlx.SelectContext(ctx, q, ×tamps2, `SELECT created_at, updated_at FROM host_batteries WHERE host_id = ?`, h1.ID) + }) + + // Verify that there were no inserts/updates (because reported data hasn't changed). + require.ElementsMatch(t, timestamps1, timestamps2) + + bat1, err = ds.ListHostBatteries(ctx, h1.ID) + require.NoError(t, err) + require.ElementsMatch(t, h1Bat, bat1) + bat2, err := ds.ListHostBatteries(ctx, h2.ID) require.NoError(t, err) require.Len(t, bat2, 0) @@ -6045,6 +6072,46 @@ func testHostsReplaceHostBatteries(t *testing.T, ds *Datastore) { require.ElementsMatch(t, h2Bat, bat2) } +func testHostsReplaceHostBatteriesDeadlock(t *testing.T, ds *Datastore) { + ctx := context.Background() + var hosts []*fleet.Host + for i := 1; i <= 100; i++ { + h, err := ds.NewHost(ctx, &fleet.Host{ + ID: uint(i), + OsqueryHostID: ptr.String(fmt.Sprintf("id-%d", i)), + NodeKey: ptr.String(fmt.Sprintf("key-%d", i)), + Platform: "linux", + Hostname: fmt.Sprintf("host-%d", i), + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + }) + require.NoError(t, err) + hosts = append(hosts, h) + } + + var g errgroup.Group + for _, h := range hosts { + hostID := h.ID + g.Go(func() error { + for i := 0; i < 100; i++ { + if err := ds.ReplaceHostBatteries(ctx, hostID, []*fleet.HostBattery{ + {HostID: hostID, SerialNumber: fmt.Sprintf("%d-0000", hostID), CycleCount: 1, Health: "Good"}, + {HostID: hostID, SerialNumber: fmt.Sprintf("%d-0000", hostID), CycleCount: 2, Health: "Fair"}, + }); err != nil { + return err + } + time.Sleep(10 * time.Millisecond) + } + return nil + }) + } + + err := g.Wait() + require.NoError(t, err) +} + func testCountHostsNotResponding(t *testing.T, ds *Datastore) { ctx := context.Background() config := config.FleetConfig{Osquery: config.OsqueryConfig{DetailUpdateInterval: 1 * time.Hour}} diff --git a/server/fleet/hosts.go b/server/fleet/hosts.go index 6de340ebe8..9610f720cd 100644 --- a/server/fleet/hosts.go +++ b/server/fleet/hosts.go @@ -982,6 +982,7 @@ func (h *HostMDM) UnmarshalJSON(b []byte) error { // HostBattery represents a host's battery, as reported by the osquery battery // table. type HostBattery struct { + ID uint `json:"-" db:"id"` HostID uint `json:"-" db:"host_id"` SerialNumber string `json:"-" db:"serial_number"` CycleCount int `json:"cycle_count" db:"cycle_count"`