diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 272abf391e..3178e5568f 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -309,6 +309,23 @@ func loadHostUsersDB(ctx context.Context, db sqlx.QueryerContext, hostID uint) ( return users, nil } +// hostRefs are the tables referenced by hosts. +// +// Defined here for testing purposes. +var hostRefs = []string{ + "host_seen_times", + "host_software", + "host_users", + "host_emails", + "host_additional", + "scheduled_query_stats", + "label_membership", + "policy_membership", + "host_mdm", + "host_munki_info", + "host_device_auth", +} + func (ds *Datastore) DeleteHost(ctx context.Context, hid uint) error { delHostRef := func(tx sqlx.ExtContext, table string) error { _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE host_id=?`, table), hid) @@ -324,20 +341,6 @@ func (ds *Datastore) DeleteHost(ctx context.Context, hid uint) error { return ctxerr.Wrapf(ctx, err, "delete host") } - hostRefs := []string{ - "host_seen_times", - "host_software", - "host_users", - "host_emails", - "host_additional", - "scheduled_query_stats", - "label_membership", - "policy_membership", - "host_mdm", - "host_munki_info", - "host_device_auth", - } - for _, table := range hostRefs { err := delHostRef(tx, table) if err != nil { @@ -1018,41 +1021,11 @@ func (ds *Datastore) TotalAndUnseenHostsSince(ctx context.Context, daysCount int } func (ds *Datastore) DeleteHosts(ctx context.Context, ids []uint) error { - _, err := ds.deleteEntities(ctx, hostsTable, ids) - if err != nil { - return ctxerr.Wrap(ctx, err, "deleting hosts") + for _, id := range ids { + if err := ds.DeleteHost(ctx, id); err != nil { + return ctxerr.Wrapf(ctx, err, "delete host %d", id) + } } - - query, args, err := sqlx.In(`DELETE FROM host_seen_times WHERE host_id in (?)`, ids) - if err != nil { - return ctxerr.Wrapf(ctx, err, "building delete host_seen_times query") - } - - _, err = ds.writer.ExecContext(ctx, query, args...) - if err != nil { - return ctxerr.Wrap(ctx, err, "deleting host seen times") - } - - query, args, err = sqlx.In(`DELETE FROM host_emails WHERE host_id in (?)`, ids) - if err != nil { - return ctxerr.Wrapf(ctx, err, "building delete host_emails query") - } - - _, err = ds.writer.ExecContext(ctx, query, args...) - if err != nil { - return ctxerr.Wrap(ctx, err, "deleting host emails") - } - - query, args, err = sqlx.In(`DELETE FROM pack_targets WHERE type=? AND target_id in (?)`, fleet.TargetHost, ids) - if err != nil { - return ctxerr.Wrapf(ctx, err, "building delete pack_targets query") - } - - _, err = ds.writer.ExecContext(ctx, query, args...) - if err != nil { - return ctxerr.Wrapf(ctx, err, "deleting pack_targets for hosts") - } - return nil } diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index af730d5e6c..0d07aea908 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -114,6 +114,7 @@ func TestHosts(t *testing.T) { {"UpdateRefetchRequested", testUpdateRefetchRequested}, {"LoadHostByDeviceAuthToken", testHostsLoadHostByDeviceAuthToken}, {"SetOrUpdateDeviceAuthToken", testHostsSetOrUpdateDeviceAuthToken}, + {"DeleteHosts", testHostsDeleteHosts}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -3779,3 +3780,135 @@ func testHostsSetOrUpdateDeviceAuthToken(t *testing.T, ds *Datastore) { require.Error(t, err) assert.ErrorIs(t, err, sql.ErrNoRows) } + +func testHostsDeleteHosts(t *testing.T, ds *Datastore) { + // Updates hosts and host_seen_times. + host, err := ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: "1", + UUID: "1", + Hostname: "foo.local", + }) + require.NoError(t, err) + require.NotNil(t, host) + // Updates host_software. + software := []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, + {Name: "bar", Version: "1.0.0", Source: "deb_packages"}, + } + err = ds.UpdateHostSoftware(context.Background(), host.ID, software) + require.NoError(t, err) + // Updates host_users. + users := []fleet.HostUser{ + { + Uid: 42, + Username: "user1", + Type: "aaa", + GroupName: "group", + Shell: "shell", + }, + { + Uid: 43, + Username: "user2", + Type: "bbb", + GroupName: "group2", + Shell: "bash", + }, + } + err = ds.SaveHostUsers(context.Background(), host.ID, users) + require.NoError(t, err) + // Updates host_emails. + err = ds.ReplaceHostDeviceMapping(context.Background(), host.ID, []*fleet.HostDeviceMapping{ + {HostID: host.ID, Email: "a@b.c", Source: "src"}, + }) + require.NoError(t, err) + // Updates host_additional. + additional := json.RawMessage(`{"additional": "result"}`) + host.Additional = &additional + err = saveHostAdditionalDB(context.Background(), ds.writer, host.ID, host.Additional) + require.NoError(t, err) + // Updates scheduled_query_stats. + pack, err := ds.NewPack(context.Background(), &fleet.Pack{ + Name: "test1", + HostIDs: []uint{host.ID}, + }) + require.NoError(t, err) + query := test.NewQuery(t, ds, "time", "select * from time", 0, true) + squery := test.NewScheduledQuery(t, ds, pack.ID, query.ID, 30, true, true, "time-scheduled") + stats := []fleet.ScheduledQueryStats{ + { + ScheduledQueryName: squery.Name, + ScheduledQueryID: squery.ID, + QueryName: query.Name, + PackName: pack.Name, + PackID: pack.ID, + AverageMemory: 8000, + Denylisted: false, + Executions: 164, + Interval: 30, + LastExecuted: time.Unix(1620325191, 0).UTC(), + OutputSize: 1337, + SystemTime: 150, + UserTime: 180, + WallTime: 0, + }, + } + host.PackStats = []fleet.PackStats{ + { + PackName: "test1", + QueryStats: stats, + }, + } + err = ds.SaveHost(context.Background(), host) + require.NoError(t, err) + // Updates label_membership. + labelID := uint(1) + label := &fleet.LabelSpec{ + ID: labelID, + Name: "label foo", + Query: "select * from time;", + } + err = ds.ApplyLabelSpecs(context.Background(), []*fleet.LabelSpec{label}) + require.NoError(t, err) + err = ds.RecordLabelQueryExecutions(context.Background(), host, map[uint]*bool{label.ID: ptr.Bool(true)}, time.Now(), false) + require.NoError(t, err) + // Update policy_membership. + user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) + policy, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{ + Name: "policy foo", + Query: "select * from time", + }) + require.NoError(t, err) + require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)) + // Update host_mdm. + err = ds.SetOrUpdateMDMData(context.Background(), host.ID, false, "", false) + require.NoError(t, err) + // Update host_munki_info. + err = ds.SetOrUpdateMunkiVersion(context.Background(), host.ID, "42") + require.NoError(t, err) + // Update device_auth_token. + err = ds.SetOrUpdateDeviceAuthToken(context.Background(), host.ID, "foo") + require.NoError(t, err) + + // Check there's an entry for the host in all the associated tables. + for _, hostRef := range hostRefs { + var ok bool + err = ds.writer.Get(&ok, fmt.Sprintf("SELECT 1 FROM %s WHERE host_id = ?", hostRef), host.ID) + require.NoError(t, err) + require.True(t, ok, "table: %s", hostRef) + } + + err = ds.DeleteHosts(context.Background(), []uint{host.ID}) + require.NoError(t, err) + + // Check that all the associated tables were cleaned up. + for _, hostRef := range hostRefs { + var ok bool + err = ds.writer.Get(&ok, fmt.Sprintf("SELECT 1 FROM %s WHERE host_id = ?", hostRef), host.ID) + require.True(t, err == nil || errors.Is(err, sql.ErrNoRows), "table: %s", hostRef) + require.False(t, ok, "table: %s", hostRef) + } +} diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 11b4a4d5e5..58cf63a371 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -196,6 +196,10 @@ type Datastore interface { TotalAndUnseenHostsSince(ctx context.Context, daysCount int) (total int, unseen int, err error) + // DeleteHosts deletes associated tables for multiple hosts. + // + // It atomically deletes each host but if it returns an error, some of the hosts may be + // deleted and others not. DeleteHosts(ctx context.Context, ids []uint) error CountHosts(ctx context.Context, filter TeamFilter, opt HostListOptions) (int, error)