diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go index 03a3fad47b..fcc6c99b87 100644 --- a/server/datastore/mysql/labels.go +++ b/server/datastore/mysql/labels.go @@ -276,13 +276,13 @@ DELETE FROM label_membership WHERE label_id = ? intRegex := regexp.MustCompile(`^[0-9]+$`) // Split hostnames into batches to avoid parameter limit in MySQL. - for _, hostIdentifiers := range batchHostnames(s.Hosts) { + for _, hostIdentifiersBatch := range batchHostnames(s.Hosts) { var stringIdents []string // Start with 0 so id IN (?) always has at least one element. // id = 0 never matches any real host. intIdents := []uint64{0} - for _, s := range hostIdentifiers { + for _, s := range hostIdentifiersBatch { stringIdents = append(stringIdents, s) // Use strconv to check if it's a valid integer if intRegex.MatchString(s) { @@ -291,10 +291,34 @@ DELETE FROM label_membership WHERE label_id = ? } } + hostsFilterClause := `(hostname IN (?) OR hardware_serial IN (?) OR uuid IN (?) OR id IN (?))` + + if s.TeamID != nil { + // Team labels can only be applied to hosts on that team. + hostnames := stringIdents + serialNumbers := stringIdents + uuids := stringIdents + hostIDs := intIdents + if err := checkHostIdentifiersInTeam(ctx, tx, + *s.TeamID, + hostsFilterClause, + []any{ + hostnames, + serialNumbers, + uuids, + hostIDs, + }, + ); err != nil { + return ctxerr.Wrap(ctx, err, "check host identifiers in team") + } + } + // Use ignore because duplicate hostnames could appear in // different batches and would result in duplicate key errors. - sql = ` -INSERT IGNORE INTO label_membership (label_id, host_id) (SELECT DISTINCT ?, id FROM hosts where hostname IN (?) OR hardware_serial IN (?) OR uuid IN (?) OR id IN (?))` + sql = fmt.Sprintf( + `INSERT IGNORE INTO label_membership (label_id, host_id) (SELECT DISTINCT ?, id FROM hosts WHERE %s)`, + hostsFilterClause, + ) sql, args, err := sqlx.In(sql, labelID, stringIdents, stringIdents, stringIdents, intIdents) if err != nil { return ctxerr.Wrap(ctx, err, "build membership IN statement") @@ -312,6 +336,32 @@ INSERT IGNORE INTO label_membership (label_id, host_id) (SELECT DISTINCT ?, id F return ctxerr.Wrap(ctx, err, "ApplyLabelSpecs transaction") } +var errLabelMismatchHostTeam = errors.New("supplied hosts are on a different team than the label") + +func checkHostIdentifiersInTeam( + ctx context.Context, + tx sqlx.QueryerContext, + teamID uint, + andFilter string, + args []any, +) error { + hostTeamCheckSql, args, err := sqlx.In( + `SELECT COUNT(id) FROM hosts WHERE (team_id != ? OR team_id IS NULL) AND `+andFilter, + append([]any{teamID}, args...)..., + ) + if err != nil { + return ctxerr.Wrap(ctx, err, "build host identifiers team membership check IN statement") + } + var hostCountOnWrongTeam int + if err := tx.QueryRowxContext(ctx, hostTeamCheckSql, args...).Scan(&hostCountOnWrongTeam); err != nil { + return ctxerr.Wrap(ctx, err, "execute host identifiers team membership check query") + } + if hostCountOnWrongTeam > 0 { + return ctxerr.Wrap(ctx, errLabelMismatchHostTeam) + } + return nil +} + func batchHostnames(hostnames []string) [][]string { // Split hostnames into batches so that they can all be inserted without // overflowing the MySQL max number of parameters (somewhere around 65,000 @@ -348,29 +398,24 @@ func (ds *Datastore) UpdateLabelMembershipByHostIDs(ctx context.Context, label f } // Split hostIds into batches to avoid parameter limit in MySQL. - for _, hostIds := range batchHostIds(hostIds) { - if label.TeamID != nil { // team labels can only be applied to hosts on that team - hostTeamCheckSql := `SELECT COUNT(id) FROM hosts WHERE (team_id != ? OR team_id IS NULL) AND id IN (?)` - hostTeamCheckSql, args, err := sqlx.In(hostTeamCheckSql, label.TeamID, hostIds) - if err != nil { - return ctxerr.Wrap(ctx, err, "build host team membership check IN statement") - } - - var hostCountOnWrongTeam int - if err := tx.QueryRowxContext(ctx, hostTeamCheckSql, args...).Scan(&hostCountOnWrongTeam); err != nil { - return ctxerr.Wrap(ctx, err, "execute host team membership check query") - } - if hostCountOnWrongTeam > 0 { - return ctxerr.Wrap(ctx, errors.New("supplied hosts are on a different team than the label")) + for _, hostIDsBatch := range batchHostIds(hostIds) { + if label.TeamID != nil { + // Team labels can only be applied to hosts on that team. + if err := checkHostIdentifiersInTeam(ctx, tx, + *label.TeamID, + `id IN (?)`, + []any{hostIDsBatch}, + ); err != nil { + return ctxerr.Wrap(ctx, err, "check host IDs in team") } } - // Use ignore because duplicate hostIds could appear in + // Use ignore because duplicate host IDs could appear in // different batches and would result in duplicate key errors. var values []any var placeholders []string - for _, hostID := range hostIds { + for _, hostID := range hostIDsBatch { values = append(values, label.ID, hostID) placeholders = append(placeholders, "(?, ?)") } diff --git a/server/datastore/mysql/labels_test.go b/server/datastore/mysql/labels_test.go index fe608134ac..a3766f318f 100644 --- a/server/datastore/mysql/labels_test.go +++ b/server/datastore/mysql/labels_test.go @@ -104,6 +104,7 @@ func TestLabels(t *testing.T) { {"TeamLabels", testTeamLabels}, {"UpdateLabelMembershipForTransferredHost", testUpdateLabelMembershipForTransferredHost}, {"SetAsideLabels", testSetAsideLabels}, + {"ApplyLabelSpecsWithManualTeamLabels", testApplyLabelSpecsWithManualTeamLabels}, } // call TruncateTables first to remove migration-created labels TruncateTables(t, ds) @@ -3305,3 +3306,139 @@ func testSetAsideLabels(t *testing.T, ds *Datastore) { }) } } + +func testApplyLabelSpecsWithManualTeamLabels(t *testing.T, ds *Datastore) { + ctx := t.Context() + teamFilter := fleet.TeamFilter{User: test.UserAdmin} + + // Create teams. + t1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) + require.NoError(t, err) + t2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"}) + require.NoError(t, err) + t3, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team3"}) + require.NoError(t, err) + t4, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team4"}) + require.NoError(t, err) + + // Create hosts on the teams. + h1t1, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("h1t1"), + NodeKey: ptr.String("h1t1"), + Hostname: "hostname-h1t1", + HardwareSerial: "serial-h1t1", + UUID: "uuid-h1t1", + Platform: "darwin", + TeamID: &t1.ID, + }) + require.NoError(t, err) + h2t2, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("h2t2"), + NodeKey: ptr.String("h2t2"), + Hostname: "hostname-h2t2", + HardwareSerial: "serial-h2t2", + UUID: "uuid-h2t2", + Platform: "darwin", + TeamID: &t2.ID, + }) + require.NoError(t, err) + h3t3, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("h3t3"), + NodeKey: ptr.String("h3t3"), + Hostname: "hostname-h3t3", + HardwareSerial: "serial-h3t3", + UUID: "uuid-h3t3", + Platform: "darwin", + TeamID: &t3.ID, + }) + require.NoError(t, err) + h4t4, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("h4t4"), + NodeKey: ptr.String("h4t4"), + Hostname: "hostname-h4t4", + HardwareSerial: "serial-h4t4", + UUID: "uuid-h4t4", + Platform: "darwin", + TeamID: &t4.ID, + }) + require.NoError(t, err) + h5Global, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("h5Global"), + NodeKey: ptr.String("h5Global"), + Hostname: "hostname-h5Global", + HardwareSerial: "serial-h5Global", + UUID: "uuid-h5Global", + Platform: "darwin", + TeamID: nil, + }) + require.NoError(t, err) + + // Create a global manual label, make sure you can add all. + err = ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{ + { + Name: "global1", + LabelMembershipType: fleet.LabelMembershipTypeManual, + Hosts: fleet.HostsSlice{ + h1t1.Hostname, + h2t2.HardwareSerial, + h3t3.UUID, + fmt.Sprint(h4t4.ID), + h5Global.Hostname, + }, + }, + }) + require.NoError(t, err) + + global1, err := ds.LabelByName(ctx, "global1", teamFilter) + require.NoError(t, err) + hosts, err := ds.ListHostsInLabel(ctx, teamFilter, global1.ID, fleet.HostListOptions{}) + require.NoError(t, err) + require.Len(t, hosts, 5) + + // Attempt to create team label, make sure we can only add hosts on that team. + for _, hostIdentifier := range []string{ + h2t2.Hostname, + h3t3.UUID, + h4t4.HardwareSerial, + fmt.Sprint(h5Global.ID), + } { + err := ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{ + { + Name: "l1t1", + LabelMembershipType: fleet.LabelMembershipTypeManual, + Hosts: fleet.HostsSlice{ + h1t1.Hostname, + hostIdentifier, // conflicting host identifier. + }, + TeamID: &t1.ID, + }, + }) + require.Error(t, err) + require.ErrorIs(t, err, errLabelMismatchHostTeam) + } + // Create team label with team host identifiers should work. + for _, hostIdentifier := range []string{ + h1t1.Hostname, + h1t1.UUID, + h1t1.HardwareSerial, + fmt.Sprint(h1t1.ID), + } { + err = ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{ + { + Name: "l1t1", + LabelMembershipType: fleet.LabelMembershipTypeManual, + Hosts: fleet.HostsSlice{ + hostIdentifier, + }, + TeamID: &t1.ID, + }, + }) + require.NoError(t, err) + l1t1, err := ds.LabelByName(ctx, "l1t1", teamFilter) + require.NoError(t, err) + hosts, err := ds.ListHostsInLabel(ctx, teamFilter, l1t1.ID, fleet.HostListOptions{}) + require.NoError(t, err) + require.Len(t, hosts, 1) + require.Equal(t, h1t1.ID, hosts[0].ID) + } +}