Check host teams before adding them to team manual labels (#37975)

**Related issue:** #33760

## Testing

- [X] Added/updated automated tests
- [X] QA'd all new/changed functionality manually

For unreleased bug fixes in a release candidate, one of:

- [X] Confirmed that the fix is not expected to adversely impact load
test results
This commit is contained in:
Lucas Manuel Rodriguez 2026-01-07 13:13:25 -03:00 committed by GitHub
parent 4a2462605a
commit fb785ce9ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 202 additions and 20 deletions

View file

@ -276,13 +276,13 @@ DELETE FROM label_membership WHERE label_id = ?
intRegex := regexp.MustCompile(`^[0-9]+$`)
// Split hostnames into batches to avoid parameter limit in MySQL.
for _, hostIdentifiers := range batchHostnames(s.Hosts) {
for _, hostIdentifiersBatch := range batchHostnames(s.Hosts) {
var stringIdents []string
// Start with 0 so id IN (?) always has at least one element.
// id = 0 never matches any real host.
intIdents := []uint64{0}
for _, s := range hostIdentifiers {
for _, s := range hostIdentifiersBatch {
stringIdents = append(stringIdents, s)
// Use strconv to check if it's a valid integer
if intRegex.MatchString(s) {
@ -291,10 +291,34 @@ DELETE FROM label_membership WHERE label_id = ?
}
}
hostsFilterClause := `(hostname IN (?) OR hardware_serial IN (?) OR uuid IN (?) OR id IN (?))`
if s.TeamID != nil {
// Team labels can only be applied to hosts on that team.
hostnames := stringIdents
serialNumbers := stringIdents
uuids := stringIdents
hostIDs := intIdents
if err := checkHostIdentifiersInTeam(ctx, tx,
*s.TeamID,
hostsFilterClause,
[]any{
hostnames,
serialNumbers,
uuids,
hostIDs,
},
); err != nil {
return ctxerr.Wrap(ctx, err, "check host identifiers in team")
}
}
// Use ignore because duplicate hostnames could appear in
// different batches and would result in duplicate key errors.
sql = `
INSERT IGNORE INTO label_membership (label_id, host_id) (SELECT DISTINCT ?, id FROM hosts where hostname IN (?) OR hardware_serial IN (?) OR uuid IN (?) OR id IN (?))`
sql = fmt.Sprintf(
`INSERT IGNORE INTO label_membership (label_id, host_id) (SELECT DISTINCT ?, id FROM hosts WHERE %s)`,
hostsFilterClause,
)
sql, args, err := sqlx.In(sql, labelID, stringIdents, stringIdents, stringIdents, intIdents)
if err != nil {
return ctxerr.Wrap(ctx, err, "build membership IN statement")
@ -312,6 +336,32 @@ INSERT IGNORE INTO label_membership (label_id, host_id) (SELECT DISTINCT ?, id F
return ctxerr.Wrap(ctx, err, "ApplyLabelSpecs transaction")
}
var errLabelMismatchHostTeam = errors.New("supplied hosts are on a different team than the label")
func checkHostIdentifiersInTeam(
ctx context.Context,
tx sqlx.QueryerContext,
teamID uint,
andFilter string,
args []any,
) error {
hostTeamCheckSql, args, err := sqlx.In(
`SELECT COUNT(id) FROM hosts WHERE (team_id != ? OR team_id IS NULL) AND `+andFilter,
append([]any{teamID}, args...)...,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "build host identifiers team membership check IN statement")
}
var hostCountOnWrongTeam int
if err := tx.QueryRowxContext(ctx, hostTeamCheckSql, args...).Scan(&hostCountOnWrongTeam); err != nil {
return ctxerr.Wrap(ctx, err, "execute host identifiers team membership check query")
}
if hostCountOnWrongTeam > 0 {
return ctxerr.Wrap(ctx, errLabelMismatchHostTeam)
}
return nil
}
func batchHostnames(hostnames []string) [][]string {
// Split hostnames into batches so that they can all be inserted without
// overflowing the MySQL max number of parameters (somewhere around 65,000
@ -348,29 +398,24 @@ func (ds *Datastore) UpdateLabelMembershipByHostIDs(ctx context.Context, label f
}
// Split hostIds into batches to avoid parameter limit in MySQL.
for _, hostIds := range batchHostIds(hostIds) {
if label.TeamID != nil { // team labels can only be applied to hosts on that team
hostTeamCheckSql := `SELECT COUNT(id) FROM hosts WHERE (team_id != ? OR team_id IS NULL) AND id IN (?)`
hostTeamCheckSql, args, err := sqlx.In(hostTeamCheckSql, label.TeamID, hostIds)
if err != nil {
return ctxerr.Wrap(ctx, err, "build host team membership check IN statement")
}
var hostCountOnWrongTeam int
if err := tx.QueryRowxContext(ctx, hostTeamCheckSql, args...).Scan(&hostCountOnWrongTeam); err != nil {
return ctxerr.Wrap(ctx, err, "execute host team membership check query")
}
if hostCountOnWrongTeam > 0 {
return ctxerr.Wrap(ctx, errors.New("supplied hosts are on a different team than the label"))
for _, hostIDsBatch := range batchHostIds(hostIds) {
if label.TeamID != nil {
// Team labels can only be applied to hosts on that team.
if err := checkHostIdentifiersInTeam(ctx, tx,
*label.TeamID,
`id IN (?)`,
[]any{hostIDsBatch},
); err != nil {
return ctxerr.Wrap(ctx, err, "check host IDs in team")
}
}
// Use ignore because duplicate hostIds could appear in
// Use ignore because duplicate host IDs could appear in
// different batches and would result in duplicate key errors.
var values []any
var placeholders []string
for _, hostID := range hostIds {
for _, hostID := range hostIDsBatch {
values = append(values, label.ID, hostID)
placeholders = append(placeholders, "(?, ?)")
}

View file

@ -104,6 +104,7 @@ func TestLabels(t *testing.T) {
{"TeamLabels", testTeamLabels},
{"UpdateLabelMembershipForTransferredHost", testUpdateLabelMembershipForTransferredHost},
{"SetAsideLabels", testSetAsideLabels},
{"ApplyLabelSpecsWithManualTeamLabels", testApplyLabelSpecsWithManualTeamLabels},
}
// call TruncateTables first to remove migration-created labels
TruncateTables(t, ds)
@ -3305,3 +3306,139 @@ func testSetAsideLabels(t *testing.T, ds *Datastore) {
})
}
}
func testApplyLabelSpecsWithManualTeamLabels(t *testing.T, ds *Datastore) {
ctx := t.Context()
teamFilter := fleet.TeamFilter{User: test.UserAdmin}
// Create teams.
t1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
t2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"})
require.NoError(t, err)
t3, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team3"})
require.NoError(t, err)
t4, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team4"})
require.NoError(t, err)
// Create hosts on the teams.
h1t1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("h1t1"),
NodeKey: ptr.String("h1t1"),
Hostname: "hostname-h1t1",
HardwareSerial: "serial-h1t1",
UUID: "uuid-h1t1",
Platform: "darwin",
TeamID: &t1.ID,
})
require.NoError(t, err)
h2t2, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("h2t2"),
NodeKey: ptr.String("h2t2"),
Hostname: "hostname-h2t2",
HardwareSerial: "serial-h2t2",
UUID: "uuid-h2t2",
Platform: "darwin",
TeamID: &t2.ID,
})
require.NoError(t, err)
h3t3, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("h3t3"),
NodeKey: ptr.String("h3t3"),
Hostname: "hostname-h3t3",
HardwareSerial: "serial-h3t3",
UUID: "uuid-h3t3",
Platform: "darwin",
TeamID: &t3.ID,
})
require.NoError(t, err)
h4t4, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("h4t4"),
NodeKey: ptr.String("h4t4"),
Hostname: "hostname-h4t4",
HardwareSerial: "serial-h4t4",
UUID: "uuid-h4t4",
Platform: "darwin",
TeamID: &t4.ID,
})
require.NoError(t, err)
h5Global, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("h5Global"),
NodeKey: ptr.String("h5Global"),
Hostname: "hostname-h5Global",
HardwareSerial: "serial-h5Global",
UUID: "uuid-h5Global",
Platform: "darwin",
TeamID: nil,
})
require.NoError(t, err)
// Create a global manual label, make sure you can add all.
err = ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{
{
Name: "global1",
LabelMembershipType: fleet.LabelMembershipTypeManual,
Hosts: fleet.HostsSlice{
h1t1.Hostname,
h2t2.HardwareSerial,
h3t3.UUID,
fmt.Sprint(h4t4.ID),
h5Global.Hostname,
},
},
})
require.NoError(t, err)
global1, err := ds.LabelByName(ctx, "global1", teamFilter)
require.NoError(t, err)
hosts, err := ds.ListHostsInLabel(ctx, teamFilter, global1.ID, fleet.HostListOptions{})
require.NoError(t, err)
require.Len(t, hosts, 5)
// Attempt to create team label, make sure we can only add hosts on that team.
for _, hostIdentifier := range []string{
h2t2.Hostname,
h3t3.UUID,
h4t4.HardwareSerial,
fmt.Sprint(h5Global.ID),
} {
err := ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{
{
Name: "l1t1",
LabelMembershipType: fleet.LabelMembershipTypeManual,
Hosts: fleet.HostsSlice{
h1t1.Hostname,
hostIdentifier, // conflicting host identifier.
},
TeamID: &t1.ID,
},
})
require.Error(t, err)
require.ErrorIs(t, err, errLabelMismatchHostTeam)
}
// Create team label with team host identifiers should work.
for _, hostIdentifier := range []string{
h1t1.Hostname,
h1t1.UUID,
h1t1.HardwareSerial,
fmt.Sprint(h1t1.ID),
} {
err = ds.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{
{
Name: "l1t1",
LabelMembershipType: fleet.LabelMembershipTypeManual,
Hosts: fleet.HostsSlice{
hostIdentifier,
},
TeamID: &t1.ID,
},
})
require.NoError(t, err)
l1t1, err := ds.LabelByName(ctx, "l1t1", teamFilter)
require.NoError(t, err)
hosts, err := ds.ListHostsInLabel(ctx, teamFilter, l1t1.ID, fleet.HostListOptions{})
require.NoError(t, err)
require.Len(t, hosts, 1)
require.Equal(t, h1t1.ID, hosts[0].ID)
}
}