diff --git a/changes/27701-fix-manual-label-with-duplicate-serials b/changes/27701-fix-manual-label-with-duplicate-serials new file mode 100644 index 0000000000..9f8956d0ec --- /dev/null +++ b/changes/27701-fix-manual-label-with-duplicate-serials @@ -0,0 +1 @@ +- Fixed an issue where adding/updating a manual label had inconsistent results when multiple hosts shared a serial number diff --git a/frontend/services/entities/labels.ts b/frontend/services/entities/labels.ts index fcc246f8c3..f88d9f6fc2 100644 --- a/frontend/services/entities/labels.ts +++ b/frontend/services/entities/labels.ts @@ -31,10 +31,6 @@ const isManualLabelFormData = ( return "targetedHosts" in formData; }; -const getUniqueHostIdentifier = (host: IHost) => { - return host.hardware_serial || host.uuid || host.hostname; -}; - const generateCreateLabelBody = ( formData: IDynamicLabelFormData | IManualLabelFormData ) => { @@ -43,9 +39,7 @@ const generateCreateLabelBody = ( return { name: formData.name, description: formData.description, - hosts: formData.targetedHosts.map((host) => - getUniqueHostIdentifier(host) - ), + host_ids: formData.targetedHosts.map((host) => host.id), }; } return formData; diff --git a/server/fleet/labels.go b/server/fleet/labels.go index b8de018452..42838317d4 100644 --- a/server/fleet/labels.go +++ b/server/fleet/labels.go @@ -14,7 +14,8 @@ type ModifyLabelPayload struct { // valid for manual labels. If it is nil (not just len() == 0, but == nil), // then the list of hosts is not modified. If it is not nil and len == 0, // then all members are removed. - Hosts []string `json:"hosts"` + Hosts []string `json:"hosts"` + HostIDs []uint `json:"host_ids"` } type LabelPayload struct { @@ -30,7 +31,8 @@ type LabelPayload struct { // supported by HostByIdentifier) that are part of the label. This defines a // manual label. Can be empty for a manual label that doesn't target any // host. Must be empty for a dynamic label. - Hosts []string `json:"hosts"` + Hosts []string `json:"hosts"` + HostIDs []uint `json:"host_ids"` } // LabelType is used to catagorize the kind of label diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 8303231714..c37c6bbf4e 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -4218,7 +4218,7 @@ func (s *integrationTestSuite) TestLabels() { // create a label with both a query and hosts, error res := s.Do("POST", "/api/latest/fleet/labels", &fleet.LabelPayload{Name: t.Name(), Query: "select 1", Hosts: []string{manualHosts[0].UUID}}, http.StatusUnprocessableEntity) errMsg := extractServerErrorText(res.Body) - require.Contains(t, errMsg, `Only one of either "query" or "hosts" can be included in the request.`) + require.Contains(t, errMsg, `Only one of either "query" or "hosts/host_ids" can be included in the request.`) // create invalid label, conflicts with builtin name for n := range builtinsMap { @@ -4351,6 +4351,19 @@ func (s *integrationTestSuite) TestLabels() { assert.Equal(t, newName, modResp.Label.Name) manualLbl2.Name = newName + // modify manual label 2 adding some hosts by ID + modResp = modifyLabelResponse{} + newName = "modified_manual_label2" + s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/labels/%d", manualLbl2.ID), + &fleet.ModifyLabelPayload{Name: &newName, HostIDs: []uint{manualHosts[1].ID, manualHosts[2].ID}}, http.StatusOK, &modResp) + assert.Equal(t, manualLbl2.ID, modResp.Label.ID) + assert.Equal(t, fleet.LabelTypeRegular, modResp.Label.LabelType) + assert.Equal(t, fleet.LabelMembershipTypeManual, modResp.Label.LabelMembershipType) + assert.ElementsMatch(t, []uint{manualHosts[1].ID, manualHosts[2].ID}, modResp.Label.HostIDs) + assert.EqualValues(t, 2, modResp.Label.HostCount) + assert.Equal(t, newName, modResp.Label.Name) + manualLbl2.Name = newName + // modify manual label 2 clearing its hosts modResp = modifyLabelResponse{} s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/labels/%d", manualLbl2.ID), &fleet.ModifyLabelPayload{Hosts: []string{}, Description: ptr.String("desc")}, http.StatusOK, &modResp) diff --git a/server/service/labels.go b/server/service/labels.go index 4bb704abd8..81d3140676 100644 --- a/server/service/labels.go +++ b/server/service/labels.go @@ -53,6 +53,11 @@ func (svc *Service) NewLabel(ctx context.Context, p fleet.LabelPayload) (*fleet. if !ok { return nil, nil, fleet.ErrNoContext } + + if len(p.Hosts) > 0 && len(p.HostIDs) > 0 { + return nil, nil, fleet.NewInvalidArgumentError("hosts", `Only one of either "hosts" or "host_ids" can be included in the request.`) + } + filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} label := &fleet.Label{ @@ -66,8 +71,8 @@ func (svc *Service) NewLabel(ctx context.Context, p fleet.LabelPayload) (*fleet. } label.Name = p.Name - if p.Query != "" && len(p.Hosts) > 0 { - return nil, nil, fleet.NewInvalidArgumentError("query", `Only one of either "query" or "hosts" can be included in the request.`) + if p.Query != "" && (len(p.Hosts) > 0 || len(p.HostIDs) > 0) { + return nil, nil, fleet.NewInvalidArgumentError("query", `Only one of either "query" or "hosts/host_ids" can be included in the request.`) } label.Query = p.Query if p.Query == "" { @@ -90,11 +95,13 @@ func (svc *Service) NewLabel(ctx context.Context, p fleet.LabelPayload) (*fleet. return nil, nil, err } - var hostIDs []uint if label.LabelMembershipType == fleet.LabelMembershipTypeManual { - hostIDs, err = svc.ds.HostIDsByIdentifier(ctx, filter, p.Hosts) - if err != nil { - return nil, nil, err + hostIDs := p.HostIDs + if len(p.Hosts) > 0 { + hostIDs, err = svc.ds.HostIDsByIdentifier(ctx, filter, p.Hosts) + if err != nil { + return nil, nil, err + } } return svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.ID, hostIDs, filter) } @@ -140,6 +147,11 @@ func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.Modi if !ok { return nil, nil, fleet.ErrNoContext } + + if len(payload.Hosts) > 0 && len(payload.HostIDs) > 0 { + return nil, nil, fleet.NewInvalidArgumentError("hosts", `Only one of either "hosts" or "host_ids" can be included in the request.`) + } + filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} label, _, err := svc.ds.Label(ctx, id, filter) @@ -161,25 +173,30 @@ func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.Modi if payload.Description != nil { label.Description = *payload.Description } - if len(payload.Hosts) > 0 && label.LabelMembershipType != fleet.LabelMembershipTypeManual { - return nil, nil, fleet.NewInvalidArgumentError("hosts", "cannot provide a list of hosts for a dynamic label") - } - // use SaveLabel to update label info, and UpdateLabelMembershipByHostIDs to update membership. Approach using label - // names and ApplyLabelSpecs doesn't work for multiple hosts with the same name. - - if payload.Hosts != nil { - // get host ids for valid hosts. since this endpoint will contain hosts identified by serial - // number, there should be no duplicates - - hostIds, err := svc.ds.HostIDsByIdentifier(ctx, filter, payload.Hosts) + hostIDs := payload.HostIDs + if len(payload.Hosts) > 0 { + // If hosts were provided, convert them to IDs. + hostIDs, err = svc.ds.HostIDsByIdentifier(ctx, filter, payload.Hosts) if err != nil { return nil, nil, err } - if _, _, err := svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.ID, hostIds, filter); err != nil { + } else if payload.Hosts != nil { + // If an empry list was provided, create an empty list of IDs + // so that we can remove all hosts from the label. + hostIDs = make([]uint, 0) + } + + if len(hostIDs) > 0 && label.LabelMembershipType != fleet.LabelMembershipTypeManual { + return nil, nil, fleet.NewInvalidArgumentError("hosts", "cannot provide a list of hosts for a dynamic label") + } + + if hostIDs != nil { + if _, _, err := svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.ID, hostIDs, filter); err != nil { return nil, nil, err } } + return svc.ds.SaveLabel(ctx, label, filter) } diff --git a/server/service/labels_test.go b/server/service/labels_test.go index 6803aae2af..0cea3448f2 100644 --- a/server/service/labels_test.go +++ b/server/service/labels_test.go @@ -431,3 +431,87 @@ func TestBatchValidateLabels(t *testing.T) { }) } } + +func TestNewManualLabel(t *testing.T) { + ds := new(mock.Store) + svc, ctx := newTestService(t, ds, nil, nil) + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + + ds.NewLabelFunc = func(ctx context.Context, lbl *fleet.Label, opts ...fleet.OptionalArg) (*fleet.Label, error) { + lbl.ID = 1 + lbl.LabelMembershipType = fleet.LabelMembershipTypeManual + return lbl, nil + } + ds.HostIDsByIdentifierFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { + return []uint{99, 100}, nil + } + + t.Run("using hostnames", func(t *testing.T) { + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), labelID) + require.Equal(t, []uint{99, 100}, hostIds) + return nil, nil, nil + } + _, _, err := svc.NewLabel(ctx, fleet.LabelPayload{ + Name: "foo", + Hosts: []string{"host1", "host2"}, + }) + require.NoError(t, err) + }) + + t.Run("using IDs", func(t *testing.T) { + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), labelID) + require.Equal(t, []uint{1, 2}, hostIds) + return nil, nil, nil + } + _, _, err := svc.NewLabel(ctx, fleet.LabelPayload{ + Name: "foo", + HostIDs: []uint{1, 2}, + }) + require.NoError(t, err) + }) +} + +func TestModifyManualLabel(t *testing.T) { + ds := new(mock.Store) + svc, ctx := newTestService(t, ds, nil, nil) + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + + ds.LabelFunc = func(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + return &fleet.Label{ + ID: lid, + LabelMembershipType: fleet.LabelMembershipTypeManual, + }, nil, nil + } + ds.HostIDsByIdentifierFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { + return []uint{99, 100}, nil + } + ds.SaveLabelFunc = func(ctx context.Context, lbl *fleet.Label, filter fleet.TeamFilter) (*fleet.Label, []uint, error) { + return nil, nil, nil + } + + t.Run("using hostnames", func(t *testing.T) { + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), labelID) + require.Equal(t, []uint{99, 100}, hostIds) + return nil, nil, nil + } + _, _, err := svc.ModifyLabel(ctx, 1, fleet.ModifyLabelPayload{ + Hosts: []string{"host1", "host2"}, + }) + require.NoError(t, err) + }) + + t.Run("using IDs", func(t *testing.T) { + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), labelID) + require.Equal(t, []uint{1, 2}, hostIds) + return nil, nil, nil + } + _, _, err := svc.ModifyLabel(ctx, 1, fleet.ModifyLabelPayload{ + HostIDs: []uint{1, 2}, + }) + require.NoError(t, err) + }) +}