diff --git a/changes/36781-team-labels-backend b/changes/36781-team-labels-backend index a539eee73b..ab5b739b42 100644 --- a/changes/36781-team-labels-backend +++ b/changes/36781-team-labels-backend @@ -1 +1 @@ -* Added backend support for team labels. +- Added support for team-specific labels. Currently team-specific labels must be created via spec endpoints, used by GitOps. diff --git a/cmd/fleet/cron.go b/cmd/fleet/cron.go index 5078779719..4441b66212 100644 --- a/cmd/fleet/cron.go +++ b/cmd/fleet/cron.go @@ -1572,7 +1572,7 @@ func cronHostVitalsLabelMembership( // so we'll filter them later. labels, err := ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{ PerPage: 0, // No limit. - }) + }, false) if err != nil { return ctxerr.Wrap(ctx, err, "list labels") } diff --git a/cmd/fleet/serve_test.go b/cmd/fleet/serve_test.go index 064349d150..5782979e55 100644 --- a/cmd/fleet/serve_test.go +++ b/cmd/fleet/serve_test.go @@ -1262,7 +1262,7 @@ func TestHostVitalsLabelMembershipJob(t *testing.T) { {Name: "Vital Vince", ID: 3, HostVitalsCriteria: ptr.RawMessage(json.RawMessage(`{"vital":"owl", "value":"hoot"}`)), LabelType: fleet.LabelTypeRegular, LabelMembershipType: fleet.LabelMembershipTypeHostVitals}, } - ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Label, error) { + ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { return labels, nil } diff --git a/cmd/fleetctl/fleetctl/apply_test.go b/cmd/fleetctl/fleetctl/apply_test.go index 8da701c0ff..309ebf4443 100644 --- a/cmd/fleetctl/fleetctl/apply_test.go +++ b/cmd/fleetctl/fleetctl/apply_test.go @@ -196,9 +196,9 @@ func TestApplyTeamSpecs(t *testing.T) { return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.Len(t, labels, 1) - switch labels[0] { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.Len(t, names, 1) + switch names[0] { case fleet.BuiltinLabelMacOS14Plus: return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil case fleet.BuiltinLabelIOS: @@ -657,8 +657,8 @@ func TestApplyAppConfig(t *testing.T) { return fleet.MDMProfilesUpdates{}, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } @@ -1349,10 +1349,13 @@ func TestApplyAsGitOps(t *testing.T) { ds.DeleteMDMWindowsConfigProfileByTeamAndNameFunc = func(ctx context.Context, teamID *uint, profileName string) error { return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } + ds.SetAsideLabelsFunc = func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + return nil + } ds.SetOrUpdateMDMAppleDeclarationFunc = func(ctx context.Context, declaration *fleet.MDMAppleDeclaration) (*fleet.MDMAppleDeclaration, error) { declaration.DeclarationUUID = uuid.NewString() return declaration, nil @@ -1777,6 +1780,9 @@ func TestApplyLabels(t *testing.T) { _, ds := testing_utils.RunServerWithMockedDS(t) var appliedLabels []*fleet.LabelSpec + ds.SetAsideLabelsFunc = func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + return nil + } ds.ApplyLabelSpecsWithAuthorFunc = func(ctx context.Context, specs []*fleet.LabelSpec, authorId *uint) error { appliedLabels = specs return nil @@ -1831,7 +1837,7 @@ func TestApplyLabels(t *testing.T) { LabelType: fleet.LabelTypeBuiltIn, LabelMembershipType: fleet.LabelMembershipTypeDynamic, } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { assert.ElementsMatch(t, []string{fleet.BuiltinLabelNameUbuntuLinux}, names) return map[string]*fleet.Label{ fleet.BuiltinLabelNameUbuntuLinux: ubuntuLabel, diff --git a/cmd/fleetctl/fleetctl/delete_test.go b/cmd/fleetctl/fleetctl/delete_test.go index dfdd584768..c95d39d6f1 100644 --- a/cmd/fleetctl/fleetctl/delete_test.go +++ b/cmd/fleetctl/fleetctl/delete_test.go @@ -14,10 +14,13 @@ func TestDeleteLabel(t *testing.T) { _, ds := testing_utils.RunServerWithMockedDS(t) var deletedLabel string - ds.DeleteLabelFunc = func(ctx context.Context, name string) error { + ds.DeleteLabelFunc = func(ctx context.Context, name string, filter fleet.TeamFilter) error { deletedLabel = name return nil } + ds.LabelByNameFunc = func(ctx context.Context, name string, filter fleet.TeamFilter) (*fleet.Label, error) { + return &fleet.Label{Name: name}, nil + } name := writeTmpYml(t, `--- apiVersion: v1 diff --git a/cmd/fleetctl/fleetctl/generate_gitops.go b/cmd/fleetctl/fleetctl/generate_gitops.go index 514e523e26..0e4be976de 100644 --- a/cmd/fleetctl/fleetctl/generate_gitops.go +++ b/cmd/fleetctl/fleetctl/generate_gitops.go @@ -361,6 +361,7 @@ func (cmd *GenerateGitopsCommand) Run() error { cmd.FilesToWrite[fileName].(map[string]interface{})["agent_options"] = cmd.AppConfig.AgentOptions + // TODO gitops do this for every team other than no team // Generate labels. labels, err := cmd.generateLabels() if err != nil { @@ -1684,6 +1685,7 @@ func (cmd *GenerateGitopsCommand) generateSoftware(filePath string, teamID uint, } func (cmd *GenerateGitopsCommand) generateLabels() ([]map[string]interface{}, error) { + // TODO gitops pass team ID labels, err := cmd.Client.GetLabels() if err != nil { fmt.Fprintf(cmd.CLI.App.ErrWriter, "Error getting labels: %s\n", err) diff --git a/cmd/fleetctl/fleetctl/get_test.go b/cmd/fleetctl/fleetctl/get_test.go index 935f115efe..d08d7dde71 100644 --- a/cmd/fleetctl/fleetctl/get_test.go +++ b/cmd/fleetctl/fleetctl/get_test.go @@ -1072,7 +1072,7 @@ spec: func TestGetLabels(t *testing.T) { _, ds := testing_utils.RunServerWithMockedDS(t) - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { ID: 32, @@ -1110,6 +1110,7 @@ spec: name: label1 platform: windows query: select 1; + team_id: null --- apiVersion: v1 kind: label @@ -1121,9 +1122,10 @@ spec: name: label2 platform: linux query: select 42; + team_id: null ` - expectedJson := `{"kind":"label","apiVersion":"v1","spec":{"id":32,"name":"label1","description":"some description","query":"select 1;","platform":"windows","label_membership_type":"dynamic","hosts":null}} -{"kind":"label","apiVersion":"v1","spec":{"id":33,"name":"label2","description":"some other description","query":"select 42;","platform":"linux","label_membership_type":"dynamic","hosts":null}} + expectedJson := `{"kind":"label","apiVersion":"v1","spec":{"id":32,"name":"label1","description":"some description","query":"select 1;","platform":"windows","label_membership_type":"dynamic","hosts":null,"team_id":null}} +{"kind":"label","apiVersion":"v1","spec":{"id":33,"name":"label2","description":"some other description","query":"select 42;","platform":"linux","label_membership_type":"dynamic","hosts":null,"team_id":null}} ` assert.Equal(t, expected, RunAppForTest(t, []string{"get", "labels"})) @@ -1134,7 +1136,7 @@ spec: func TestGetLabel(t *testing.T) { _, ds := testing_utils.RunServerWithMockedDS(t) - ds.GetLabelSpecFunc = func(ctx context.Context, name string) (*fleet.LabelSpec, error) { + ds.GetLabelSpecFunc = func(ctx context.Context, filter fleet.TeamFilter, name string) (*fleet.LabelSpec, error) { if name != "label1" { return nil, nil } @@ -1158,8 +1160,9 @@ spec: name: label1 platform: windows query: select 1; + team_id: null ` - expectedJson := `{"kind":"label","apiVersion":"v1","spec":{"id":32,"name":"label1","description":"some description","query":"select 1;","platform":"windows","label_membership_type":"dynamic","hosts":null}} + expectedJson := `{"kind":"label","apiVersion":"v1","spec":{"id":32,"name":"label1","description":"some description","query":"select 1;","platform":"windows","label_membership_type":"dynamic","hosts":null,"team_id":null}} ` assert.Equal(t, expectedYaml, RunAppForTest(t, []string{"get", "label", "label1"})) @@ -2476,8 +2479,8 @@ func TestGetTeamsYAMLAndApply(t *testing.T) { ds.DeleteMDMAppleDeclarationByNameFunc = func(ctx context.Context, teamID *uint, name string) error { return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } ds.SetOrUpdateMDMAppleDeclarationFunc = func(ctx context.Context, declaration *fleet.MDMAppleDeclaration) (*fleet.MDMAppleDeclaration, error) { diff --git a/cmd/fleetctl/fleetctl/gitops.go b/cmd/fleetctl/fleetctl/gitops.go index fceb9bbd06..e40f220573 100644 --- a/cmd/fleetctl/fleetctl/gitops.go +++ b/cmd/fleetctl/fleetctl/gitops.go @@ -223,6 +223,8 @@ func gitopsCommand() *cli.Command { config.Controls = noTeamControls } + // TODO GitOps move this to have team-specific and global names + // If config.Labels is nil, it means we plan on deleting all existing labels. if config.Labels == nil { proposedLabelNames = make([]string, 0) diff --git a/cmd/fleetctl/fleetctl/gitops_test.go b/cmd/fleetctl/fleetctl/gitops_test.go index 35d6b8ce8d..6f49236d2b 100644 --- a/cmd/fleetctl/fleetctl/gitops_test.go +++ b/cmd/fleetctl/fleetctl/gitops_test.go @@ -127,7 +127,7 @@ func TestGitOpsBasicGlobalFree(t *testing.T) { savedAppConfig = config return nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return nil, nil } @@ -314,7 +314,7 @@ func TestGitOpsBasicGlobalPremium(t *testing.T) { savedAppConfig = config return nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return nil, nil } @@ -323,8 +323,8 @@ func TestGitOpsBasicGlobalPremium(t *testing.T) { enrolledSecrets = secrets return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - return map[string]uint{labels[0]: 1}, nil + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + return map[string]uint{names[0]: 1}, nil } ds.SetOrUpdateMDMAppleDeclarationFunc = func(ctx context.Context, declaration *fleet.MDMAppleDeclaration) (*fleet.MDMAppleDeclaration, error) { return &fleet.MDMAppleDeclaration{}, nil @@ -640,7 +640,7 @@ func TestGitOpsBasicTeam(t *testing.T) { ds.ListQueriesFunc = func(ctx context.Context, opts fleet.ListQueryOptions) ([]*fleet.Query, int, *fleet.PaginationMetadata, error) { return nil, 0, nil, nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return nil, nil } ds.DeleteIconsAssociatedWithTitlesWithoutInstallersFunc = func(ctx context.Context, teamID uint) error { @@ -708,9 +708,9 @@ func TestGitOpsBasicTeam(t *testing.T) { savedTeam = team return team, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.Len(t, labels, 1) - switch labels[0] { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.Len(t, names, 1) + switch names[0] { case fleet.BuiltinLabelMacOS14Plus: return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil case fleet.BuiltinLabelIOS: @@ -904,6 +904,9 @@ func TestGitOpsFullGlobal(t *testing.T) { ds.DeleteIconsAssociatedWithTitlesWithoutInstallersFunc = func(ctx context.Context, teamID uint) error { return nil } + ds.SetAsideLabelsFunc = func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + return nil + } // Policies policy := fleet.Policy{} @@ -967,7 +970,7 @@ func TestGitOpsFullGlobal(t *testing.T) { var appliedLabelSpecs []*fleet.LabelSpec var deletedLabels []string - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { Name: "a", @@ -988,12 +991,15 @@ func TestGitOpsFullGlobal(t *testing.T) { return nil } - ds.DeleteLabelFunc = func(ctx context.Context, name string) error { + ds.LabelByNameFunc = func(ctx context.Context, name string, filter fleet.TeamFilter) (*fleet.Label, error) { + return &fleet.Label{Name: name}, nil + } + ds.DeleteLabelFunc = func(ctx context.Context, name string, filter fleet.TeamFilter) error { deletedLabels = append(deletedLabels, name) return nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { return map[string]*fleet.Label{ "a": { ID: 1, @@ -1203,8 +1209,8 @@ func TestGitOpsFullTeam(t *testing.T) { ds.NewMDMAppleDeclarationFunc = func(ctx context.Context, declaration *fleet.MDMAppleDeclaration) (*fleet.MDMAppleDeclaration, error) { return declaration, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } ds.SetOrUpdateMDMAppleDeclarationFunc = func(ctx context.Context, declaration *fleet.MDMAppleDeclaration) (*fleet.MDMAppleDeclaration, error) { @@ -1647,8 +1653,8 @@ func TestGitOpsBasicGlobalAndTeam(t *testing.T) { ds.DeleteMDMAppleDeclarationByNameFunc = func(ctx context.Context, teamID *uint, name string) error { return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } ds.ListGlobalPoliciesFunc = func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) { return nil, nil } @@ -2046,8 +2052,8 @@ func TestGitOpsBasicGlobalAndNoTeam(t *testing.T) { ds.DeleteMDMAppleDeclarationByNameFunc = func(ctx context.Context, teamID *uint, name string) error { return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } ds.ListGlobalPoliciesFunc = func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) { return nil, nil } @@ -2513,6 +2519,9 @@ func TestGitOpsFullGlobalAndTeam(t *testing.T) { ds.DeleteIconsAssociatedWithTitlesWithoutInstallersFunc = func(ctx context.Context, teamID uint) error { return nil } + ds.SetAsideLabelsFunc = func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + return nil + } apnsCert, apnsKey, err := mysql.GenerateTestCertBytes(mdmtesting.NewTestMDMAppleCertTemplate()) require.NoError(t, err) @@ -2539,7 +2548,7 @@ func TestGitOpsFullGlobalAndTeam(t *testing.T) { return nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { return map[string]*fleet.Label{ "a": { ID: 1, @@ -2702,7 +2711,7 @@ func TestGitOpsCustomSettings(t *testing.T) { ds, appCfgPtr, _ := testing_utils.SetupFullGitOpsPremiumServer(t) (*appCfgPtr).MDM.EnabledAndConfigured = true (*appCfgPtr).MDM.WindowsEnabledAndConfigured = true - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { Name: "A", @@ -2731,10 +2740,10 @@ func TestGitOpsCustomSettings(t *testing.T) { "C": 4, } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { // for this test, recognize labels A, B and C (as well as the built-in macos 14+ one) ret := make(map[string]uint) - for _, lbl := range labels { + for _, lbl := range names { id, ok := labelToIDs[lbl] if ok { ret[lbl] = id @@ -2742,7 +2751,7 @@ func TestGitOpsCustomSettings(t *testing.T) { } return ret, nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { // for this test, recognize labels A, B and C (as well as the built-in macos 14+ one) ret := make(map[string]*fleet.Label) for _, lbl := range names { @@ -3873,8 +3882,8 @@ func setupAndroidCertificatesTestMocks(t *testing.T, ds *mock.Store) []*fleet.Ce } // Override LabelIDsByNameFunc to handle empty labels - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - if len(labels) == 0 { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + if len(names) == 0 { return map[string]uint{}, nil } return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil @@ -4493,7 +4502,7 @@ func TestGitOpsWindowsUpdates(t *testing.T) { ds.BatchSetScriptsFunc = func(ctx context.Context, tmID *uint, scripts []*fleet.Script) ([]fleet.ScriptResponse, error) { return []fleet.ScriptResponse{}, nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return nil, nil } ds.DeleteIconsAssociatedWithTitlesWithoutInstallersFunc = func(ctx context.Context, teamID uint) error { @@ -4561,7 +4570,7 @@ func TestGitOpsWindowsUpdates(t *testing.T) { ds.DeleteSetupExperienceScriptFunc = func(ctx context.Context, teamID *uint) error { return nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { return map[string]uint{}, nil } diff --git a/cmd/fleetctl/fleetctl/hosts_test.go b/cmd/fleetctl/fleetctl/hosts_test.go index 0b02d71085..97d82acb04 100644 --- a/cmd/fleetctl/fleetctl/hosts_test.go +++ b/cmd/fleetctl/fleetctl/hosts_test.go @@ -101,8 +101,8 @@ func TestHostsTransferByLabel(t *testing.T) { return &fleet.Team{ID: 99, Name: "team1"}, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.Equal(t, []string{"label1"}, labels) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.Equal(t, []string{"label1"}, names) return map[string]uint{"label1": uint(11)}, nil } @@ -173,8 +173,8 @@ func TestHostsTransferByStatus(t *testing.T) { return &fleet.Team{ID: 99, Name: "team1"}, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.Equal(t, []string{"label1"}, labels) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.Equal(t, []string{"label1"}, names) return map[string]uint{"label1": uint(11)}, nil } @@ -232,8 +232,8 @@ func TestHostsTransferByStatusAndSearchQuery(t *testing.T) { return &fleet.Team{ID: 99, Name: "team1"}, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.Equal(t, []string{"label1"}, labels) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.Equal(t, []string{"label1"}, names) return map[string]uint{"label1": uint(11)}, nil } diff --git a/cmd/fleetctl/fleetctl/query_test.go b/cmd/fleetctl/fleetctl/query_test.go index 4ea1ec1cc3..0edfb267f9 100644 --- a/cmd/fleetctl/fleetctl/query_test.go +++ b/cmd/fleetctl/fleetctl/query_test.go @@ -55,7 +55,7 @@ func TestSavedLiveQuery(t *testing.T) { } return nil, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { return nil, nil } ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { @@ -219,7 +219,7 @@ func TestAdHocLiveQuery(t *testing.T) { ds.HostIDsByIdentifierFunc = func(ctx context.Context, filter fleet.TeamFilter, hostIdentifiers []string) ([]uint, error) { return []uint{1234}, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { return map[string]uint{"label1": uint(1)}, nil } diff --git a/cmd/fleetctl/fleetctl/testing_utils/testing_utils.go b/cmd/fleetctl/fleetctl/testing_utils/testing_utils.go index d38474129d..20578d2d31 100644 --- a/cmd/fleetctl/fleetctl/testing_utils/testing_utils.go +++ b/cmd/fleetctl/fleetctl/testing_utils/testing_utils.go @@ -328,8 +328,8 @@ func SetupFullGitOpsPremiumServer(t *testing.T) (*mock.Store, **fleet.AppConfig, ds.IsEnrollSecretAvailableFunc = func(ctx context.Context, secret string, isNew bool, teamID *uint) (bool, error) { return true, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - require.ElementsMatch(t, labels, []string{fleet.BuiltinLabelMacOS14Plus}) + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{fleet.BuiltinLabelMacOS14Plus: 1}, nil } ds.ListGlobalPoliciesFunc = func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) { return nil, nil } @@ -722,7 +722,7 @@ func (m *MemKeyValueStore) Get(ctx context.Context, key string) (*string, error) func AddLabelMocks(ds *mock.Store) { var deletedLabels []string - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { Name: "a", @@ -742,12 +742,12 @@ func AddLabelMocks(ds *mock.Store) { return nil } - ds.DeleteLabelFunc = func(ctx context.Context, name string) error { + ds.DeleteLabelFunc = func(ctx context.Context, name string, filter fleet.TeamFilter) error { deletedLabels = append(deletedLabels, name) return nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { - return map[string]*fleet.Label{ + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { + validLabels := map[string]*fleet.Label{ "a": { ID: 1, Name: "a", @@ -756,7 +756,15 @@ func AddLabelMocks(ds *mock.Store) { ID: 2, Name: "b", }, - }, nil + } + + found := make(map[string]*fleet.Label) + for _, l := range names { + if label, ok := validLabels[l]; ok { + found[l] = label + } + } + return found, nil } } diff --git a/cmd/fleetctl/integrationtest/gitops/gitops_enterprise_integration_test.go b/cmd/fleetctl/integrationtest/gitops/gitops_enterprise_integration_test.go index 1de563fe2c..ada5e254ca 100644 --- a/cmd/fleetctl/integrationtest/gitops/gitops_enterprise_integration_test.go +++ b/cmd/fleetctl/integrationtest/gitops/gitops_enterprise_integration_test.go @@ -149,11 +149,11 @@ func (s *enterpriseIntegrationGitopsTestSuite) TearDownTest() { return err }) - lbls, err := s.DS.ListLabels(ctx, fleet.TeamFilter{User: test.UserAdmin}, fleet.ListOptions{}) + lbls, err := s.DS.ListLabels(ctx, fleet.TeamFilter{User: test.UserAdmin}, fleet.ListOptions{}, false) require.NoError(t, err) for _, lbl := range lbls { if lbl.LabelType != fleet.LabelTypeBuiltIn { - err := s.DS.DeleteLabel(ctx, lbl.Name) + err := s.DS.DeleteLabel(ctx, lbl.Name, fleet.TeamFilter{User: test.UserAdmin}) require.NoError(t, err) } } @@ -1602,7 +1602,7 @@ func (s *enterpriseIntegrationGitopsTestSuite) TestFleetGitOpsDeletesNonManagedL _ = fleetctl.RunAppForTest(t, []string{"gitops", "--config", fleetctlConfig.Name(), "-f", opsFile}) // Check label was removed successfully - result, err := s.DS.LabelIDsByName(ctx, []string{nonManagedLabel.Name}) + result, err := s.DS.LabelIDsByName(ctx, []string{nonManagedLabel.Name}, fleet.TeamFilter{}) require.NoError(t, err) require.Empty(t, result) } @@ -1999,7 +1999,7 @@ labels: s.assertRealRunOutput(t, fleetctl.RunAppForTest(t, []string{"gitops", "--config", fleetctlConfig.Name(), "-f", globalFile.Name()})) // Verify the label was created and has the correct hosts - labels, err := s.DS.LabelsByName(ctx, []string{"my-label"}) + labels, err := s.DS.LabelsByName(ctx, []string{"my-label"}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, labels, 1) label := labels["my-label"] diff --git a/cmd/fleetctl/integrationtest/gitops/software_test.go b/cmd/fleetctl/integrationtest/gitops/software_test.go index 861fd73cc0..189a237ac4 100644 --- a/cmd/fleetctl/integrationtest/gitops/software_test.go +++ b/cmd/fleetctl/integrationtest/gitops/software_test.go @@ -106,7 +106,7 @@ func TestGitOpsTeamSoftwareInstallers(t *testing.T) { }, nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { Name: "a", @@ -128,10 +128,10 @@ func TestGitOpsTeamSoftwareInstallers(t *testing.T) { "a": 2, "b": 3, } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { // for this test, recognize labels a and b (as well as the built-in macos 14+ one) ret := make(map[string]uint) - for _, lbl := range labels { + for _, lbl := range names { id, ok := labelToIDs[lbl] if ok { ret[lbl] = id @@ -272,10 +272,10 @@ func TestGitOpsNoTeamVPPPolicies(t *testing.T) { "a": 2, "b": 3, } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { // for this test, recognize labels a and b (as well as the built-in macos 14+ one) ret := make(map[string]uint) - for _, lbl := range labels { + for _, lbl := range names { id, ok := labelToIDs[lbl] if ok { ret[lbl] = id @@ -283,7 +283,7 @@ func TestGitOpsNoTeamVPPPolicies(t *testing.T) { } return ret, nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { return map[string]*fleet.Label{ "a": { ID: 1, @@ -295,6 +295,9 @@ func TestGitOpsNoTeamVPPPolicies(t *testing.T) { }, }, nil } + ds.SetAsideLabelsFunc = func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + return nil + } ds.GetSoftwareCategoryIDsFunc = func(ctx context.Context, names []string) ([]uint, error) { return []uint{}, nil } @@ -387,7 +390,7 @@ func TestGitOpsNoTeamSoftwareInstallers(t *testing.T) { Teams: nil, }, nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { Name: "a", @@ -403,15 +406,18 @@ func TestGitOpsNoTeamSoftwareInstallers(t *testing.T) { }, }, nil } + ds.SetAsideLabelsFunc = func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + return nil + } labelToIDs := map[string]uint{ fleet.BuiltinLabelMacOS14Plus: 1, "a": 2, "b": 3, } - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { // for this test, recognize labels a and b (as well as the built-in macos 14+ one) ret := make(map[string]uint) - for _, lbl := range labels { + for _, lbl := range names { id, ok := labelToIDs[lbl] if ok { ret[lbl] = id @@ -522,7 +528,7 @@ func TestGitOpsTeamVPPApps(t *testing.T) { }, nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return []*fleet.LabelSpec{ { Name: "label 1", @@ -543,15 +549,15 @@ func TestGitOpsTeamVPPApps(t *testing.T) { } found := make(map[string]uint) - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { - for _, l := range labels { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { + for _, l := range names { if id, ok := c.expectedLabels[l]; ok { found[l] = id } } return found, nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { found2 := make(map[string]*fleet.Label) for _, l := range names { if id, ok := c.expectedLabels[l]; ok { diff --git a/ee/server/service/mdm.go b/ee/server/service/mdm.go index 5cd3f81dd5..a0516e50a8 100644 --- a/ee/server/service/mdm.go +++ b/ee/server/service/mdm.go @@ -1314,7 +1314,7 @@ func (svc *Service) mdmAppleEditedAppleOSUpdates(ctx context.Context, teamID *ui d := fleet.NewMDMAppleDeclaration(rawDecl, teamID, osUpdatesProfileName, softwareUpdateType, softwareUpdateIdentifier) // Associate the profile with the built-in label to ensure that the profile is applied to the targeted devices. - lblIDs, err := svc.ds.LabelIDsByName(ctx, []string{labelName}) + lblIDs, err := svc.ds.LabelIDsByName(ctx, []string{labelName}, fleet.TeamFilter{}) // built-in labels are global if err != nil { return err } diff --git a/ee/server/service/mdm_external_test.go b/ee/server/service/mdm_external_test.go index 8e9a792dc6..152cae0d9f 100644 --- a/ee/server/service/mdm_external_test.go +++ b/ee/server/service/mdm_external_test.go @@ -230,7 +230,7 @@ func TestGetOrCreatePreassignTeam(t *testing.T) { ds.GetMDMAppleSetupAssistantFunc = func(ctx context.Context, teamID *uint) (*fleet.MDMAppleSetupAssistant, error) { return nil, errors.New("not implemented") } - ds.LabelIDsByNameFunc = func(ctx context.Context, names []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { require.Len(t, names, 1) require.ElementsMatch(t, names, []string{fleet.BuiltinLabelMacOS14Plus}) return map[string]uint{names[0]: 1}, nil diff --git a/ee/server/service/teams.go b/ee/server/service/teams.go index 00207042c7..78aff44424 100644 --- a/ee/server/service/teams.go +++ b/ee/server/service/teams.go @@ -493,7 +493,7 @@ func (svc *Service) ModifyTeamAgentOptions(ctx context.Context, teamID uint, tea } if teamOptions != nil { - if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, teamOptions, true); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, teamOptions, true, teamID); err != nil { err = fleet.SuggestAgentOptionsCorrection(err) err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOptions.Force && !applyOptions.DryRun { @@ -1040,8 +1040,13 @@ func (svc *Service) ApplyTeamSpecs(ctx context.Context, specs []*fleet.TeamSpec, } } + var tmID uint + if team != nil { + tmID = team.ID + } + if len(spec.AgentOptions) > 0 && !bytes.Equal(spec.AgentOptions, jsonNull) { - if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, spec.AgentOptions, true); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, spec.AgentOptions, true, tmID); err != nil { err = fleet.SuggestAgentOptionsCorrection(err) err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOpts.Force && !applyOpts.DryRun { diff --git a/frontend/__mocks__/labelsMock.ts b/frontend/__mocks__/labelsMock.ts index 145ce11740..996e51807a 100644 --- a/frontend/__mocks__/labelsMock.ts +++ b/frontend/__mocks__/labelsMock.ts @@ -14,6 +14,7 @@ const DEFAULT_LABEL_MOCK: ILabel = { display_text: "test macsss", count: 0, host_ids: null, + team_id: null, criteria: { vital: "end_user_idp_department", value: " IT admins", diff --git a/frontend/components/LiveQuery/SelectTargets.tsx b/frontend/components/LiveQuery/SelectTargets.tsx index 7634d8fa35..106491b6f6 100644 --- a/frontend/components/LiveQuery/SelectTargets.tsx +++ b/frontend/components/LiveQuery/SelectTargets.tsx @@ -166,7 +166,8 @@ const SelectTargets = ({ isLoading: isLoadingLabels, } = useQuery( ["labelsSummary"], - labelsAPI.summary, + // labels API automatically filters to global/team labels user has access to, so no need for additional params + () => labelsAPI.summary(), { select: (data) => data.labels, staleTime: STALE_TIME, // TODO: confirm diff --git a/frontend/interfaces/label.ts b/frontend/interfaces/label.ts index fbce8d852d..61789b3131 100644 --- a/frontend/interfaces/label.ts +++ b/frontend/interfaces/label.ts @@ -80,6 +80,8 @@ export interface ILabel extends ILabelSummary { slug?: string; // e.g., "labels/13" | "online" target_type?: string; // e.g., "labels" author_id?: number; + team_id: number | null; + team_name?: string | null; // returned on individual label endpoints but not list endpoints label_membership_type: LabelMembershipType; diff --git a/frontend/pages/ManageControlsPage/OSSettings/cards/CustomSettings/components/ProfileUploader/components/AddProfileModal/AddProfileModal.tsx b/frontend/pages/ManageControlsPage/OSSettings/cards/CustomSettings/components/ProfileUploader/components/AddProfileModal/AddProfileModal.tsx index ab8c356a7b..16dbc4348e 100644 --- a/frontend/pages/ManageControlsPage/OSSettings/cards/CustomSettings/components/ProfileUploader/components/AddProfileModal/AddProfileModal.tsx +++ b/frontend/pages/ManageControlsPage/OSSettings/cards/CustomSettings/components/ProfileUploader/components/AddProfileModal/AddProfileModal.tsx @@ -130,7 +130,10 @@ const AddProfileModal = ({ isError: isErrorLabels, } = useQuery( ["custom_labels"], - () => labelsAPI.summary().then((res) => getCustomLabels(res.labels)), + () => + labelsAPI + .summary(currentTeamId) + .then((res) => getCustomLabels(res.labels)), { enabled: isPremiumTier, refetchOnWindowFocus: false, diff --git a/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareAppStore/SoftwareAppStoreVpp/SoftwareAppStoreVpp.tsx b/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareAppStore/SoftwareAppStoreVpp/SoftwareAppStoreVpp.tsx index 2eed6f9a65..5c3a90af20 100644 --- a/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareAppStore/SoftwareAppStoreVpp/SoftwareAppStoreVpp.tsx +++ b/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareAppStore/SoftwareAppStoreVpp/SoftwareAppStoreVpp.tsx @@ -118,7 +118,10 @@ const SoftwareAppStoreVpp = ({ isError: isErrorLabels, } = useQuery( ["custom_labels"], - () => labelsAPI.summary().then((res) => getCustomLabels(res.labels)), + () => + labelsAPI + .summary(currentTeamId) + .then((res) => getCustomLabels(res.labels)), { ...DEFAULT_USE_QUERY_OPTIONS, diff --git a/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareCustomPackage/SoftwareCustomPackage.tsx b/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareCustomPackage/SoftwareCustomPackage.tsx index c3b21a41e6..03b03f2f9a 100644 --- a/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareCustomPackage/SoftwareCustomPackage.tsx +++ b/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareCustomPackage/SoftwareCustomPackage.tsx @@ -63,7 +63,10 @@ const SoftwareCustomPackage = ({ isError: isErrorLabels, } = useQuery( ["custom_labels"], - () => labelsAPI.summary().then((res) => getCustomLabels(res.labels)), + () => + labelsAPI + .summary(currentTeamId) + .then((res) => getCustomLabels(res.labels)), { ...DEFAULT_USE_QUERY_OPTIONS, enabled: isPremiumTier, diff --git a/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareFleetMaintained/FleetMaintainedAppDetailsPage/FleetMaintainedAppDetailsPage.tsx b/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareFleetMaintained/FleetMaintainedAppDetailsPage/FleetMaintainedAppDetailsPage.tsx index 777eeb4f55..2c018e75ce 100644 --- a/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareFleetMaintained/FleetMaintainedAppDetailsPage/FleetMaintainedAppDetailsPage.tsx +++ b/frontend/pages/SoftwarePage/SoftwareAddPage/SoftwareFleetMaintained/FleetMaintainedAppDetailsPage/FleetMaintainedAppDetailsPage.tsx @@ -178,7 +178,10 @@ const FleetMaintainedAppDetailsPage = ({ isError: isErrorLabels, } = useQuery( ["custom_labels"], - () => labelsAPI.summary().then((res) => getCustomLabels(res.labels)), + () => + labelsAPI + .summary(parseInt(teamId || "0", 10)) + .then((res) => getCustomLabels(res.labels)), { ...DEFAULT_USE_QUERY_OPTIONS, diff --git a/frontend/pages/SoftwarePage/SoftwareTitleDetailsPage/EditSoftwareModal/EditSoftwareModal.tsx b/frontend/pages/SoftwarePage/SoftwareTitleDetailsPage/EditSoftwareModal/EditSoftwareModal.tsx index f8f39d5da8..1f9333015c 100644 --- a/frontend/pages/SoftwarePage/SoftwareTitleDetailsPage/EditSoftwareModal/EditSoftwareModal.tsx +++ b/frontend/pages/SoftwarePage/SoftwareTitleDetailsPage/EditSoftwareModal/EditSoftwareModal.tsx @@ -121,7 +121,7 @@ const EditSoftwareModal = ({ const { data: labels } = useQuery( ["custom_labels"], - () => labelsAPI.summary().then((res) => getCustomLabels(res.labels)), + () => labelsAPI.summary(teamId).then((res) => getCustomLabels(res.labels)), { ...DEFAULT_USE_QUERY_OPTIONS, } diff --git a/frontend/pages/hosts/ManageHostsPage/ManageHostsPage.tsx b/frontend/pages/hosts/ManageHostsPage/ManageHostsPage.tsx index 4353e7ef5b..d4a9555ab8 100644 --- a/frontend/pages/hosts/ManageHostsPage/ManageHostsPage.tsx +++ b/frontend/pages/hosts/ManageHostsPage/ManageHostsPage.tsx @@ -407,7 +407,7 @@ const ManageHostsPage = ({ ILabelsResponse, Error, ILabel[] - >(["labels"], () => labelsAPI.loadAll(), { + >(["labels", currentTeamId], () => labelsAPI.loadAll(currentTeamId), { enabled: isRouteOk, select: (data: ILabelsResponse) => data.labels, }); diff --git a/frontend/pages/labels/EditLabelPage/EditLabelPage.tests.tsx b/frontend/pages/labels/EditLabelPage/EditLabelPage.tests.tsx index e20a403e57..62ca04ca1b 100644 --- a/frontend/pages/labels/EditLabelPage/EditLabelPage.tests.tsx +++ b/frontend/pages/labels/EditLabelPage/EditLabelPage.tests.tsx @@ -55,9 +55,8 @@ describe("EditLabelPage", () => { expect(queryLabel).toBeInTheDocument(); expect(platformLabel).toBeInTheDocument(); - expect(screen.getByText(/Label queries are immutable/)).toBeInTheDocument(); expect( - screen.getByText(/Label platforms are immutable/) + screen.getByText(/Label queries and platforms are immutable/) ).toBeInTheDocument(); }); diff --git a/frontend/pages/labels/EditLabelPage/EditLabelPage.tsx b/frontend/pages/labels/EditLabelPage/EditLabelPage.tsx index 5aae244f7d..8f0f8319aa 100644 --- a/frontend/pages/labels/EditLabelPage/EditLabelPage.tsx +++ b/frontend/pages/labels/EditLabelPage/EditLabelPage.tsx @@ -12,6 +12,7 @@ import { DEFAULT_USE_QUERY_OPTIONS } from "utilities/constants"; import { ILabel } from "interfaces/label"; import { IHost } from "interfaces/host"; import { NotificationContext } from "context/notification"; +import { AppContext } from "context/app"; import MainContent from "components/MainContent"; import Spinner from "components/Spinner"; @@ -21,6 +22,7 @@ import DynamicLabelForm from "../components/DynamicLabelForm"; import ManualLabelForm from "../components/ManualLabelForm"; import { IDynamicLabelFormData } from "../components/DynamicLabelForm/DynamicLabelForm"; import { IManualLabelFormData } from "../components/ManualLabelForm/ManualLabelForm"; +import { hasEditPermission } from "../ManageLabelsPage/LabelsTable/LabelsTableConfig"; const baseClass = "edit-label-page"; @@ -35,6 +37,7 @@ type IEditLabelPageProps = RouteComponentProps< const EditLabelPage = ({ routeParams, router }: IEditLabelPageProps) => { const { renderFlash } = useContext(NotificationContext); + const { currentUser } = useContext(AppContext); const labelId = parseInt(routeParams.label_id, 10); @@ -43,7 +46,7 @@ const EditLabelPage = ({ routeParams, router }: IEditLabelPageProps) => { isLoading: isLoadingLabel, isError: isErrorLabel, } = useQuery( - ["label", labelId], + ["label", labelId, currentUser], () => labelsAPI.getLabel(labelId), { ...DEFAULT_USE_QUERY_OPTIONS, @@ -57,6 +60,14 @@ const EditLabelPage = ({ routeParams, router }: IEditLabelPageProps) => { ); router.replace(PATHS.MANAGE_LABELS); } + + if (currentUser && !hasEditPermission(currentUser, data)) { + renderFlash( + "error", + "You do not have permission to edit this label." + ); + router.replace(PATHS.MANAGE_LABELS); + } }, } ); @@ -118,6 +129,7 @@ const EditLabelPage = ({ routeParams, router }: IEditLabelPageProps) => { defaultDescription={label.description} defaultQuery={label.query} defaultPlatform={label.platform} + teamName={label.team_name || null} isEditing onSave={onUpdateLabel} onCancel={onCancelEdit} @@ -128,6 +140,7 @@ const EditLabelPage = ({ routeParams, router }: IEditLabelPageProps) => { defaultName={label.name} defaultDescription={label.description} defaultTargetedHosts={targetedHosts} + teamName={label.team_name || null} onSave={onUpdateLabel} onCancel={onCancelEdit} /> diff --git a/frontend/pages/labels/ManageLabelsPage/LabelsTable/LabelsTableConfig.tsx b/frontend/pages/labels/ManageLabelsPage/LabelsTable/LabelsTableConfig.tsx index 3149dfae6c..4f97f0ead3 100644 --- a/frontend/pages/labels/ManageLabelsPage/LabelsTable/LabelsTableConfig.tsx +++ b/frontend/pages/labels/ManageLabelsPage/LabelsTable/LabelsTableConfig.tsx @@ -7,6 +7,8 @@ import { isGlobalAdmin, isGlobalMaintainer, isAnyTeamMaintainerOrTeamAdmin, + isTeamAdmin, + isTeamMaintainer, } from "utilities/permissions/permissions"; import { IUser } from "interfaces/user"; import HeaderCell from "components/TableContainer/DataTable/HeaderCell"; @@ -50,6 +52,21 @@ interface IDataColumn { sortType?: string; } +const hasEditPermission = (currentUser: IUser, label: ILabel): boolean => { + return ( + // global permissions + isGlobalAdmin(currentUser) || + isGlobalMaintainer(currentUser) || + // author permission + (label.author_id === currentUser.id && + isAnyTeamMaintainerOrTeamAdmin(currentUser)) || + // team permission + (label.team_id != null && + (isTeamAdmin(currentUser, label.team_id) || + isTeamMaintainer(currentUser, label.team_id))) + ); +}; + const generateActionDropdownOptions = ( currentUser: IUser, label: ILabel @@ -62,14 +79,7 @@ const generateActionDropdownOptions = ( }, ]; - const hasGlobalWritePermission = - isGlobalAdmin(currentUser) || isGlobalMaintainer(currentUser); - - const hasLabelAuthorWritePermission = - isAnyTeamMaintainerOrTeamAdmin(currentUser) && - label.author_id === currentUser.id; - - if (hasGlobalWritePermission || hasLabelAuthorWritePermission) { + if (hasEditPermission(currentUser, label)) { if (label.label_membership_type !== "host_vitals") { options.push({ label: "Edit", @@ -164,4 +174,4 @@ const generateTableHeaders = ( const generateDataSet = (labels: ILabel[]) => labels.filter((label) => label.label_type !== "builtin"); -export { generateTableHeaders, generateDataSet }; +export { generateTableHeaders, generateDataSet, hasEditPermission }; diff --git a/frontend/pages/labels/NewLabelPage/NewLabelPage.tsx b/frontend/pages/labels/NewLabelPage/NewLabelPage.tsx index a3ce5583b0..45c46ba6c0 100644 --- a/frontend/pages/labels/NewLabelPage/NewLabelPage.tsx +++ b/frontend/pages/labels/NewLabelPage/NewLabelPage.tsx @@ -321,8 +321,13 @@ const NewLabelPage = ({ await labelsAPI.create(formData); router.push(PATHS.MANAGE_LABELS); renderFlash("success", "Label added successfully."); - } catch { - renderFlash("error", "Couldn't add label. Please try again."); + } catch (error) { + renderFlash( + "error", + (error as { status: number }).status === 409 + ? "A label with this name already exists." + : "Couldn't add label. Please try again." + ); } setIsUpdating(false); }; diff --git a/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tests.tsx b/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tests.tsx index 13cee290fa..0d79fa854e 100644 --- a/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tests.tsx +++ b/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tests.tsx @@ -8,7 +8,7 @@ import DynamicLabelForm from "./DynamicLabelForm"; describe("DynamicLabelForm", () => { it("should render the Fleet Ace and Select Platform input", () => { - render(); + render(); expect(screen.getByText("Query")).toBeInTheDocument(); expect(screen.getByText("All platforms")).toBeInTheDocument(); @@ -28,6 +28,7 @@ describe("DynamicLabelForm", () => { onCancel={noop} defaultQuery={query} defaultPlatform={platform} + teamName={null} /> ); diff --git a/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tsx b/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tsx index e6cf13bbbe..be9b128dd6 100644 --- a/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tsx +++ b/frontend/pages/labels/components/DynamicLabelForm/DynamicLabelForm.tsx @@ -13,9 +13,6 @@ import PlatformField from "../PlatformField"; const baseClass = "dynamic-label-form"; -const IMMUTABLE_QUERY_HELP_TEXT = - "Label queries are immutable. To change the query, delete this label and create a new one."; - export interface IDynamicLabelFormData { name: string; description: string; @@ -32,6 +29,7 @@ interface IDynamicLabelFormProps { isEditing?: boolean; onOpenSidebar?: () => void; onOsqueryTableSelect?: (tableName: string) => void; + teamName: string | null; onSave: (formData: IDynamicLabelFormData) => void; onCancel: () => void; } @@ -45,6 +43,7 @@ const DynamicLabelForm = ({ showOpenSidebarButton = false, onOpenSidebar, onOsqueryTableSelect, + teamName, onSave, onCancel, }: IDynamicLabelFormProps) => { @@ -120,8 +119,14 @@ const DynamicLabelForm = ({ { it("should validate the name to be required", async () => { const { user } = renderWithSetup( - + ); const nameInput = screen.getByLabelText("Name"); @@ -30,6 +35,8 @@ describe("LabelForm", () => { } /> ); @@ -40,7 +47,12 @@ describe("LabelForm", () => { it("should pass up the form data when the form is submitted and valid", async () => { const onSave = jest.fn(); const { user } = renderWithSetup( - + ); const nameValue = "Test Name"; diff --git a/frontend/pages/labels/components/LabelForm/LabelForm.tsx b/frontend/pages/labels/components/LabelForm/LabelForm.tsx index eb9f7f155e..044c8c74b7 100644 --- a/frontend/pages/labels/components/LabelForm/LabelForm.tsx +++ b/frontend/pages/labels/components/LabelForm/LabelForm.tsx @@ -5,6 +5,7 @@ import validate_presence from "components/forms/validators/validate_presence"; // @ts-ignore import InputField from "components/forms/fields/InputField"; import Button from "components/buttons/Button"; +import TeamNameField from "../TeamNameField/TeamNameField"; export interface ILabelFormData { name: string; @@ -16,19 +17,38 @@ interface ILabelFormProps { defaultDescription?: string; additionalFields?: ReactNode; isUpdatingLabel?: boolean; + teamName: string | null; onCancel: () => void; + immutableFields: string[]; onSave: (formData: ILabelFormData, isValid: boolean) => void; } const baseClass = "label-form"; +const generateDescriptionHelpText = (immutableFields: string[]) => { + if (immutableFields.length === 0) { + return ""; + } + + const SUFFIX = + "are immutable. To make changes, delete this label and create a new one."; + + return immutableFields.length === 1 + ? `Label ${immutableFields[0]} ${SUFFIX}` + : `Label ${immutableFields + .slice(0, -1) + .join(", ")} and ${immutableFields.pop()} ${SUFFIX}`; +}; + const LabelForm = ({ defaultName = "", defaultDescription = "", additionalFields, isUpdatingLabel, + teamName, onCancel, onSave, + immutableFields, }: ILabelFormProps) => { const [name, setName] = useState(defaultName); const [description, setDescription] = useState(defaultDescription); @@ -75,6 +95,12 @@ const LabelForm = ({ type="textarea" placeholder="Label description (optional)" /> + {immutableFields.length > 0 ? ( + + {generateDescriptionHelpText(immutableFields)} + + ) : null} + {teamName ? : null} {additionalFields}
) : ( - + <>

{platform ? PLATFORM_STRINGS[platform] : "All platforms"}

diff --git a/frontend/pages/labels/components/TeamNameField/TeamNameField.tsx b/frontend/pages/labels/components/TeamNameField/TeamNameField.tsx new file mode 100644 index 0000000000..a671be48e3 --- /dev/null +++ b/frontend/pages/labels/components/TeamNameField/TeamNameField.tsx @@ -0,0 +1,21 @@ +import React from "react"; + +import FormField from "components/forms/FormField"; + +const baseClass = "team-name-field"; + +interface ITeamNameFieldProps { + name: string; +} + +const TeamNameField = ({ name }: ITeamNameFieldProps) => { + return ( +
+ +

{name}

+
+
+ ); +}; + +export default TeamNameField; diff --git a/frontend/pages/labels/components/TeamNameField/index.ts b/frontend/pages/labels/components/TeamNameField/index.ts new file mode 100644 index 0000000000..ef92c621e9 --- /dev/null +++ b/frontend/pages/labels/components/TeamNameField/index.ts @@ -0,0 +1 @@ +export { default } from "./TeamNameField"; diff --git a/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tests.tsx b/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tests.tsx index 9f36323e60..de56d04c59 100644 --- a/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tests.tsx +++ b/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tests.tsx @@ -8,6 +8,7 @@ import userEvent from "@testing-library/user-event"; import createMockPolicy from "__mocks__/policyMock"; import createMockUser from "__mocks__/userMock"; import createMockConfig from "__mocks__/configMock"; +import { createMockTeamSummary } from "__mocks__/teamMock"; import { ILabelSummary } from "interfaces/label"; import PolicyForm from "./PolicyForm"; @@ -305,6 +306,7 @@ describe("PolicyForm - component", () => { context: { app: { currentUser: createMockUser(), + currentTeam: createMockTeamSummary(), isGlobalObserver: false, isGlobalAdmin: true, isGlobalMaintainer: false, diff --git a/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tsx b/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tsx index 977c8fcc80..6e4e2d3e23 100644 --- a/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tsx +++ b/frontend/pages/policies/PolicyPage/components/PolicyForm/PolicyForm.tsx @@ -167,12 +167,21 @@ const PolicyForm = ({ config, } = useContext(AppContext); - const { - data: { labels } = { labels: [] }, - isFetching: isFetchingLabels, - } = useQuery( - ["custom_labels"], - () => labelsAPI.summary(), + const { data: { labels } = { labels: [] } } = useQuery< + ILabelsSummaryResponse, + Error + >( + ["custom_labels", currentTeam], + () => { + // Wait for the current team to load from context before pulling labels, otherwise on a page load + // directly on the policies new/edit page this gets called with currentTeam not set, then again + // with the correct team value. If we don't trigger on currentTeam changes we'll just start with a + // null team ID here and never populate with the correct team unless we navigate from another page + // where team context is already set prior to navigation. + return !currentTeam + ? ({ labels: [] } as ILabelsSummaryResponse) + : labelsAPI.summary(currentTeam?.id, true); + }, { ...DEFAULT_USE_QUERY_OPTIONS, enabled: isPremiumTier, diff --git a/frontend/pages/queries/edit/components/EditQueryForm/EditQueryForm.tsx b/frontend/pages/queries/edit/components/EditQueryForm/EditQueryForm.tsx index 1ccd042e36..0079d219a9 100644 --- a/frontend/pages/queries/edit/components/EditQueryForm/EditQueryForm.tsx +++ b/frontend/pages/queries/edit/components/EditQueryForm/EditQueryForm.tsx @@ -244,7 +244,8 @@ const EditQueryForm = ({ isFetching: isFetchingLabels, } = useQuery( ["custom_labels"], - () => labelsAPI.summary(), + // All-teams queries can only be assigned global labels + () => labelsAPI.summary(currentTeamId, true), { ...DEFAULT_USE_QUERY_OPTIONS, enabled: isPremiumTier, diff --git a/frontend/pages/queries/edit/components/SaveNewQueryModal/SaveNewQueryModal.tsx b/frontend/pages/queries/edit/components/SaveNewQueryModal/SaveNewQueryModal.tsx index 0a6956111f..60b6a04fa0 100644 --- a/frontend/pages/queries/edit/components/SaveNewQueryModal/SaveNewQueryModal.tsx +++ b/frontend/pages/queries/edit/components/SaveNewQueryModal/SaveNewQueryModal.tsx @@ -107,7 +107,7 @@ const SaveNewQueryModal = ({ isFetching: isFetchingLabels, } = useQuery( ["custom_labels"], - () => labelsAPI.summary(), + () => labelsAPI.summary(apiTeamIdForQuery, true), { ...DEFAULT_USE_QUERY_OPTIONS, enabled: isPremiumTier, diff --git a/frontend/services/entities/labels.ts b/frontend/services/entities/labels.ts index 0d7ef536f4..1eb0880f56 100644 --- a/frontend/services/entities/labels.ts +++ b/frontend/services/entities/labels.ts @@ -110,12 +110,19 @@ export default { return sendRequest("DELETE", path); }, - loadAll: async (includeHostCounts = false): Promise => { + loadAll: async (teamID: number | null = null): Promise => { const { LABELS } = endpoints; const queryStringParams = { - include_host_counts: includeHostCounts, + include_host_counts: false, + team_id: null as null | number | string, }; + if (teamID === 0) { + queryStringParams.team_id = "global"; + } else if (teamID !== null && teamID > 0) { + // filter out "all teams" -1 + queryStringParams.team_id = teamID; + } const queryString = buildQueryStringFromParams(queryStringParams); const path = `${LABELS}?${queryString}`; @@ -128,10 +135,27 @@ export default { return Promise.reject(error); } }, - summary: (): Promise => { + summary: ( + teamID: number | null = null, + treatAllTeamsAsGlobalOnly = false + ): Promise => { const { LABELS_SUMMARY } = endpoints; - return sendRequest("GET", LABELS_SUMMARY); + const queryStringParams = { + team_id: null as null | number | string, + }; + if (teamID === 0 || (teamID === -1 && treatAllTeamsAsGlobalOnly)) { + queryStringParams.team_id = "global"; + } else if (teamID !== null && teamID > 0) { + queryStringParams.team_id = teamID; + } + + const queryString = buildQueryStringFromParams(queryStringParams); + + return sendRequest( + "GET", + queryString ? `${LABELS_SUMMARY}?${queryString}` : LABELS_SUMMARY + ); }, update: async ( diff --git a/server/authz/policy.rego b/server/authz/policy.rego index 1dd88afcf6..abadf65fe2 100644 --- a/server/authz/policy.rego +++ b/server/authz/policy.rego @@ -13,6 +13,7 @@ import input.subject read := "read" list := "list" write := "write" +create := "create" # only for labels right now write_host_label := "write_host_label" cancel_host_activity := "cancel_host_activity" @@ -358,12 +359,13 @@ allow { action == read } -# Team admins, maintainers, observer_plus, observers and gitops can read labels. +# Team admins, maintainers, observer_plus, observers and gitops can read global labels. allow { - object.type == "label" + object.type == "label" + is_null(object.team_id) # If role is admin, maintainer, observer_plus or observer on any team. team_role(subject, subject.teams[_].id) == [admin, maintainer, observer_plus, observer, gitops][_] - action == read + action == read } # Global admins, maintainers and gitops can write labels @@ -373,15 +375,54 @@ allow { action == write } - -# Team admins and maintainers can write labels +# Global admins, maintainers and gitops can write labels allow { object.type == "label" - # If role is admin, maintainer or gitops on any team. - team_role(subject, subject.teams[_].id) == [admin, maintainer][_] + subject.global_role == [admin, maintainer, gitops][_] + action == create +} + +# Team admins, maintainers, and gitops can create global labels +allow { + object.type == "label" + is_null(object.team_id) + team_role(subject, subject.teams[_].id) == [admin, maintainer, gitops][_] + action == create +} + +# Team admins, maintainers, and gitops can write global labels they created +allow { + object.type == "label" + is_null(object.team_id) + not is_null(object.author_id) + object.author_id = subject.id + team_role(subject, subject.teams[_].id) == [admin, maintainer, gitops][_] action == write } +# Team users can read labels on their team +allow { + object.type == "label" + not is_null(object.team_id) + team_role(subject, object.team_id) == [admin, maintainer, gitops, observer_plus, observer][_] + action == read +} + +# Team admins, maintainers, and gitops can write labels on their team +allow { + object.type == "label" + not is_null(object.team_id) + team_role(subject, object.team_id) == [admin, maintainer, gitops][_] + action == write +} + +# Team admins, maintainers, and gitops can create labels on their team +allow { + object.type == "label" + not is_null(object.team_id) + team_role(subject, object.team_id) == [admin, maintainer, gitops][_] + action == create +} ## # Queries diff --git a/server/authz/policy_test.go b/server/authz/policy_test.go index bd5494bcbe..51e979f113 100644 --- a/server/authz/policy_test.go +++ b/server/authz/policy_test.go @@ -24,6 +24,7 @@ const ( selectiveRead = fleet.ActionSelectiveRead selectiveList = fleet.ActionSelectiveList cancelHostActivity = fleet.ActionCancelHostActivity + create = fleet.ActionCreate ) var auth *Authorizer @@ -456,36 +457,88 @@ func TestAuthorizeLabel(t *testing.T) { t.Parallel() label := &fleet.Label{} + authoredLabel := func(user *fleet.User) fleet.Label { + return fleet.Label{AuthorID: &user.ID} + } + sameTeamLabel := func(user *fleet.User) fleet.Label { + return fleet.Label{TeamID: &user.Teams[0].ID} + } + differentTeamLabel := func(user *fleet.User) fleet.Label { + return fleet.Label{TeamID: ptr.Uint(999)} + } + runTestCases(t, []authTestCase{ {user: nil, object: label, action: read, allow: false}, {user: nil, object: label, action: write, allow: false}, + {user: nil, object: label, action: create, allow: false}, {user: test.UserNoRoles, object: label, action: read, allow: false}, {user: test.UserNoRoles, object: label, action: write, allow: false}, + {user: test.UserNoRoles, object: label, action: create, allow: false}, {user: test.UserAdmin, object: label, action: read, allow: true}, {user: test.UserAdmin, object: label, action: write, allow: true}, + {user: test.UserAdmin, object: label, action: create, allow: true}, {user: test.UserMaintainer, object: label, action: read, allow: true}, {user: test.UserMaintainer, object: label, action: write, allow: true}, + {user: test.UserMaintainer, object: label, action: create, allow: true}, {user: test.UserObserver, object: label, action: read, allow: true}, {user: test.UserObserver, object: label, action: write, allow: false}, + {user: test.UserObserver, object: label, action: create, allow: false}, {user: test.UserObserverPlus, object: label, action: read, allow: true}, {user: test.UserObserverPlus, object: label, action: write, allow: false}, + {user: test.UserObserverPlus, object: label, action: create, allow: false}, {user: test.UserGitOps, object: label, action: read, allow: true}, {user: test.UserGitOps, object: label, action: write, allow: true}, + {user: test.UserGitOps, object: label, action: create, allow: true}, + + {user: test.UserTeamObserverTeam1, object: label, action: read, allow: true}, + {user: test.UserTeamObserverTeam1, object: label, action: write, allow: false}, + {user: test.UserTeamObserverTeam1, object: label, action: create, allow: false}, + + {user: test.UserTeamObserverPlusTeam1, object: label, action: read, allow: true}, + {user: test.UserTeamObserverPlusTeam1, object: label, action: write, allow: false}, + {user: test.UserTeamObserverPlusTeam1, object: label, action: create, allow: false}, {user: test.UserTeamGitOpsTeam1, object: label, action: read, allow: true}, {user: test.UserTeamGitOpsTeam1, object: label, action: write, allow: false}, + {user: test.UserTeamGitOpsTeam1, object: label, action: create, allow: true}, {user: test.UserTeamAdminTeam1, object: label, action: read, allow: true}, - {user: test.UserTeamAdminTeam1, object: label, action: write, allow: true}, + {user: test.UserTeamAdminTeam1, object: label, action: write, allow: false}, + {user: test.UserTeamAdminTeam1, object: label, action: create, allow: true}, {user: test.UserTeamMaintainerTeam1, object: label, action: read, allow: true}, - {user: test.UserTeamMaintainerTeam1, object: label, action: write, allow: true}, + {user: test.UserTeamMaintainerTeam1, object: label, action: write, allow: false}, + {user: test.UserTeamMaintainerTeam1, object: label, action: create, allow: true}, + + {user: test.UserTeamObserverTeam1, object: authoredLabel(test.UserTeamObserverTeam1), action: read, allow: true}, + {user: test.UserTeamObserverTeam1, object: authoredLabel(test.UserTeamObserverTeam1), action: write, allow: false}, + {user: test.UserTeamObserverTeam1, object: authoredLabel(test.UserTeamObserverTeam1), action: create, allow: false}, + + {user: test.UserTeamGitOpsTeam1, object: authoredLabel(test.UserTeamGitOpsTeam1), action: read, allow: true}, + {user: test.UserTeamGitOpsTeam1, object: authoredLabel(test.UserTeamGitOpsTeam1), action: write, allow: true}, + {user: test.UserTeamGitOpsTeam1, object: authoredLabel(test.UserTeamGitOpsTeam1), action: create, allow: true}, + + {user: test.UserTeamObserverTeam1, object: sameTeamLabel(test.UserTeamObserverTeam1), action: read, allow: true}, + {user: test.UserTeamObserverTeam1, object: sameTeamLabel(test.UserTeamObserverTeam1), action: write, allow: false}, + {user: test.UserTeamObserverTeam1, object: sameTeamLabel(test.UserTeamObserverTeam1), action: create, allow: false}, + + {user: test.UserTeamGitOpsTeam1, object: sameTeamLabel(test.UserTeamGitOpsTeam1), action: read, allow: true}, + {user: test.UserTeamGitOpsTeam1, object: sameTeamLabel(test.UserTeamGitOpsTeam1), action: write, allow: true}, + {user: test.UserTeamGitOpsTeam1, object: sameTeamLabel(test.UserTeamGitOpsTeam1), action: create, allow: true}, + + {user: test.UserTeamObserverTeam1, object: differentTeamLabel(test.UserTeamObserverTeam1), action: read, allow: false}, + {user: test.UserTeamObserverTeam1, object: differentTeamLabel(test.UserTeamObserverTeam1), action: write, allow: false}, + {user: test.UserTeamObserverTeam1, object: differentTeamLabel(test.UserTeamObserverTeam1), action: create, allow: false}, + + {user: test.UserTeamGitOpsTeam1, object: differentTeamLabel(test.UserTeamGitOpsTeam1), action: read, allow: false}, + {user: test.UserTeamGitOpsTeam1, object: differentTeamLabel(test.UserTeamGitOpsTeam1), action: write, allow: false}, + {user: test.UserTeamGitOpsTeam1, object: differentTeamLabel(test.UserTeamGitOpsTeam1), action: create, allow: false}, }) } diff --git a/server/datastore/mysql/android_test.go b/server/datastore/mysql/android_test.go index 8237c405ef..e6b929d85e 100644 --- a/server/datastore/mysql/android_test.go +++ b/server/datastore/mysql/android_test.go @@ -1152,7 +1152,7 @@ func testListMDMAndroidProfilesToSend(t *testing.T, ds *Datastore) { }, profs) // make host[0] a member of only one of the labels - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lblIncAll1.ID, []uint{hosts[0].ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lblIncAll1, []uint{hosts[0].ID}, fleet.TeamFilter{}) require.NoError(t, err) // no change, host is not a member of both labels @@ -1167,7 +1167,7 @@ func testListMDMAndroidProfilesToSend(t *testing.T, ds *Datastore) { }, profs) // make host[0] a member of the other label - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lblIncAll2.ID, []uint{hosts[0].ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lblIncAll2, []uint{hosts[0].ID}, fleet.TeamFilter{}) require.NoError(t, err) // now p4 is applicable to host 0 @@ -1203,7 +1203,7 @@ func testListMDMAndroidProfilesToSend(t *testing.T, ds *Datastore) { }, profs) // make host[0] a member of one of the labels - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lblIncAny1.ID, []uint{hosts[0].ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lblIncAny1, []uint{hosts[0].ID}, fleet.TeamFilter{}) require.NoError(t, err) // now p5 is applicable to host 0 @@ -1261,7 +1261,7 @@ func testListMDMAndroidProfilesToSend(t *testing.T, ds *Datastore) { }, profs) // make host[0] a member of one of the exclude labels - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lblExclAny2.ID, []uint{hosts[0].ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lblExclAny2, []uint{hosts[0].ID}, fleet.TeamFilter{}) require.NoError(t, err) // p6 is not applicable anymore @@ -1493,7 +1493,7 @@ func testListMDMAndroidProfilesToSendWithExcludeAny(t *testing.T, ds *Datastore) }, profs) // Make host 0 a member of labelExclAny2 which excludes everything except p1 for it - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lblExclAny2.ID, []uint{hosts[0].ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lblExclAny2, []uint{hosts[0].ID}, fleet.TeamFilter{}) require.NoError(t, err) profs, toRemoveProfs, err = ds.ListMDMAndroidProfilesToSend(ctx) @@ -1510,7 +1510,7 @@ func testListMDMAndroidProfilesToSendWithExcludeAny(t *testing.T, ds *Datastore) // Make hosts 0 and 1 members of labelExclAny1 which excludes everything except p5 for host p1. Android doesn't // currently support dynamic labels but this ensures the datastore processes it right if somehow an Android host // becomes a member of one - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lblExclAny1.ID, []uint{hosts[0].ID, hosts[1].ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lblExclAny1, []uint{hosts[0].ID, hosts[1].ID}, fleet.TeamFilter{}) require.NoError(t, err) profs, toRemoveProfs, err = ds.ListMDMAndroidProfilesToSend(ctx) diff --git a/server/datastore/mysql/apple_mdm_test.go b/server/datastore/mysql/apple_mdm_test.go index 443da8a40f..8011d2c1b4 100644 --- a/server/datastore/mysql/apple_mdm_test.go +++ b/server/datastore/mysql/apple_mdm_test.go @@ -268,7 +268,7 @@ func testNewMDMAppleConfigProfileDuplicateIdentifier(t *testing.T, ds *Datastore require.False(t, prof.LabelsIncludeAll[0].Broken) // break the profile by deleting the label - require.NoError(t, ds.DeleteLabel(ctx, lbl.Name)) + require.NoError(t, ds.DeleteLabel(ctx, lbl.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) prof, err = ds.GetMDMAppleConfigProfile(ctx, labelProf.ProfileUUID) require.NoError(t, err) diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 08da0312ba..28fca1b9ce 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -5908,7 +5908,7 @@ func testHostsPackStatsMultipleHosts(t *testing.T, ds *Datastore) { // Create global pack (and one scheduled query in it). test.AddAllHostsLabel(t, ds) // the global pack needs the "All Hosts" label. - labels, err := ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}) + labels, err := ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}, false) require.NoError(t, err) require.Len(t, labels, 1) @@ -6100,7 +6100,7 @@ func testHostsPackStatsForPlatform(t *testing.T, ds *Datastore) { require.NotNil(t, host2) test.AddAllHostsLabel(t, ds) - labels, err := ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}) + labels, err := ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}, false) require.NoError(t, err) require.Len(t, labels, 1) diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go index b89169a11f..27f8bf82ae 100644 --- a/server/datastore/mysql/labels.go +++ b/server/datastore/mysql/labels.go @@ -3,6 +3,7 @@ package mysql import ( "context" "database/sql" + "errors" "fmt" "regexp" "sort" @@ -20,6 +21,123 @@ func (ds *Datastore) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSp return ds.ApplyLabelSpecsWithAuthor(ctx, specs, nil) } +func (ds *Datastore) SetAsideLabels(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + if len(names) == 0 { + return nil + } + + type existingLabel struct { + ID uint `db:"id"` + AuthorID *uint `db:"author_id"` + TeamID *uint `db:"team_id"` + } + + stmt := `SELECT id, author_id, team_id FROM labels WHERE name IN (?) AND label_type != ?` + stmt, args, err := sqlx.In(stmt, names, uint(fleet.LabelTypeBuiltIn)) + if err != nil { + return ctxerr.Wrap(ctx, err, "build labels query") + } + + var labels []existingLabel + if err := sqlx.SelectContext(ctx, ds.writer(ctx), &labels, stmt, args...); err != nil { + return ctxerr.Wrap(ctx, err, "query existing labels") + } + + errCannotSetAside := ctxerr.New(ctx, "one or more specified labels to set aside do not exist or cannot be set aside") + errGlobal := ctxerr.New(ctx, "one or more specified labels to set aside is on the same team as you are trying to modify") + + if len(labels) != len(names) { + return errCannotSetAside + } + + // Helper function to check if user has a global write role (admin, maintainer, or gitops) + hasGlobalWriteRole := func() bool { + if user.GlobalRole == nil { + return false + } + return *user.GlobalRole == fleet.RoleAdmin || + *user.GlobalRole == fleet.RoleMaintainer || + *user.GlobalRole == fleet.RoleGitOps + } + + // Helper function to check if user has a write role on any team + hasWriteRoleAnywhere := func() bool { + for _, team := range user.Teams { + if team.Role == fleet.RoleAdmin || + team.Role == fleet.RoleMaintainer || + team.Role == fleet.RoleGitOps { + return true + } + } + return false + } + + // Helper function to check if user has a write role on a specific team + hasWriteRoleOnTeam := func(teamID uint) bool { + for _, team := range user.Teams { + if team.ID == teamID && + (team.Role == fleet.RoleAdmin || + team.Role == fleet.RoleMaintainer || + team.Role == fleet.RoleGitOps) { + return true + } + } + return false + } + + for _, label := range labels { + if label.TeamID == nil { // Global label + if notOnTeamID == nil { // Disallow moving aside since the label is on the same team + return errGlobal + } + + if hasGlobalWriteRole() { + continue + } + + if hasWriteRoleAnywhere() && label.AuthorID != nil && *label.AuthorID == user.ID { + continue + } + + // User doesn't have permission to set aside this global label + return errCannotSetAside + } + + // Team label + if notOnTeamID != nil && *notOnTeamID == *label.TeamID { // label is on the same team we're applying specs for + return errCannotSetAside // generic error here because label may not be visible to the user + } + + if hasGlobalWriteRole() { + continue + } + + if hasWriteRoleAnywhere() && label.AuthorID != nil && *label.AuthorID == user.ID { + continue + } + + if hasWriteRoleOnTeam(*label.TeamID) { + continue + } + + // User doesn't have permission to set aside this team label + return errCannotSetAside + } + + // Bulk update to rename labels by appending __team_{team_id} (or __team_0 for global labels) + updateStmt := `UPDATE labels SET name = CONCAT(name, '__team_', COALESCE(team_id, 0)) WHERE name IN (?)` + updateStmt, updateArgs, err := sqlx.In(updateStmt, names) + if err != nil { + return ctxerr.Wrap(ctx, err, "build update labels query") + } + + if _, err := ds.writer(ctx).ExecContext(ctx, updateStmt, updateArgs...); err != nil { + return ctxerr.Wrap(ctx, err, "rename labels to set aside") + } + + return nil +} + func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fleet.LabelSpec, authorID *uint) (err error) { // First, get existing labels to detect platform changes labelNames := make([]string, 0, len(specs)) @@ -33,9 +151,14 @@ func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fle ID uint `db:"id"` Name string `db:"name"` Platform string `db:"platform"` + TeamID *uint `db:"team_id"` } existingLabels := make(map[string]existingLabel, len(specs)) + // NOTE: Thie assumes the caller has verified that label specs are all writable by the user, either for authorship + // or team affiliation. We'll catch cases where a user is attempting to move the label between teams (which + // should've been cleaned up by SetAsideLabels). + err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { // TODO: do we want to allow on duplicate updating label_type or // label_membership_type or should those always be immutable? @@ -43,7 +166,7 @@ func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fle // are not changed? if len(labelNames) > 0 { - stmt := `SELECT id, name, platform FROM labels WHERE name IN (?)` + stmt := `SELECT id, name, platform, team_id FROM labels WHERE name IN (?)` stmt, args, err := sqlx.In(stmt, labelNames) if err != nil { return ctxerr.Wrap(ctx, err, "build existing labels query") @@ -55,7 +178,16 @@ func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fle } for _, label := range labels { - existingLabels[label.Name] = label + existingLabels[strings.ToLower(label.Name)] = label + } + + for _, spec := range specs { + if existingLabel, ok := existingLabels[strings.ToLower(spec.Name)]; ok && + (existingLabel.TeamID != nil && spec.TeamID == nil || + existingLabel.TeamID == nil && spec.TeamID != nil || + (existingLabel.TeamID != nil && spec.TeamID != nil && *existingLabel.TeamID != *spec.TeamID)) { + return ctxerr.Wrap(ctx, err, "one or more specified labels exists on another team") + } } } @@ -68,8 +200,9 @@ func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fle label_type, label_membership_type, criteria, - author_id - ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ? ) + author_id, + team_id + ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? ) ON DUPLICATE KEY UPDATE name = VALUES(name), description = VALUES(description), @@ -94,13 +227,13 @@ func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fle if s.Name == "" { return ctxerr.New(ctx, "label name must not be empty") } - insertLabelResult, err := stmt.ExecContext(ctx, s.Name, s.Description, s.Query, s.Platform, s.LabelType, s.LabelMembershipType, s.HostVitalsCriteria, authorID) + insertLabelResult, err := stmt.ExecContext(ctx, s.Name, s.Description, s.Query, s.Platform, s.LabelType, s.LabelMembershipType, s.HostVitalsCriteria, authorID, s.TeamID) if err != nil { return ctxerr.Wrap(ctx, err, "exec ApplyLabelSpecs insert") } // Check if this is an existing label and platform changed -> clean up memberships if needed - if existing, ok := existingLabels[s.Name]; ok && existing.Platform != s.Platform { + if existing, ok := existingLabels[strings.ToLower(s.Name)]; ok && existing.Platform != s.Platform { // When a label's platform changes, we delete all existing memberships. // This ensures a clean slate - the label's query will be re-evaluated // by Fleet's label execution system, and only hosts matching the new @@ -121,7 +254,7 @@ func (ds *Datastore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*fle // For manual labels, we need the label ID to update membership var labelID uint - if existing, ok := existingLabels[s.Name]; ok { + if existing, ok := existingLabels[strings.ToLower(s.Name)]; ok { // Use the existing label ID labelID = existing.ID } else { @@ -203,13 +336,13 @@ func batchHostnames(hostnames []string) [][]string { return batches } -func (ds *Datastore) UpdateLabelMembershipByHostIDs(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { +func (ds *Datastore) UpdateLabelMembershipByHostIDs(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { // delete all label membership sql := ` DELETE FROM label_membership WHERE label_id = ? ` - _, err := tx.ExecContext(ctx, sql, labelID) + _, err := tx.ExecContext(ctx, sql, label.ID) if err != nil { return ctxerr.Wrap(ctx, err, "clear membership for ID") } @@ -220,13 +353,42 @@ func (ds *Datastore) UpdateLabelMembershipByHostIDs(ctx context.Context, labelID // Split hostIds into batches to avoid parameter limit in MySQL. for _, hostIds := range batchHostIds(hostIds) { + if label.TeamID != nil { // team labels can only be applied to hosts on that team + hostTeamCheckSql := `SELECT COUNT(id) FROM hosts WHERE team_id != ? AND id IN (` + + strings.TrimRight(strings.Repeat("?,", len(hostIds)), ",") + ")" + hostTeamCheckSql, args, err := sqlx.In(hostTeamCheckSql, label.TeamID, hostIds) + if err != nil { + return ctxerr.Wrap(ctx, err, "build host team membership check IN statement") + } + + rows, err := tx.QueryContext(ctx, hostTeamCheckSql, args...) + if err != nil { + return ctxerr.Wrap(ctx, err, "execute host team membership check query") + } + + rows.Next() + var hostCountOnWrongTeam int + if err := rows.Scan(&hostCountOnWrongTeam); err != nil { + return ctxerr.Wrap(ctx, err, "check host team membership") + } + if err := rows.Err(); err != nil { + return ctxerr.Wrap(ctx, err, "check host team membership") + } + if err := rows.Close(); err != nil { //nolint:sqlclosecheck + return ctxerr.Wrap(ctx, err, "close result set for host team membership") + } + if hostCountOnWrongTeam > 0 { + return ctxerr.Wrap(ctx, errors.New("supplied hosts are on a different team than the label")) + } + } + // Use ignore because duplicate hostIds could appear in // different batches and would result in duplicate key errors. - values := []interface{}{} - placeholders := []string{} + var values []any + var placeholders []string for _, hostID := range hostIds { - values = append(values, labelID, hostID) + values = append(values, label.ID, hostID) placeholders = append(placeholders, "(?, ?)") } @@ -249,7 +411,12 @@ VALUES ` + strings.Join(placeholders, ", ") return nil, nil, ctxerr.Wrap(ctx, err, "UpdateLabelMembershipByHostIDs transaction") } - return ds.labelDB(ctx, labelID, teamFilter, ds.writer(ctx)) + updatedLabel, hostIDs, err := ds.labelDB(ctx, label.ID, teamFilter, ds.writer(ctx)) + if err != nil { + return nil, nil, ctxerr.Wrap(ctx, err, "UpdateLabelMembershipByHostIDs get label after update") + } + + return updatedLabel.GetLabel(), hostIDs, err } // Update label membership for a host vitals label. @@ -271,9 +438,13 @@ func (ds *Datastore) UpdateLabelMembershipByHostCriteria(ctx context.Context, hv return nil, ctxerr.New(ctx, "label query is empty after calculating host vitals query") } + labelSelect := fmt.Sprintf("%d as label_id, hosts.id as host_id", label.ID) + labelQuery := fmt.Sprintf(query, labelSelect, "hosts") + if label.TeamID != nil { + labelQuery = fmt.Sprintf(query, labelSelect, fmt.Sprintf("hosts JOIN (SELECT %d team_id) label_team ON label_team.team_id = hosts.team_id", *label.TeamID)) + } + err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { - labelSelect := fmt.Sprintf("%d as label_id, hosts.id as host_id", label.ID) - labelQuery := fmt.Sprintf(query, labelSelect, "hosts") // Insert new label membership based on the label query. sql := fmt.Sprintf(`INSERT INTO label_membership (label_id, host_id) SELECT candidate.label_id, candidate.host_id FROM (%s) as candidate ON DUPLICATE KEY UPDATE host_id = label_membership.host_id`, labelQuery) _, err := tx.ExecContext(ctx, sql, queryVals...) @@ -316,11 +487,17 @@ func batchHostIds(hostIds []uint) [][]uint { return batches } -func (ds *Datastore) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, error) { +func (ds *Datastore) GetLabelSpecs(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { var specs []*fleet.LabelSpec // Get basic specs - query := "SELECT id, name, description, query, platform, label_type, label_membership_type, criteria FROM labels" - if err := sqlx.SelectContext(ctx, ds.reader(ctx), &specs, query); err != nil { + query, params, err := applyLabelTeamFilter(`SELECT id, name, description, query, platform, + label_type, label_membership_type, criteria, team_id + FROM labels l`, filter) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building query for getting label specs") + } + + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &specs, query, params...); err != nil { return nil, ctxerr.Wrap(ctx, err, "get labels") } @@ -336,14 +513,17 @@ func (ds *Datastore) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, err return specs, nil } -func (ds *Datastore) GetLabelSpec(ctx context.Context, name string) (*fleet.LabelSpec, error) { +func (ds *Datastore) GetLabelSpec(ctx context.Context, filter fleet.TeamFilter, name string) (*fleet.LabelSpec, error) { var specs []*fleet.LabelSpec - query := ` -SELECT id, name, description, query, platform, label_type, label_membership_type -FROM labels -WHERE name = ? -` - if err := sqlx.SelectContext(ctx, ds.reader(ctx), &specs, query, name); err != nil { + query, params, err := applyLabelTeamFilter(` +SELECT l.id, l.name, l.description, l.query, l.platform, l.label_type, l.label_membership_type +FROM labels l +WHERE l.name = ?`, filter, name) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building query for getting label spec") + } + + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &specs, query, params...); err != nil { return nil, ctxerr.Wrap(ctx, err, "get label") } if len(specs) == 0 { @@ -420,7 +600,7 @@ func (ds *Datastore) NewLabel(ctx context.Context, label *fleet.Label, opts ...f return label, nil } -func (ds *Datastore) SaveLabel(ctx context.Context, label *fleet.Label, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { +func (ds *Datastore) SaveLabel(ctx context.Context, label *fleet.Label, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { query := `UPDATE labels SET name = ?, description = ? WHERE id = ?` _, err := ds.writer(ctx).ExecContext(ctx, query, label.Name, label.Description, label.ID) if err != nil { @@ -438,10 +618,16 @@ func (ds *Datastore) SaveLabel(ctx context.Context, label *fleet.Label, teamFilt } // DeleteLabel deletes a fleet.Label -func (ds *Datastore) DeleteLabel(ctx context.Context, name string) error { +func (ds *Datastore) DeleteLabel(ctx context.Context, name string, filter fleet.TeamFilter) error { return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var labelID uint - err := sqlx.GetContext(ctx, tx, &labelID, `select id FROM labels WHERE name = ?`, name) + + query, params, err := applyLabelTeamFilter(`select id FROM labels WHERE name = ?`, filter, name) + if err != nil { + return ctxerr.Wrap(ctx, err, "getting label id to delete") + } + + err = sqlx.GetContext(ctx, tx, &labelID, query, params...) if err != nil { if err == sql.ErrNoRows { return ctxerr.Wrap(ctx, notFound("Label").WithName(name)) @@ -486,22 +672,45 @@ func deleteLabelsInTx(ctx context.Context, tx sqlx.ExtContext, labelIDs []uint) return nil } -// Label returns a fleet.Label identified by lid if one exists. -func (ds *Datastore) Label(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { +// LabelByName returns a fleet.Label identified by name if one exists and is accessible to the specified user. +func (ds *Datastore) LabelByName(ctx context.Context, name string, teamFilter fleet.TeamFilter) (*fleet.Label, error) { + stmt, params, err := applyLabelTeamFilter("SELECT l.* FROM labels l WHERE l.name = ?", teamFilter, name) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building label select query") + } + + var label fleet.Label + if err := sqlx.GetContext(ctx, ds.reader(ctx), &label, stmt, params...); err != nil { + if err == sql.ErrNoRows { + return nil, ctxerr.Wrap(ctx, notFound("Label").WithName(name)) + } + return nil, ctxerr.Wrap(ctx, err, "selecting label") + } + + return &label, nil +} + +// Label returns a fleet.LabelWithTeamName identified by lid if one exists and is accessible to the specified user. +func (ds *Datastore) Label(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { return ds.labelDB(ctx, lid, teamFilter, ds.reader(ctx)) } -func (ds *Datastore) labelDB(ctx context.Context, lid uint, teamFilter fleet.TeamFilter, q sqlx.QueryerContext) (*fleet.Label, []uint, error) { +func (ds *Datastore) labelDB(ctx context.Context, lid uint, teamFilter fleet.TeamFilter, q sqlx.QueryerContext) (*fleet.LabelWithTeamName, []uint, error) { stmt := fmt.Sprintf(` SELECT - l.*, + l.*, teams.name team_name, (SELECT COUNT(1) FROM label_membership lm JOIN hosts h ON (lm.host_id = h.id) WHERE label_id = l.id AND %s) AS host_count - FROM labels l - WHERE id = ? + FROM labels l LEFT JOIN teams ON teams.id = l.team_id + WHERE l.id = ? `, ds.whereFilterHostsByTeams(teamFilter, "h")) - var label fleet.Label - if err := sqlx.GetContext(ctx, q, &label, stmt, lid); err != nil { + stmt, params, err := applyLabelTeamFilter(stmt, teamFilter, lid) + if err != nil { + return nil, nil, ctxerr.Wrap(ctx, err, "building label select query") + } + + var label fleet.LabelWithTeamName + if err := sqlx.GetContext(ctx, q, &label, stmt, params...); err != nil { if err == sql.ErrNoRows { return nil, nil, ctxerr.Wrap(ctx, notFound("Label").WithID(lid)) } @@ -520,7 +729,7 @@ func (ds *Datastore) labelDB(ctx context.Context, lid uint, teamFilter fleet.Tea // ListLabels returns all labels limited or sorted by fleet.ListOptions. // MatchQuery not supported -func (ds *Datastore) ListLabels(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Label, error) { +func (ds *Datastore) ListLabels(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { if opt.After != "" { return nil, &fleet.BadRequestError{Message: "parameter 'after' is not supported"} } @@ -528,31 +737,80 @@ func (ds *Datastore) ListLabels(ctx context.Context, filter fleet.TeamFilter, op return nil, &fleet.BadRequestError{Message: "parameter 'query' is not supported"} } - query := "SELECT * FROM labels l " - // If a team filter is provided, filter host membership by team and return counts with the labels. - if filter.User != nil { + query := "SELECT l.* FROM labels l " + // When applicable, filter host membership by team and return counts with the labels. + if filter.User != nil && includeHostCounts { query = fmt.Sprintf(` - SELECT *, - (SELECT COUNT(1) FROM label_membership lm JOIN hosts h ON (lm.host_id = h.id) WHERE label_id = l.id AND %s) AS host_count + SELECT l.*, + (SELECT COUNT(1) + FROM label_membership lm + JOIN hosts h ON (lm.host_id = h.id) WHERE label_id = l.id AND %s + ) AS host_count FROM labels l `, ds.whereFilterHostsByTeams(filter, "h"), ) } - query, params := appendListOptionsToSQL(query, &opt) - labels := []*fleet.Label{} + query, params, err := applyLabelTeamFilter(query, filter) + if err != nil { + return nil, err + } + + query, params = appendListOptionsWithCursorToSQL(query, params, &opt) + var labels []*fleet.Label if err := sqlx.SelectContext(ctx, ds.reader(ctx), &labels, query, params...); err != nil { // it's ok if no labels exist if err == sql.ErrNoRows { return labels, nil } + return nil, ctxerr.Wrap(ctx, err, "selecting labels") } return labels, nil } +var errInaccessibleTeam = errors.New("The team ID you provided refers to a team that either does not exist or you do not have permission to access.") + +// applyLabelTeamFilter requires the labels table to be aliased as "l" to work +func applyLabelTeamFilter(query string, filter fleet.TeamFilter, initialParams ...any) (string, []any, error) { + // using this rather than a "contains a WHERE" check because some queries have subqueries + // but don't have any parameters for those subqueries + whereOrAnd := " WHERE " + if len(initialParams) > 0 { + whereOrAnd = " AND " + } + + // apply sqlx.In if we had initial params, as they may include slices for where-ins other than the team one + maybeIn := func(query string) (string, []any, error) { + if len(initialParams) > 0 { + return sqlx.In(query, initialParams...) + } + return query, nil, nil + } + + if filter.User == nil { // fall back to safe (global-only) filter if this happens (it shouldn't) + return maybeIn(query + whereOrAnd + " l.team_id IS NULL") + } + + if filter.TeamID != nil { + if *filter.TeamID == 0 { // global labels only; any user can see them + return maybeIn(query + whereOrAnd + "l.team_id IS NULL") + } else if !filter.UserCanAccessSelectedTeam() { + return "", nil, fleet.NewUserMessageError(errInaccessibleTeam, 403) + } // else user can see the team labels they're asking for; return global labels plus that team's labels + + return sqlx.In(query+whereOrAnd+"(l.team_id IS NULL OR l.team_id = ?)", append(initialParams, *filter.TeamID)...) + } + + if !filter.User.HasAnyGlobalRole() && filter.User.HasAnyTeamRole() { // filter to teams user can see + return sqlx.In(query+whereOrAnd+"(l.team_id IS NULL OR l.team_id IN (?))", append(initialParams, filter.User.TeamIDsWithAnyRole())...) + } // else user exists and has a global role, so we don't need to filter out any team labels + + return maybeIn(query) +} + func platformForHost(host *fleet.Host) string { if host.Platform != "rhel" { return host.Platform @@ -709,6 +967,19 @@ func (ds *Datastore) ListLabelsForHost(ctx context.Context, hid uint) ([]*fleet. // ListHostsInLabel returns a list of fleet.Host that are associated // with fleet.Label referenced by Label ID func (ds *Datastore) ListHostsInLabel(ctx context.Context, filter fleet.TeamFilter, lid uint, opt fleet.HostListOptions) ([]*fleet.Host, error) { + labelCheckSql, labelCheckParams, err := applyLabelTeamFilter(`SELECT l.id FROM labels l WHERE id = ?`, filter, lid) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building query to confirm label existence") + } + + var foundID uint + if err := sqlx.GetContext(ctx, ds.reader(ctx), &foundID, labelCheckSql, labelCheckParams...); err != nil { + if err == sql.ErrNoRows { + return nil, nil // matches previous behavior (invalid labels return no hosts) + } + return nil, ctxerr.Wrap(ctx, err, "confirming label existence") + } + queryFmt := ` SELECT h.id, @@ -955,101 +1226,28 @@ func (ds *Datastore) CountHostsInLabel(ctx context.Context, filter fleet.TeamFil return count, nil } -func (ds *Datastore) ListUniqueHostsInLabels(ctx context.Context, filter fleet.TeamFilter, labels []uint) ([]*fleet.Host, error) { - if len(labels) == 0 { - return []*fleet.Host{}, nil - } - - sqlStatement := fmt.Sprintf(` - SELECT DISTINCT - h.id, - h.osquery_host_id, - h.created_at, - h.updated_at, - h.detail_updated_at, - h.node_key, - h.hostname, - h.uuid, - h.platform, - h.osquery_version, - h.os_version, - h.build, - h.platform_like, - h.code_name, - h.uptime, - h.memory, - h.cpu_type, - h.cpu_subtype, - h.cpu_brand, - h.cpu_physical_cores, - h.cpu_logical_cores, - h.hardware_vendor, - h.hardware_model, - h.hardware_version, - h.hardware_serial, - h.computer_name, - h.primary_ip_id, - h.distributed_interval, - h.logger_tls_period, - h.config_tls_refresh, - h.primary_ip, - h.primary_mac, - h.label_updated_at, - h.last_enrolled_at, - h.refetch_requested, - h.refetch_critical_queries_until, - h.team_id, - h.policy_updated_at, - h.public_ip, - COALESCE(hd.gigs_disk_space_available, 0) as gigs_disk_space_available, - COALESCE(hd.percent_disk_space_available, 0) as percent_disk_space_available, - COALESCE(hd.gigs_total_disk_space, 0) as gigs_total_disk_space, - (SELECT name FROM teams t WHERE t.id = h.team_id) AS team_name - FROM label_membership lm - JOIN hosts h ON lm.host_id = h.id - LEFT JOIN host_disks hd ON hd.host_id = h.id - WHERE lm.label_id IN (?) AND %s -`, ds.whereFilterHostsByTeams(filter, "h"), - ) - - query, args, err := sqlx.In(sqlStatement, labels) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "building query listing unique hosts in labels") - } - - query = ds.reader(ctx).Rebind(query) - hosts := []*fleet.Host{} - err = sqlx.SelectContext(ctx, ds.reader(ctx), &hosts, query, args...) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "listing unique hosts in labels") - } - - return hosts, nil -} - func (ds *Datastore) searchLabelsWithOmits(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Label, error) { - transformedQuery := transformQuery(query) - sqlStatement := fmt.Sprintf(` - SELECT *, + SELECT l.*, (SELECT COUNT(1) FROM label_membership lm JOIN hosts h ON (lm.host_id = h.id) WHERE label_id = l.id AND %s ) AS host_count FROM labels l WHERE ( - MATCH(name) AGAINST(? IN BOOLEAN MODE) + MATCH(l.name) AGAINST(? IN BOOLEAN MODE) ) - AND id NOT IN (?) - ORDER BY label_type DESC, id ASC + AND l.id NOT IN (?) `, ds.whereFilterHostsByTeams(filter, "h"), ) - sql, args, err := sqlx.In(sqlStatement, transformedQuery, omit) + sql, args, err := applyLabelTeamFilter(sqlStatement, filter, transformQuery(query), omit) if err != nil { return nil, ctxerr.Wrap(ctx, err, "building query for labels with omits") } + sql += ` ORDER BY label_type DESC, id ASC` + sql = ds.reader(ctx).Rebind(sql) matches := []*fleet.Label{} @@ -1106,15 +1304,13 @@ func (ds *Datastore) addAllHostsLabelToList(ctx context.Context, filter fleet.Te func (ds *Datastore) searchLabelsDefault(ctx context.Context, filter fleet.TeamFilter, omit ...uint) ([]*fleet.Label, error) { sql := fmt.Sprintf(` - SELECT *, + SELECT l.*, (SELECT COUNT(1) FROM label_membership lm JOIN hosts h ON (lm.host_id = h.id) WHERE label_id = l.id AND %s ) AS host_count FROM labels l - WHERE id NOT IN (?) - GROUP BY id - ORDER BY label_type DESC, id ASC + WHERE l.id NOT IN (?) `, ds.whereFilterHostsByTeams(filter, "h"), ) @@ -1129,10 +1325,12 @@ func (ds *Datastore) searchLabelsDefault(ctx context.Context, filter fleet.TeamF } var labels []*fleet.Label - sql, args, err := sqlx.In(sql, in) + sql, args, err := applyLabelTeamFilter(sql, filter, in) if err != nil { return nil, ctxerr.Wrap(ctx, err, "searching default labels") } + sql += ` GROUP BY id ORDER BY label_type DESC, id ASC` + sql = ds.reader(ctx).Rebind(sql) if err := sqlx.SelectContext(ctx, ds.reader(ctx), &labels, sql, args...); err != nil { return nil, ctxerr.Wrap(ctx, err, "searching default labels rebound") @@ -1161,7 +1359,7 @@ func (ds *Datastore) SearchLabels(ctx context.Context, filter fleet.TeamFilter, // if additional label types are added. Ordering next by ID ensures // that the order is always consistent. sql := fmt.Sprintf(` - SELECT *, + SELECT l.*, (SELECT COUNT(1) FROM label_membership lm JOIN hosts h ON (lm.host_id = h.id) WHERE label_id = l.id AND %s @@ -1170,16 +1368,22 @@ func (ds *Datastore) SearchLabels(ctx context.Context, filter fleet.TeamFilter, WHERE ( MATCH(name) AGAINST(? IN BOOLEAN MODE) ) - ORDER BY label_type DESC, id ASC `, ds.whereFilterHostsByTeams(filter, "h"), ) + sql, args, err := applyLabelTeamFilter(sql, filter, transformQuery(query)) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building query for searching labels") + } + + sql += ` ORDER BY label_type DESC, id ASC` + matches := []*fleet.Label{} - if err := sqlx.SelectContext(ctx, ds.reader(ctx), &matches, sql, transformedQuery); err != nil { + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &matches, sql, args...); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting labels for search") } - matches, err := ds.addAllHostsLabelToList(ctx, filter, matches, omit...) + matches, err = ds.addAllHostsLabelToList(ctx, filter, matches, omit...) if err != nil { return nil, ctxerr.Wrap(ctx, err, "adding all hosts label to matches") } @@ -1187,17 +1391,12 @@ func (ds *Datastore) SearchLabels(ctx context.Context, filter fleet.TeamFilter, return matches, nil } -func (ds *Datastore) LabelIDsByName(ctx context.Context, names []string) (map[string]uint, error) { +func (ds *Datastore) LabelIDsByName(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { if len(names) == 0 { return map[string]uint{}, nil } - sqlStatement := ` - SELECT id, name FROM labels - WHERE name IN (?) - ` - - sql, args, err := sqlx.In(sqlStatement, names) + sql, args, err := applyLabelTeamFilter(`SELECT l.id, l.name FROM labels l WHERE l.name IN (?)`, filter, names) if err != nil { return nil, ctxerr.Wrap(ctx, err, "building query to get label ids by name") } @@ -1215,24 +1414,19 @@ func (ds *Datastore) LabelIDsByName(ctx context.Context, names []string) (map[st return result, nil } -func (ds *Datastore) LabelsByName(ctx context.Context, names []string) (map[string]*fleet.Label, error) { +func (ds *Datastore) LabelsByName(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { if len(names) == 0 { return map[string]*fleet.Label{}, nil } - sqlStatement := ` - SELECT * FROM labels - WHERE name IN (?) - ` - - sqlStatement, args, err := sqlx.In(sqlStatement, names) + sqlStatement, args, err := applyLabelTeamFilter(`SELECT l.* FROM labels l WHERE l.name IN (?)`, filter, names) if err != nil { return nil, ctxerr.Wrap(ctx, err, "building query to get label ids by name") } var labels []*fleet.Label if err := sqlx.SelectContext(ctx, ds.reader(ctx), &labels, sqlStatement, args...); err != nil { - return nil, ctxerr.Wrap(ctx, err, "get label ids by name") + return nil, ctxerr.Wrap(ctx, err, "get labels by name") } result := make(map[string]*fleet.Label, len(labels)) @@ -1320,9 +1514,16 @@ func amountLabelsDB(ctx context.Context, db sqlx.QueryerContext) (int, error) { return amount, nil } -func (ds *Datastore) LabelsSummary(ctx context.Context) ([]*fleet.LabelSummary, error) { - labelsSummary := []*fleet.LabelSummary{} - if err := sqlx.SelectContext(ctx, ds.reader(ctx), &labelsSummary, "SELECT id, name, description, label_type, team_id FROM labels"); err != nil { +func (ds *Datastore) LabelsSummary(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSummary, error) { + var labelsSummary []*fleet.LabelSummary + + query := "SELECT id, name, description, label_type, team_id FROM labels l" + query, params, err := applyLabelTeamFilter(query, filter) + if err != nil { + return nil, err + } + + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &labelsSummary, query, params...); err != nil { return nil, ctxerr.Wrap(ctx, err, "labels summary") } return labelsSummary, nil @@ -1356,6 +1557,7 @@ func (ds *Datastore) HostMemberOfAllLabels(ctx context.Context, hostID uint, lab return ok, nil } +// AddLabelsToHost skips auth as it's only used in tests, and where label teams have already been validated. func (ds *Datastore) AddLabelsToHost(ctx context.Context, hostID uint, labelIDs []uint) error { if len(labelIDs) == 0 { return nil @@ -1375,6 +1577,7 @@ func (ds *Datastore) AddLabelsToHost(ctx context.Context, hostID uint, labelIDs } func (ds *Datastore) RemoveLabelsFromHost(ctx context.Context, hostID uint, labelIDs []uint) error { + // We *don't* check label team here because a wrong-team label won't be on the host in the first place if len(labelIDs) == 0 { return nil } diff --git a/server/datastore/mysql/labels_test.go b/server/datastore/mysql/labels_test.go index 15f8b0fa42..ba8d3c5e14 100644 --- a/server/datastore/mysql/labels_test.go +++ b/server/datastore/mysql/labels_test.go @@ -81,18 +81,18 @@ func TestLabels(t *testing.T) { {"ListHostsInLabelAndTeamFilterDeferred", func(t *testing.T, ds *Datastore) { testLabelsListHostsInLabelAndTeamFilter(true, t, ds) }}, {"ListHostsInLabelAndTeamFilterNotDeferred", func(t *testing.T, ds *Datastore) { testLabelsListHostsInLabelAndTeamFilter(false, t, ds) }}, {"BuiltIn", testLabelsBuiltIn}, - {"ListUniqueHostsInLabels", testLabelsListUniqueHostsInLabels}, {"ChangeDetails", testLabelsChangeDetails}, {"GetSpec", testLabelsGetSpec}, {"ApplySpecsRoundtrip", testLabelsApplySpecsRoundtrip}, {"UpdateLabelMembershipByHostIDs", testUpdateLabelMembershipByHostIDs}, {"IDsByName", testLabelsIDsByName}, {"ByName", testLabelsByName}, + {"SingleByName", testLabelByName}, {"Save", testLabelsSave}, {"QueriesForCentOSHost", testLabelsQueriesForCentOSHost}, {"RecordNonExistentQueryLabelExecution", testLabelsRecordNonexistentQueryLabelExecution}, {"DeleteLabel", testDeleteLabel}, - {"LabelsSummary", testLabelsSummary}, + {"LabelsSummaryAndListTeamFiltering", testLabelsSummaryAndListTeamFiltering}, {"ListHostsInLabelIssues", testListHostsInLabelIssues}, {"ListHostsInLabelDiskEncryptionStatus", testListHostsInLabelDiskEncryptionStatus}, {"HostMemberOfAllLabels", testHostMemberOfAllLabels}, @@ -103,6 +103,7 @@ func TestLabels(t *testing.T) { {"UpdateLabelMembershipByHostCriteria", testUpdateLabelMembershipByHostCriteria}, {"TeamLabels", testTeamLabels}, {"UpdateLabelMembershipForTransferredHost", testUpdateLabelMembershipForTransferredHost}, + {"SetAsideLabels", testSetAsideLabels}, } // call TruncateTables first to remove migration-created labels TruncateTables(t, ds) @@ -232,6 +233,8 @@ func testLabelsAddAllHosts(deferred bool, t *testing.T, db *Datastore) { } func testLabelsSearch(t *testing.T, db *Datastore) { + // TODO test team filtering + specs := []*fleet.LabelSpec{ {ID: 1, Name: "foo"}, {ID: 2, Name: "bar"}, @@ -253,6 +256,8 @@ func testLabelsSearch(t *testing.T, db *Datastore) { err := db.ApplyLabelSpecs(context.Background(), specs) require.Nil(t, err) + // TODO add team checking + user := &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)} filter := fleet.TeamFilter{User: user} @@ -266,12 +271,12 @@ func testLabelsSearch(t *testing.T, db *Datastore) { labels, err := db.SearchLabels(context.Background(), filter, "") require.Nil(t, err) assert.Len(t, labels, 12) - assert.Contains(t, labels, all) + assert.Contains(t, labels, &all.Label) labels, err = db.SearchLabels(context.Background(), filter, "foo") require.Nil(t, err) assert.Len(t, labels, 3) - assert.Contains(t, labels, all) + assert.Contains(t, labels, &all.Label) labels, err = db.SearchLabels(context.Background(), filter, "foo", all.ID, l3.ID) require.Nil(t, err) @@ -281,10 +286,12 @@ func testLabelsSearch(t *testing.T, db *Datastore) { labels, err = db.SearchLabels(context.Background(), filter, "xxx") require.Nil(t, err) assert.Len(t, labels, 1) - assert.Contains(t, labels, all) + assert.Contains(t, labels, &all.Label) } func testLabelsListHostsInLabel(t *testing.T, db *Datastore) { + // TODO test label filtering + h1, err := db.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), @@ -587,118 +594,6 @@ func testLabelsBuiltIn(t *testing.T, db *Datastore) { assert.Equal(t, fleet.LabelTypeBuiltIn, hits[1].LabelType) } -func testLabelsListUniqueHostsInLabels(t *testing.T, db *Datastore) { - hosts := make([]*fleet.Host, 4) - for i := range hosts { - h, err := db.NewHost(context.Background(), &fleet.Host{ - DetailUpdatedAt: time.Now(), - LabelUpdatedAt: time.Now(), - PolicyUpdatedAt: time.Now(), - SeenTime: time.Now(), - OsqueryHostID: ptr.String(strconv.Itoa(i)), - NodeKey: ptr.String(strconv.Itoa(i)), - UUID: strconv.Itoa(i), - Hostname: fmt.Sprintf("host_%d", i), - }) - require.Nil(t, err) - hosts[i] = h - } - - team1, err := db.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) - require.NoError(t, err) - require.NoError(t, db.AddHostsToTeam(context.Background(), fleet.NewAddHostsToTeamParams(&team1.ID, []uint{hosts[0].ID}))) - - l1 := fleet.LabelSpec{ - ID: 1, - Name: "label foo", - Query: "query1", - } - l2 := fleet.LabelSpec{ - ID: 2, - Name: "label bar", - Query: "query2", - } - require.NoError(t, db.ApplyLabelSpecs(context.Background(), []*fleet.LabelSpec{&l1, &l2})) - - for i := 0; i < 3; i++ { - err = db.RecordLabelQueryExecutions(context.Background(), hosts[i], map[uint]*bool{l1.ID: ptr.Bool(true)}, time.Now(), false) - assert.Nil(t, err) - } - // host 2 executes twice - for i := 2; i < len(hosts); i++ { - err = db.RecordLabelQueryExecutions(context.Background(), hosts[i], map[uint]*bool{l2.ID: ptr.Bool(true)}, time.Now(), false) - assert.Nil(t, err) - } - - filter := fleet.TeamFilter{User: test.UserAdmin} - - uniqueHosts, err := db.ListUniqueHostsInLabels(context.Background(), filter, []uint{l1.ID, l2.ID}) - assert.Nil(t, err) - assert.Equal(t, len(hosts), len(uniqueHosts)) - - labels, err := db.ListLabels(context.Background(), filter, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, labels, 2) - for _, l := range labels { - assert.True(t, l.HostCount > 0) - } - - // If an empty team filter is used, all hosts should be returned. - labelsNoTeamFilter, err := db.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, labelsNoTeamFilter, 2) - for _, l := range labelsNoTeamFilter { - assert.True(t, l.HostCount == 0) - } - - userObs := &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)} - filter = fleet.TeamFilter{User: userObs} - - // observer not included - uniqueHosts, err = db.ListUniqueHostsInLabels(context.Background(), filter, []uint{l1.ID, l2.ID}) - require.Nil(t, err) - assert.Len(t, uniqueHosts, 0) - - labels, err = db.ListLabels(context.Background(), filter, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, labels, 2) - for _, l := range labels { - assert.Equal(t, 0, l.HostCount) - } - - // observer included - filter.IncludeObserver = true - uniqueHosts, err = db.ListUniqueHostsInLabels(context.Background(), filter, []uint{l1.ID, l2.ID}) - require.Nil(t, err) - assert.Len(t, uniqueHosts, len(hosts)) - - labels, err = db.ListLabels(context.Background(), filter, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, labels, 2) - for _, l := range labels { - assert.True(t, l.HostCount > 0) - } - - userTeam1 := &fleet.User{Teams: []fleet.UserTeam{{Team: *team1, Role: fleet.RoleAdmin}}} - filter = fleet.TeamFilter{User: userTeam1} - - uniqueHosts, err = db.ListUniqueHostsInLabels(context.Background(), filter, []uint{l1.ID, l2.ID}) - require.Nil(t, err) - require.Len(t, uniqueHosts, 1) // only host 0 associated with this team - assert.Equal(t, hosts[0].ID, uniqueHosts[0].ID) - - labels, err = db.ListLabels(context.Background(), filter, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, labels, 2) - for _, l := range labels { - if l.ID == l1.ID { - assert.Equal(t, 1, l.HostCount) - } else { - assert.Equal(t, 0, l.HostCount) - } - } -} - func testLabelsChangeDetails(t *testing.T, db *Datastore) { label := fleet.LabelSpec{ ID: 1, @@ -732,7 +627,7 @@ func testLabelsChangeDetails(t *testing.T, db *Datastore) { label.Name = "changed name" // ApplyLabelSpecs can't update the name -- it simply creates a new label, so we need to call SaveLabel. saved.Name = label.Name - saved2, _, err := db.SaveLabel(context.Background(), saved, filter) + saved2, _, err := db.SaveLabel(context.Background(), &saved.Label, filter) require.NoError(t, err) assert.Equal(t, label.Name, saved2.Name) assert.Equal(t, label.Description, saved2.Description) @@ -800,7 +695,7 @@ func testLabelsGetSpec(t *testing.T, ds *Datastore) { expectedSpecs := setupLabelSpecsTest(t, ds) for _, s := range expectedSpecs { - spec, err := ds.GetLabelSpec(context.Background(), s.Name) + spec, err := ds.GetLabelSpec(context.Background(), fleet.TeamFilter{}, s.Name) require.Nil(t, err) require.True(t, cmp.Equal(s, spec, cmp.FilterPath(func(p cmp.Path) bool { @@ -810,16 +705,19 @@ func testLabelsGetSpec(t *testing.T, ds *Datastore) { } func testLabelsApplySpecsRoundtrip(t *testing.T, ds *Datastore) { - expectedSpecs := setupLabelSpecsTest(t, ds) + // TODO test team labels - specs, err := ds.GetLabelSpecs(context.Background()) + expectedSpecs := setupLabelSpecsTest(t, ds) + globalOnlyFilter := fleet.TeamFilter{} + + specs, err := ds.GetLabelSpecs(context.Background(), globalOnlyFilter) require.Nil(t, err) test.ElementsMatchSkipTimestampsID(t, expectedSpecs, specs) // Should be idempotent err = ds.ApplyLabelSpecs(context.Background(), expectedSpecs) require.Nil(t, err) - specs, err = ds.GetLabelSpecs(context.Background()) + specs, err = ds.GetLabelSpecs(context.Background(), globalOnlyFilter) require.Nil(t, err) test.ElementsMatchSkipTimestampsID(t, expectedSpecs, specs) } @@ -827,7 +725,9 @@ func testLabelsApplySpecsRoundtrip(t *testing.T, ds *Datastore) { func testLabelsIDsByName(t *testing.T, ds *Datastore) { setupLabelSpecsTest(t, ds) - labels, err := ds.LabelIDsByName(context.Background(), []string{"foo", "bar", "bing"}) + // TODO test team labels + + labels, err := ds.LabelIDsByName(context.Background(), []string{"foo", "bar", "bing"}, fleet.TeamFilter{}) require.Nil(t, err) assert.Equal(t, map[string]uint{"foo": 1, "bar": 2, "bing": 3}, labels) } @@ -835,8 +735,10 @@ func testLabelsIDsByName(t *testing.T, ds *Datastore) { func testLabelsByName(t *testing.T, ds *Datastore) { setupLabelSpecsTest(t, ds) + // TODO test team labels + names := []string{"foo", "bar", "bing"} - labels, err := ds.LabelsByName(context.Background(), names) + labels, err := ds.LabelsByName(context.Background(), names, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, labels, 3) for _, name := range names { @@ -856,6 +758,10 @@ func testLabelsByName(t *testing.T, ds *Datastore) { } } +func testLabelByName(t *testing.T, ds *Datastore) { + // TODO implement, including team filtering +} + func testLabelsSave(t *testing.T, db *Datastore) { h1, err := db.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), @@ -987,6 +893,8 @@ func testLabelsRecordNonexistentQueryLabelExecution(t *testing.T, db *Datastore) } func testDeleteLabel(t *testing.T, db *Datastore) { + // TODO test team label filtering + ctx := context.Background() l, err := db.NewLabel(ctx, &fleet.Label{ Name: t.Name(), @@ -1000,7 +908,7 @@ func testDeleteLabel(t *testing.T, db *Datastore) { }) require.NoError(t, err) - require.NoError(t, db.DeleteLabel(ctx, l.Name)) + require.NoError(t, db.DeleteLabel(ctx, l.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) newP, err := db.Pack(ctx, p.ID) require.NoError(t, err) @@ -1009,7 +917,7 @@ func testDeleteLabel(t *testing.T, db *Datastore) { require.NoError(t, db.DeletePack(ctx, newP.Name)) // delete a non-existing label - err = db.DeleteLabel(ctx, "no-such-label") + err = db.DeleteLabel(ctx, "no-such-label", fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.Error(t, err) var nfe fleet.NotFoundError require.ErrorAs(t, err, &nfe) @@ -1043,16 +951,16 @@ func testDeleteLabel(t *testing.T, db *Datastore) { }) // try to delete that label referenced by software installer - err = db.DeleteLabel(ctx, l2.Name) + err = db.DeleteLabel(ctx, l2.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.Error(t, err) require.True(t, fleet.IsForeignKey(err)) } -func testLabelsSummary(t *testing.T, db *Datastore) { +func testLabelsSummaryAndListTeamFiltering(t *testing.T, db *Datastore) { test.AddAllHostsLabel(t, db) // Only 'All Hosts' label should be returned - labels, err := db.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}) + labels, err := db.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}, false) require.NoError(t, err) require.Len(t, labels, 1) @@ -1077,7 +985,28 @@ func testLabelsSummary(t *testing.T, db *Datastore) { err = db.ApplyLabelSpecs(context.Background(), newLabels) require.Nil(t, err) - labels, err = db.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}) + team1, err := db.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) + require.NoError(t, err) + team2, err := db.NewTeam(context.Background(), &fleet.Team{Name: "team2"}) + require.NoError(t, err) + team3, err := db.NewTeam(context.Background(), &fleet.Team{Name: "team3"}) + require.NoError(t, err) + + team1Label, err := db.NewLabel(context.Background(), &fleet.Label{ + Name: "t1 label", + LabelMembershipType: fleet.LabelMembershipTypeManual, + TeamID: &team1.ID, + }) + require.NoError(t, err) + team2Label, err := db.NewLabel(context.Background(), &fleet.Label{ + Name: "t2 label", + LabelMembershipType: fleet.LabelMembershipTypeManual, + TeamID: &team2.ID, + }) + require.NoError(t, err) + + // should only show global labels + labels, err = db.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}, false) require.NoError(t, err) require.Len(t, labels, 4) labelsByID := make(map[uint]*fleet.Label) @@ -1085,7 +1014,8 @@ func testLabelsSummary(t *testing.T, db *Datastore) { labelsByID[l.ID] = l } - ls, err := db.LabelsSummary(context.Background()) + // should show only global labels + ls, err := db.LabelsSummary(context.Background(), fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, ls, 4) for _, l := range ls { @@ -1101,9 +1031,125 @@ func testLabelsSummary(t *testing.T, db *Datastore) { }) require.NoError(t, err) - ls, err = db.LabelsSummary(context.Background()) + ls, err = db.LabelsSummary(context.Background(), fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, ls, 5) + + for _, tc := range []struct { + name string + filter fleet.TeamFilter + expectedErr error + expectedTeamLabels map[*fleet.Team]*fleet.Label + }{ + { + name: "explicit global filter", + filter: fleet.TeamFilter{ + User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: team1.ID}, Role: fleet.RoleObserver}}}, + TeamID: ptr.Uint(0), + }, + }, + { + name: "global role filtered to team", + filter: fleet.TeamFilter{ + User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}, + TeamID: &team1.ID, + }, + expectedTeamLabels: map[*fleet.Team]*fleet.Label{team1: team1Label}, + }, + { + name: "team role filtered to user-accessible team", + filter: fleet.TeamFilter{ + User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: team1.ID}, Role: fleet.RoleObserverPlus}}}, + TeamID: &team1.ID, + }, + expectedTeamLabels: map[*fleet.Team]*fleet.Label{team1: team1Label}, + }, + { + name: "team role filtered to inaccessible team", + filter: fleet.TeamFilter{ + User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: team1.ID}, Role: fleet.RoleObserverPlus}}}, + TeamID: &team2.ID, + }, + expectedErr: errInaccessibleTeam, + }, + { + name: "global role with no team filter", + filter: fleet.TeamFilter{ + User: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}, + }, + expectedTeamLabels: map[*fleet.Team]*fleet.Label{team1: team1Label, team2: team2Label}, + }, + { + name: "single-team user with no team filter", + filter: fleet.TeamFilter{ + User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: team1.ID}, Role: fleet.RoleObserverPlus}}}, + }, + expectedTeamLabels: map[*fleet.Team]*fleet.Label{team1: team1Label}, + }, + { + name: "multi-team user with no team filter, partial overlap with labels", + filter: fleet.TeamFilter{ + User: &fleet.User{Teams: []fleet.UserTeam{ + {Team: fleet.Team{ID: team1.ID}, Role: fleet.RoleObserverPlus}, + {Team: fleet.Team{ID: team3.ID}, Role: fleet.RoleMaintainer}, + }}, + }, + expectedTeamLabels: map[*fleet.Team]*fleet.Label{team1: team1Label}, + }, + { + name: "multi-team user with no team filter, full overlap with labels", + filter: fleet.TeamFilter{ + User: &fleet.User{Teams: []fleet.UserTeam{ + {Team: fleet.Team{ID: team1.ID}, Role: fleet.RoleObserverPlus}, + {Team: fleet.Team{ID: team2.ID}, Role: fleet.RoleMaintainer}, + }}, + }, + expectedTeamLabels: map[*fleet.Team]*fleet.Label{team1: team1Label, team2: team2Label}, + }, + } { + t.Run(tc.name+" summary", func(t *testing.T) { + ls, err := db.LabelsSummary(context.Background(), tc.filter) + if tc.expectedErr != nil { + require.ErrorContains(t, err, tc.expectedErr.Error()) + return + } + require.NoError(t, err) + require.Len(t, ls, 5+len(tc.expectedTeamLabels)) + + foundTeamLabels := make(map[uint]fleet.LabelSummary) + for _, l := range ls { + if l.TeamID != nil { + foundTeamLabels[*l.TeamID] = *l + } + } + for team, label := range tc.expectedTeamLabels { + foundLabel, labelInMap := foundTeamLabels[team.ID] + require.Truef(t, labelInMap, "%s label should have been found", team.Name) + require.Equalf(t, label.ID, foundLabel.ID, "Found team label %s label did not match expected (%s)", foundLabel.Name, label.Name) + } + }) + t.Run(tc.name+" list", func(t *testing.T) { + ls, err := db.ListLabels(context.Background(), tc.filter, fleet.ListOptions{}, false) + if tc.expectedErr != nil { + require.ErrorContains(t, err, tc.expectedErr.Error()) + return + } + require.NoError(t, err) + require.Len(t, ls, 5+len(tc.expectedTeamLabels)) + + foundTeamLabels := make(map[uint]fleet.Label) + for _, l := range ls { + if l.TeamID != nil { + foundTeamLabels[*l.TeamID] = *l + } + } + for team, label := range tc.expectedTeamLabels { + foundLabel, labelInMap := foundTeamLabels[team.ID] + require.Truef(t, labelInMap, "%s label should have been found", team.Name) + require.Equalf(t, label.ID, foundLabel.ID, "Found team label %s label did not match expected (%s)", foundLabel.Name, label.Name) + } + }) + } } func testListHostsInLabelIssues(t *testing.T, ds *Datastore) { @@ -1804,7 +1850,7 @@ func testAddDeleteLabelsToFromHost(t *testing.T, ds *Datastore) { } func labelIDFromName(t *testing.T, ds fleet.Datastore, name string) uint { - allLbls, err := ds.ListLabels(context.Background(), fleet.TeamFilter{User: test.UserAdmin}, fleet.ListOptions{}) + allLbls, err := ds.ListLabels(context.Background(), fleet.TeamFilter{User: test.UserAdmin}, fleet.ListOptions{}, false) require.Nil(t, err) for _, lbl := range allLbls { if lbl.Name == name { @@ -1815,6 +1861,8 @@ func labelIDFromName(t *testing.T, ds fleet.Datastore, name string) uint { } func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { + // TODO validate team label host validation behavior + ctx := context.Background() filter := fleet.TeamFilter{User: test.UserAdmin} @@ -1853,7 +1901,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.NoError(t, err) // add hosts 1 and 2 to the label - label, hostIDs, err := ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host1.ID, host2.ID}, filter) + label, hostIDs, err := ds.UpdateLabelMembershipByHostIDs(ctx, *label1, []uint{host1.ID, host2.ID}, filter) require.NoError(t, err) require.Equal(t, label.HostCount, 2) @@ -1865,7 +1913,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.Equal(t, host1.ID, hostIDs[0]) require.Equal(t, host2.ID, hostIDs[1]) - labelSpec, err := ds.GetLabelSpec(ctx, label1.Name) + labelSpec, err := ds.GetLabelSpec(ctx, fleet.TeamFilter{}, label1.Name) // only need global labels, so this works require.NoError(t, err) // label.Hosts contains hostnames require.Len(t, labelSpec.Hosts, 2) @@ -1887,7 +1935,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.Len(t, labels, 0) // modify the label to contain hosts 1 and 3, confirm - label, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host1.ID, host3.ID}, filter) + label, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *label1, []uint{host1.ID, host3.ID}, filter) require.NoError(t, err) require.Equal(t, label.HostCount, 2) @@ -1907,7 +1955,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.Equal(t, "label1", labels[0].Name) // modify the label to contain hosts 2 and 3, confirm - label, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host2.ID, host3.ID}, filter) + label, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *label1, []uint{host2.ID, host3.ID}, filter) require.NoError(t, err) require.Equal(t, label.HostCount, 2) @@ -1927,7 +1975,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.Equal(t, "label1", labels[0].Name) // modify the label to contain no hosts, confirm - label, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{}, filter) + label, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *label1, []uint{}, filter) require.NoError(t, err) require.Equal(t, label.HostCount, 0) @@ -1944,7 +1992,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.Len(t, labels, 0) // modify the label to contain all 3 hosts, confirm - label, hostIDs, err = ds.UpdateLabelMembershipByHostIDs(ctx, label1.ID, []uint{host1.ID, host2.ID, host3.ID}, filter) + label, hostIDs, err = ds.UpdateLabelMembershipByHostIDs(ctx, *label1, []uint{host1.ID, host2.ID, host3.ID}, filter) require.NoError(t, err) require.Equal(t, label.HostCount, 3) @@ -1971,7 +2019,7 @@ func testUpdateLabelMembershipByHostIDs(t *testing.T, ds *Datastore) { require.Equal(t, host2.ID, hostIDs[1]) require.Equal(t, host3.ID, hostIDs[2]) - labelSpec, err = ds.GetLabelSpec(ctx, label1.Name) + labelSpec, err = ds.GetLabelSpec(ctx, fleet.TeamFilter{}, label1.Name) // only need global labels, so this works require.NoError(t, err) // label.Hosts contains hostnames @@ -2101,7 +2149,7 @@ func testApplyLabelSpecsWithPlatformChange(t *testing.T, ds *Datastore) { require.NoError(t, err) // Get the label ID - labels, err := ds.LabelsByName(ctx, []string{"platform_test_label"}) + labels, err := ds.LabelsByName(ctx, []string{"platform_test_label"}, fleet.TeamFilter{}) require.NoError(t, err) label := labels["platform_test_label"] require.NotNil(t, label) @@ -2192,8 +2240,20 @@ func (t *TestHostVitalsLabel) GetLabel() *fleet.Label { func testUpdateLabelMembershipByHostCriteria(t *testing.T, ds *Datastore) { ctx := context.Background() + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"}) + require.NoError(t, err) + team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) + require.NoError(t, err) + hosts := make([]*fleet.Host, 4) for i := 1; i <= 4; i++ { + var teamID *uint + if i == 1 || i == 2 { + teamID = &team1.ID + } else if i == 3 { + teamID = &team2.ID + } + host, err := ds.NewHost(ctx, &fleet.Host{ OsqueryHostID: ptr.String(fmt.Sprintf("%d", i)), NodeKey: ptr.String(fmt.Sprintf("%d", i)), @@ -2201,6 +2261,7 @@ func testUpdateLabelMembershipByHostCriteria(t *testing.T, ds *Datastore) { Hostname: fmt.Sprintf("host%d.local", i), HardwareSerial: fmt.Sprintf("hwd%d", i), Platform: "darwin", + TeamID: teamID, }) require.NoError(t, err) hosts[i-1] = host @@ -2208,10 +2269,10 @@ func testUpdateLabelMembershipByHostCriteria(t *testing.T, ds *Datastore) { // Add users to the hosts ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { _, err := q.ExecContext(ctx, ` - INSERT INTO host_users (host_id, uid, username) VALUES - (?, ?, ?), - (?, ?, ?), - (?, ?, ?), + INSERT INTO host_users (host_id, uid, username) VALUES + (?, ?, ?), + (?, ?, ?), + (?, ?, ?), (?, ?, ?), (?, ?, ?)`, hosts[0].ID, 1, "user1", @@ -2228,49 +2289,80 @@ func testUpdateLabelMembershipByHostCriteria(t *testing.T, ds *Datastore) { }) require.NoError(t, err) - var id uint + var ids []uint ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { - result, err := q.ExecContext(context.Background(), - "INSERT INTO labels (name, description, platform, label_type, label_membership_type, query) VALUES (?, ?, ?, ?, ?, ?)", - "test host vitals label", "test", "", fleet.LabelTypeRegular, fleet.LabelMembershipTypeHostVitals, "") - if err != nil { - return err + for _, teamID := range []*uint{nil, &team1.ID, &team2.ID} { + result, err := q.ExecContext(context.Background(), + "INSERT INTO labels (name, description, platform, label_type, label_membership_type, query, team_id) VALUES (?, ?, ?, ?, ?, ?, ?)", + fmt.Sprintf("test host vitals label %d", teamID), "test", "", fleet.LabelTypeRegular, fleet.LabelMembershipTypeHostVitals, "", teamID) + if err != nil { + return err + } + id64, err := result.LastInsertId() + if err != nil { + return err + } + ids = append(ids, uint(id64)) // nolint:gosec } - id64, err := result.LastInsertId() - if err != nil { - return err - } - id = uint(id64) // nolint:gosec return nil }) - label := &TestHostVitalsLabel{ - Label: fleet.Label{ - ID: id, - Name: "Test Host Vitals Label", - LabelType: fleet.LabelTypeRegular, - LabelMembershipType: fleet.LabelMembershipTypeHostVitals, - HostVitalsCriteria: ptr.RawMessage(criteria), + testCases := []struct { + LabelID uint + TeamID *uint + BeforeHostIDs []uint + AfterHostIDs []uint + }{ + { + ids[0], + nil, + []uint{hosts[0].ID, hosts[2].ID}, // Only hosts 1 and 3 should match the criteria (user1) + []uint{hosts[1].ID, hosts[2].ID, hosts[3].ID}, // Only hosts 2, 3 and 4 should match the criteria (user1) }, + { + ids[1], + &team1.ID, + []uint{hosts[0].ID}, // Only host 1 is on the team affected by the label + []uint{hosts[1].ID}, // Only host 2 is on the team affected by the label after vitals changes + }, + } + + makeLabel := func(id uint, teamID *uint) *TestHostVitalsLabel { + return &TestHostVitalsLabel{ + Label: fleet.Label{ + ID: id, + TeamID: teamID, + Name: fmt.Sprintf("Test Host Vitals Label %d", teamID), + LabelType: fleet.LabelTypeRegular, + LabelMembershipType: fleet.LabelMembershipTypeHostVitals, + HostVitalsCriteria: ptr.RawMessage(criteria), + }, + } } filter := fleet.TeamFilter{User: test.UserAdmin} - updatedLabel, err := ds.UpdateLabelMembershipByHostCriteria(ctx, label) - require.NoError(t, err) - require.Equal(t, 2, updatedLabel.HostCount) + for _, tt := range testCases { + updatedLabel, err := ds.UpdateLabelMembershipByHostCriteria(ctx, makeLabel(tt.LabelID, tt.TeamID)) + require.NoError(t, err) + require.Equal(t, len(tt.BeforeHostIDs), updatedLabel.HostCount) - // Check that the label has the correct hosts - hostsInLabel, err := ds.ListHostsInLabel(ctx, filter, label.ID, fleet.HostListOptions{}) - require.NoError(t, err) - require.Len(t, hostsInLabel, 2) // Only hosts 1 and 3 should match the criteria (user1) - require.ElementsMatch(t, []uint{hosts[0].ID, hosts[2].ID}, []uint{hostsInLabel[0].ID, hostsInLabel[1].ID}) + // Check that the label has the correct hosts + hostsInLabel, err := ds.ListHostsInLabel(ctx, filter, tt.LabelID, fleet.HostListOptions{}) + require.NoError(t, err) + require.Len(t, hostsInLabel, len(tt.BeforeHostIDs)) + labelHostIDs := make([]uint, 0, len(hostsInLabel)) + for _, host := range hostsInLabel { + labelHostIDs = append(labelHostIDs, host.ID) + } + require.ElementsMatch(t, tt.BeforeHostIDs, labelHostIDs) + } // Update host users. ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { _, err := q.ExecContext(ctx, ` - INSERT INTO host_users (host_id, uid, username) VALUES - (?, ?, ?), + INSERT INTO host_users (host_id, uid, username) VALUES + (?, ?, ?), (?, ?, ?), (?, ?, ?) ON DUPLICATE KEY UPDATE username = VALUES(username), uid = VALUES(uid)`, hosts[0].ID, 2, "user2", @@ -2284,15 +2376,22 @@ func testUpdateLabelMembershipByHostCriteria(t *testing.T, ds *Datastore) { hosts[0].ID, 1) // Remove user1 from host 1 return err }) - updatedLabel, err = ds.UpdateLabelMembershipByHostCriteria(ctx, label) - require.NoError(t, err) - require.Equal(t, 3, updatedLabel.HostCount) - // Check that the label has the correct hosts - hostsInLabel, err = ds.ListHostsInLabel(ctx, filter, label.ID, fleet.HostListOptions{}) - require.NoError(t, err) - require.Len(t, hostsInLabel, 3) // Only hosts 2, 3 and 4 should match the criteria (user1) - require.ElementsMatch(t, []uint{hosts[1].ID, hosts[2].ID, hosts[3].ID}, []uint{hostsInLabel[0].ID, hostsInLabel[1].ID, hostsInLabel[2].ID}) + for _, tt := range testCases { + updatedLabel, err := ds.UpdateLabelMembershipByHostCriteria(ctx, makeLabel(tt.LabelID, tt.TeamID)) + require.NoError(t, err) + require.Equal(t, len(tt.AfterHostIDs), updatedLabel.HostCount) + + // Check that the label has the correct hosts + hostsInLabel, err := ds.ListHostsInLabel(ctx, filter, tt.LabelID, fleet.HostListOptions{}) + require.NoError(t, err) + require.Len(t, hostsInLabel, len(tt.AfterHostIDs)) + labelHostIDs := make([]uint, 0, len(hostsInLabel)) + for _, host := range hostsInLabel { + labelHostIDs = append(labelHostIDs, host.ID) + } + require.ElementsMatch(t, tt.AfterHostIDs, labelHostIDs) + } } func testTeamLabels(t *testing.T, ds *Datastore) { @@ -2484,3 +2583,7 @@ func testUpdateLabelMembershipForTransferredHost(t *testing.T, ds *Datastore) { require.Len(t, labels, 1) require.Equal(t, "global", labels[0].Name) } + +func testSetAsideLabels(t *testing.T, ds *Datastore) { + // TODO +} diff --git a/server/datastore/mysql/mdm_test.go b/server/datastore/mysql/mdm_test.go index c280f98c12..3d287d0940 100644 --- a/server/datastore/mysql/mdm_test.go +++ b/server/datastore/mysql/mdm_test.go @@ -1052,9 +1052,9 @@ func testListMDMConfigProfiles(t *testing.T, ds *Datastore) { require.NoError(t, err) } // delete label 3, 4 and 8 so that profiles D, E and G are broken - require.NoError(t, ds.DeleteLabel(ctx, labels[3].Name)) - require.NoError(t, ds.DeleteLabel(ctx, labels[4].Name)) - require.NoError(t, ds.DeleteLabel(ctx, labels[8].Name)) + require.NoError(t, ds.DeleteLabel(ctx, labels[3].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) + require.NoError(t, ds.DeleteLabel(ctx, labels[4].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) + require.NoError(t, ds.DeleteLabel(ctx, labels[8].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) profLabels := map[string][]fleet.ConfigurationProfileLabel{ "C": { {LabelName: labels[0].Name, LabelID: labels[0].ID, RequireAll: true}, @@ -3809,8 +3809,8 @@ func testBulkSetPendingMDMHostProfiles(t *testing.T, ds *Datastore) { }) // "break" the two G6 label-based profile by deleting labels[0] and [3] - require.NoError(t, ds.DeleteLabel(ctx, labels[0].Name)) - require.NoError(t, ds.DeleteLabel(ctx, labels[3].Name)) + require.NoError(t, ds.DeleteLabel(ctx, labels[0].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) + require.NoError(t, ds.DeleteLabel(ctx, labels[3].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) // sync the affected profiles updates, err = ds.BulkSetPendingMDMHostProfiles( @@ -4868,8 +4868,8 @@ func testBulkSetPendingMDMHostProfiles(t *testing.T, ds *Datastore) { }) // "break" the team 2 label-based profile by deleting a label - require.NoError(t, ds.DeleteLabel(ctx, labels[1].Name)) - require.NoError(t, ds.DeleteLabel(ctx, labels[4].Name)) + require.NoError(t, ds.DeleteLabel(ctx, labels[1].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) + require.NoError(t, ds.DeleteLabel(ctx, labels[4].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) // sync team 2, the label-based profile of team2 is left untouched (broken // profiles are ignored) @@ -5982,7 +5982,7 @@ func testGetHostMDMProfilesExpectedForVerification(t *testing.T, ds *Datastore) require.Len(t, profs, 3) // Now delete label, we shouldn't see the related profile - err = ds.DeleteLabel(ctx, testLabel4.Name) + err = ds.DeleteLabel(ctx, testLabel4.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) return team.ID, host @@ -6483,7 +6483,7 @@ func testGetHostMDMProfilesExpectedForVerification(t *testing.T, ds *Datastore) require.Len(t, profs, 3) // Now delete label, we shouldn't see the related profile - err = ds.DeleteLabel(ctx, label.Name) + err = ds.DeleteLabel(ctx, label.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) return team.ID, host @@ -8275,13 +8275,13 @@ func testBulkSetPendingMDMHostProfilesExcludeAny(t *testing.T, ds *Datastore) { }) // delete labels 0, 2, 3, and 6, breaking all profiles - err = ds.DeleteLabel(ctx, labels[0].Name) + err = ds.DeleteLabel(ctx, labels[0].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) - err = ds.DeleteLabel(ctx, labels[2].Name) + err = ds.DeleteLabel(ctx, labels[2].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) - err = ds.DeleteLabel(ctx, labels[3].Name) + err = ds.DeleteLabel(ctx, labels[3].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) - err = ds.DeleteLabel(ctx, labels[6].Name) + err = ds.DeleteLabel(ctx, labels[6].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) updates, err = ds.BulkSetPendingMDMHostProfiles(ctx, []uint{winHost.ID, appleHost.ID, androidHost.ID}, nil, nil, nil) diff --git a/server/datastore/mysql/microsoft_mdm_test.go b/server/datastore/mysql/microsoft_mdm_test.go index 700b0d9998..148f7296be 100644 --- a/server/datastore/mysql/microsoft_mdm_test.go +++ b/server/datastore/mysql/microsoft_mdm_test.go @@ -2144,7 +2144,7 @@ func testMDMWindowsConfigProfiles(t *testing.T, ds *Datastore) { require.False(t, prof.LabelsIncludeAll[0].Broken) // break that profile by deleting the label - require.NoError(t, ds.DeleteLabel(ctx, label.Name)) + require.NoError(t, ds.DeleteLabel(ctx, label.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) prof, err = ds.GetMDMWindowsConfigProfile(ctx, profWithLabel.ProfileUUID) require.NoError(t, err) diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index 97e85c3cb3..90b4169719 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -910,7 +910,7 @@ func (ds *Datastore) whereFilterHostsByTeams(filter fleet.TeamFilter, hostKey st return fmt.Sprintf("%s.team_id IN (%s)", hostKey, strings.Join(idStrs, ",")) } -// whereFilterGlobalOrTeamIDByTeams is the same as whereFilterHostsByTeams, it +// whereFilterTeamWithGlobalStats is the same as whereFilterHostsByTeams, it // returns the appropriate condition to use in the WHERE clause to render only // the appropriate teams, but is to be used when the team_id column uses "0" to // mean "all teams including no team". This is the case e.g. for @@ -919,7 +919,7 @@ func (ds *Datastore) whereFilterHostsByTeams(filter fleet.TeamFilter, hostKey st // filter provides the filtering parameters that should be used. // filterTableAlias is the name/alias of the table to use in generating the // SQL. -func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, filterTableAlias string) string { +func (ds *Datastore) whereFilterTeamWithGlobalStats(filter fleet.TeamFilter, filterTableAlias string) string { globalFilter := fmt.Sprintf("%s.team_id = 0 AND %[1]s.global_stats = 1", filterTableAlias) teamIDFilter := fmt.Sprintf("%s.team_id", filterTableAlias) return ds.whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(filter, globalFilter, teamIDFilter) diff --git a/server/datastore/mysql/mysql_test.go b/server/datastore/mysql/mysql_test.go index 48a777b75e..ba7484b69b 100644 --- a/server/datastore/mysql/mysql_test.go +++ b/server/datastore/mysql/mysql_test.go @@ -1032,7 +1032,7 @@ func Test_buildWildcardMatchPhrase(t *testing.T) { } } -func TestWhereFilterGlobalOrTeamIDByTeams(t *testing.T) { +func TestWhereFilterTeamWithGlobalStats(t *testing.T) { t.Parallel() testCases := []struct { @@ -1243,7 +1243,7 @@ func TestWhereFilterGlobalOrTeamIDByTeams(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() ds := &Datastore{logger: log.NewNopLogger()} - sql := ds.whereFilterGlobalOrTeamIDByTeams(tt.filter, "hosts") + sql := ds.whereFilterTeamWithGlobalStats(tt.filter, "hosts") assert.Equal(t, tt.expected, sql) }) } diff --git a/server/datastore/mysql/software.go b/server/datastore/mysql/software.go index 037ea76e89..d87082d056 100644 --- a/server/datastore/mysql/software.go +++ b/server/datastore/mysql/software.go @@ -2398,7 +2398,7 @@ func (ds *Datastore) SoftwareByID(ctx context.Context, id uint, teamID *uint, in // filter by teams if tmFilter != nil { - q = q.Where(goqu.L(ds.whereFilterGlobalOrTeamIDByTeams(*tmFilter, "shc"))) + q = q.Where(goqu.L(ds.whereFilterTeamWithGlobalStats(*tmFilter, "shc"))) } sql, args, err := q.ToSQL() diff --git a/server/datastore/mysql/software_installers_test.go b/server/datastore/mysql/software_installers_test.go index 9f1db8ad8f..9836204fc9 100644 --- a/server/datastore/mysql/software_installers_test.go +++ b/server/datastore/mysql/software_installers_test.go @@ -302,7 +302,7 @@ func testSoftwareInstallRequests(t *testing.T, ds *Datastore) { user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) createBuiltinLabels(t, ds) - labelsByName, err := ds.LabelIDsByName(ctx, []string{fleet.BuiltinLabelNameAllHosts}) + labelsByName, err := ds.LabelIDsByName(ctx, []string{fleet.BuiltinLabelNameAllHosts}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, labelsByName, 1) diff --git a/server/datastore/mysql/software_test.go b/server/datastore/mysql/software_test.go index a80cef9d94..9c9b862765 100644 --- a/server/datastore/mysql/software_test.go +++ b/server/datastore/mysql/software_test.go @@ -8905,25 +8905,25 @@ func testLabelScopingTimestampLogic(t *testing.T, ds *Datastore) { }) // Dynamic label - label1, err := ds.NewLabel(ctx, &fleet.Label{Name: "label1" + t.Name(), LabelMembershipType: fleet.LabelMembershipTypeDynamic}) + label1Orig, err := ds.NewLabel(ctx, &fleet.Label{Name: "label1" + t.Name(), LabelMembershipType: fleet.LabelMembershipTypeDynamic}) require.NoError(t, err) // Manual label - label2, err := ds.NewLabel(ctx, &fleet.Label{Name: "label2" + t.Name(), LabelMembershipType: fleet.LabelMembershipTypeManual}) + label2Orig, err := ds.NewLabel(ctx, &fleet.Label{Name: "label2" + t.Name(), LabelMembershipType: fleet.LabelMembershipTypeManual}) require.NoError(t, err) // make sure the label is created after the host's labels_updated_at timestamp ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { - _, err = q.ExecContext(ctx, `UPDATE labels SET created_at = ? WHERE id in (?, ?)`, host.LabelUpdatedAt.Add(time.Hour), label1.ID, label2.ID) + _, err = q.ExecContext(ctx, `UPDATE labels SET created_at = ? WHERE id in (?, ?)`, host.LabelUpdatedAt.Add(time.Hour), label1Orig.ID, label2Orig.ID) if err != nil { return err } return nil }) // refetch labels to ensure their state is correct - label1, _, err = ds.Label(ctx, label1.ID, fleet.TeamFilter{}) + label1, _, err := ds.Label(ctx, label1Orig.ID, fleet.TeamFilter{}) require.NoError(t, err) - label2, _, err = ds.Label(ctx, label2.ID, fleet.TeamFilter{}) + label2, _, err := ds.Label(ctx, label2Orig.ID, fleet.TeamFilter{}) require.NoError(t, err) require.Greater(t, label1.CreatedAt, host.LabelUpdatedAt) diff --git a/server/datastore/mysql/software_titles.go b/server/datastore/mysql/software_titles.go index 322579bbad..c3492c5813 100644 --- a/server/datastore/mysql/software_titles.go +++ b/server/datastore/mysql/software_titles.go @@ -38,7 +38,7 @@ func (ds *Datastore) SoftwareTitleByID(ctx context.Context, id uint, teamID *uin vppAppsTeamsGlobalOrTeamIDFilter = fmt.Sprintf("vat.global_or_team_id = %d", *teamID) inHouseAppsTeamsGlobalOrTeamIDFilter = fmt.Sprintf("iha.global_or_team_id = %d", *teamID) } else { - teamFilter = ds.whereFilterGlobalOrTeamIDByTeams(tmFilter, "sthc") + teamFilter = ds.whereFilterTeamWithGlobalStats(tmFilter, "sthc") softwareInstallerGlobalOrTeamIDFilter = "TRUE" vppAppsTeamsGlobalOrTeamIDFilter = "TRUE" inHouseAppsTeamsGlobalOrTeamIDFilter = "TRUE" @@ -621,7 +621,7 @@ func (ds *Datastore) selectSoftwareVersionsSQL(titleIDs []uint, teamID *uint, tm if teamID != nil { teamFilter = fmt.Sprintf("shc.team_id = %d", *teamID) } else { - teamFilter = ds.whereFilterGlobalOrTeamIDByTeams(tmFilter, "shc") + teamFilter = ds.whereFilterTeamWithGlobalStats(tmFilter, "shc") } selectVersionsStmt := ` diff --git a/server/datastore/mysql/targets_test.go b/server/datastore/mysql/targets_test.go index 8d402fee12..e042c11065 100644 --- a/server/datastore/mysql/targets_test.go +++ b/server/datastore/mysql/targets_test.go @@ -399,23 +399,23 @@ func testTargetsHostIDsInTargets(t *testing.T, ds *Datastore) { allLinux, _, err := ds.Label(context.Background(), 12, filter) require.NoError(t, err) - allBuiltIn := []*fleet.Label{ + allBuiltIn := []*fleet.LabelWithTeamName{ allHosts, macOS, ubuntuLinux, centOSLinux, msWindows, redHatLinux, allLinux, } for _, item := range []struct { host *fleet.Host - labels map[*fleet.Label]struct{} + labels map[*fleet.LabelWithTeamName]struct{} }{ { host: h1, - labels: map[*fleet.Label]struct{}{ + labels: map[*fleet.LabelWithTeamName]struct{}{ allHosts: {}, macOS: {}, }, }, { host: h2, - labels: map[*fleet.Label]struct{}{ + labels: map[*fleet.LabelWithTeamName]struct{}{ allHosts: {}, centOSLinux: {}, allLinux: {}, @@ -423,7 +423,7 @@ func testTargetsHostIDsInTargets(t *testing.T, ds *Datastore) { }, { host: h3, - labels: map[*fleet.Label]struct{}{ + labels: map[*fleet.LabelWithTeamName]struct{}{ allHosts: {}, ubuntuLinux: {}, allLinux: {}, @@ -431,21 +431,21 @@ func testTargetsHostIDsInTargets(t *testing.T, ds *Datastore) { }, { host: h4, - labels: map[*fleet.Label]struct{}{ + labels: map[*fleet.LabelWithTeamName]struct{}{ allHosts: {}, msWindows: {}, }, }, { host: h5, - labels: map[*fleet.Label]struct{}{ + labels: map[*fleet.LabelWithTeamName]struct{}{ allHosts: {}, msWindows: {}, }, }, { host: h6, - labels: map[*fleet.Label]struct{}{ + labels: map[*fleet.LabelWithTeamName]struct{}{ allHosts: {}, macOS: {}, }, diff --git a/server/datastore/mysql/teams_test.go b/server/datastore/mysql/teams_test.go index 874122dd6e..758cceed86 100644 --- a/server/datastore/mysql/teams_test.go +++ b/server/datastore/mysql/teams_test.go @@ -328,7 +328,7 @@ func testTeamsGetSetDelete(t *testing.T, ds *Datastore) { require.NoError(t, ds.DeletePack(context.Background(), newP.Name)) // Check team label is gone. - labels, err := ds.LabelsByName(context.Background(), []string{teamLabel.Name}) + labels, err := ds.LabelsByName(context.Background(), []string{teamLabel.Name}, fleet.TeamFilter{}) require.NoError(t, err) require.Empty(t, labels) diff --git a/server/fleet/agent_options.go b/server/fleet/agent_options.go index 5058e22f6a..0c4fb391f9 100644 --- a/server/fleet/agent_options.go +++ b/server/fleet/agent_options.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "strings" + + "github.com/fleetdm/fleet/v4/server/ptr" ) //go:generate go run ../../tools/osquery-agent-options agent_options_generated.go @@ -67,7 +69,7 @@ func SuggestAgentOptionsCorrection(err error) error { // Options payload. It ensures that all fields are known and have valid values. // The validation always uses the most recent Osquery version that is available // at the time of the Fleet release. -func ValidateJSONAgentOptions(ctx context.Context, ds Datastore, rawJSON json.RawMessage, isPremium bool) error { +func ValidateJSONAgentOptions(ctx context.Context, ds Datastore, rawJSON json.RawMessage, isPremium bool, teamID uint) error { var opts AgentOptions if err := JSONStrictDecode(bytes.NewReader(rawJSON), &opts); err != nil { return err @@ -132,7 +134,7 @@ func ValidateJSONAgentOptions(ctx context.Context, ds Datastore, rawJSON json.Ra } if len(opts.Extensions) > 0 { - if err := validateJSONAgentOptionsExtensions(ctx, ds, opts.Extensions, isPremium); err != nil { + if err := validateJSONAgentOptionsExtensions(ctx, ds, opts.Extensions, isPremium, teamID); err != nil { return err } } @@ -156,23 +158,28 @@ func checkEmptyFields(prefix string, data json.RawMessage) error { return nil } -func validateJSONAgentOptionsExtensions(ctx context.Context, ds Datastore, optsExtensions json.RawMessage, isPremium bool) error { +func validateJSONAgentOptionsExtensions(ctx context.Context, ds Datastore, optsExtensions json.RawMessage, isPremium bool, teamID uint) error { var extensions map[string]ExtensionInfo if err := json.Unmarshal(optsExtensions, &extensions); err != nil { return fmt.Errorf("unmarshal extensions: %w", err) } + + // any user able to make it past auth checks elsewhere to modify agent options can see labels for the associated + // team; this filter is strictly to filter out mismatched team labels + teamFilter := TeamFilter{TeamID: &teamID, User: &User{GlobalRole: ptr.String(RoleAdmin)}} + for _, extensionInfo := range extensions { if !isPremium && len(extensionInfo.Labels) != 0 { // Setting labels settings in the extensions config is premium only. return ErrMissingLicense } for _, labelName := range extensionInfo.Labels { - switch _, err := ds.GetLabelSpec(ctx, labelName); { + switch _, err := ds.GetLabelSpec(ctx, teamFilter, labelName); { case err == nil: // OK case IsNotFound(err): // Label does not exist, fail the request. - return fmt.Errorf("Label %q does not exist", labelName) + return fmt.Errorf("Label %q does not exist, or cannot be used on this team", labelName) default: return fmt.Errorf("get label by name: %w", err) } diff --git a/server/fleet/agent_options_test.go b/server/fleet/agent_options_test.go index c44897b13a..6c663c29d2 100644 --- a/server/fleet/agent_options_test.go +++ b/server/fleet/agent_options_test.go @@ -202,7 +202,7 @@ func TestValidateAgentOptions(t *testing.T) { for _, c := range cases { t.Run(c.desc, func(t *testing.T) { - err := ValidateJSONAgentOptions(context.Background(), nil, []byte(c.in), c.isPremium) + err := ValidateJSONAgentOptions(context.Background(), nil, []byte(c.in), c.isPremium, 0) t.Logf("%T", errors.Unwrap(err)) if c.wantErr != "" { require.ErrorContains(t, err, c.wantErr) diff --git a/server/fleet/authz.go b/server/fleet/authz.go index 92ab377753..82ba3b92f6 100644 --- a/server/fleet/authz.go +++ b/server/fleet/authz.go @@ -9,6 +9,8 @@ const ( ActionWrite = "write" // ActionWriteHostLabel refers to writing labels on hosts. ActionWriteHostLabel = "write_host_label" + // ActionCreate refers to creating an entity when permissions differ from standard writes (e.g. global labels) + ActionCreate = "create" // ActionCancelHostActivity refers to canceling an upcoming activity on a host. ActionCancelHostActivity = "cancel_host_activity" diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index db6a151e42..5c53219232 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -190,21 +190,25 @@ type Datastore interface { ApplyLabelSpecs(ctx context.Context, specs []*LabelSpec) error // ApplyLabelSpecs does the same as ApplyLabelSpecs, additionally allowing an author ID to be set for the labels. ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*LabelSpec, authorId *uint) error - // GetLabelSpecs returns all of the stored LabelSpecs. - GetLabelSpecs(ctx context.Context) ([]*LabelSpec, error) - // GetLabelSpec returns the spec for the named label. - GetLabelSpec(ctx context.Context, name string) (*LabelSpec, error) + // SetAsideLabels moves a set of labels out of the way if those labels *aren't* on the specified team and *are* + // writable by the specified user + SetAsideLabels(ctx context.Context, notOnTeamID *uint, names []string, user User) error + // GetLabelSpecs returns all of the stored LabelSpecs that the user can see. + GetLabelSpecs(ctx context.Context, filter TeamFilter) ([]*LabelSpec, error) + // GetLabelSpec returns the spec for the named label, filtered by the provided team filter. + GetLabelSpec(ctx context.Context, filter TeamFilter, name string) (*LabelSpec, error) - // AddLabelsToHost adds the given label IDs membership to the host. + // AddLabelsToHost adds the given label IDs membership to the host, with the assumption that the label + // is available for the host (visibility checks are assumed to have been done prior to this call). // If a host is already a member of the label then this will update the row's updated_at. AddLabelsToHost(ctx context.Context, hostID uint, labelIDs []uint) error // RemoveLabelsFromHost removes the given label IDs membership from the host. // If a host is already not a member of a label then such label will be ignored. RemoveLabelsFromHost(ctx context.Context, hostID uint, labelIDs []uint) error - // UpdateLabelMembershipByHostIDs updates the label membership for the given label ID with host - // IDs, applied in batches - UpdateLabelMembershipByHostIDs(ctx context.Context, labelID uint, hostIds []uint, teamFilter TeamFilter) (*Label, []uint, error) + // UpdateLabelMembershipByHostIDs updates the label membership for the given label with host + // IDs, applied in batches, then returns the updated label + UpdateLabelMembershipByHostIDs(ctx context.Context, label Label, hostIds []uint, teamFilter TeamFilter) (*Label, []uint, error) // UpdateLabelMembershipByHostCriteria updates the label membership for the given label // based on its host vitals criteria. UpdateLabelMembershipByHostCriteria(ctx context.Context, hvl HostVitalsLabel) (*Label, error) @@ -212,12 +216,13 @@ type Datastore interface { NewLabel(ctx context.Context, label *Label, opts ...OptionalArg) (*Label, error) // SaveLabel updates the label and returns the label and an array of host IDs // members of this label, or an error. - SaveLabel(ctx context.Context, label *Label, teamFilter TeamFilter) (*Label, []uint, error) - DeleteLabel(ctx context.Context, name string) error + SaveLabel(ctx context.Context, label *Label, teamFilter TeamFilter) (*LabelWithTeamName, []uint, error) + DeleteLabel(ctx context.Context, name string, filter TeamFilter) error + LabelByName(ctx context.Context, name string, filter TeamFilter) (*Label, error) // Label returns the label and an array of host IDs members of this label, or an error. - Label(ctx context.Context, lid uint, teamFilter TeamFilter) (*Label, []uint, error) - ListLabels(ctx context.Context, filter TeamFilter, opt ListOptions) ([]*Label, error) - LabelsSummary(ctx context.Context) ([]*LabelSummary, error) + Label(ctx context.Context, lid uint, teamFilter TeamFilter) (*LabelWithTeamName, []uint, error) + ListLabels(ctx context.Context, filter TeamFilter, opt ListOptions, includeHostCounts bool) ([]*Label, error) + LabelsSummary(ctx context.Context, filter TeamFilter) ([]*LabelSummary, error) GetEnrollmentIDsWithPendingMDMAppleCommands(ctx context.Context) ([]string, error) @@ -231,16 +236,12 @@ type Datastore interface { // ListHostsInLabel returns a slice of hosts in the label with the given ID. ListHostsInLabel(ctx context.Context, filter TeamFilter, lid uint, opt HostListOptions) ([]*Host, error) - // ListUniqueHostsInLabels returns a slice of all of the hosts in the given label IDs. A host will only appear once - // in the results even if it is in multiple of the provided labels. - ListUniqueHostsInLabels(ctx context.Context, filter TeamFilter, labels []uint) ([]*Host, error) - SearchLabels(ctx context.Context, filter TeamFilter, query string, omit ...uint) ([]*Label, error) // LabelIDsByName retrieves the IDs associated with the given label names - LabelIDsByName(ctx context.Context, labels []string) (map[string]uint, error) + LabelIDsByName(ctx context.Context, labels []string, filter TeamFilter) (map[string]uint, error) // LabelsByName retrieves the labels associated with the given label names - LabelsByName(ctx context.Context, names []string) (map[string]*Label, error) + LabelsByName(ctx context.Context, names []string, filter TeamFilter) (map[string]*Label, error) // Methods used for async processing of host label query results. AsyncBatchInsertLabelMembership(ctx context.Context, batch [][2]uint) error diff --git a/server/fleet/labels.go b/server/fleet/labels.go index bd7e0e5299..17d10d94e7 100644 --- a/server/fleet/labels.go +++ b/server/fleet/labels.go @@ -154,6 +154,11 @@ type Label struct { TeamID *uint `json:"team_id" db:"team_id"` } +type LabelWithTeamName struct { + Label + TeamName *string `json:"team_name" db:"team_name"` +} + // Implement the HostVitalsLabel interface. func (l *Label) GetLabel() *Label { return l @@ -225,6 +230,7 @@ type LabelSpec struct { LabelMembershipType LabelMembershipType `json:"label_membership_type" db:"label_membership_type"` Hosts HostsSlice `json:"hosts"` HostVitalsCriteria *json.RawMessage `json:"criteria,omitempty" db:"criteria"` + TeamID *uint `json:"team_id" db:"team_id"` } const ( @@ -351,7 +357,7 @@ func (l *Label) CalculateHostVitalsQuery() (query string, values []any, err erro // We'll use a set to gather the foreign vitals groups we need to join on, // so that we can avoid duplicates. foreignVitalsGroups := make(map[*HostForeignVitalGroup]struct{}) - // Hold values to be substituted in the paramerized query. + // Hold values to be substituted in the parameterized query. values = make([]any, 0) // Recursively parse the criteria to build the WHERE clause. whereClause, err := parseHostVitalCriteria(criteria, foreignVitalsGroups, &values) diff --git a/server/fleet/service.go b/server/fleet/service.go index 945a8639e8..2db2681fd1 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -253,18 +253,19 @@ type Service interface { // ///////////////////////////////////////////////////////////////////////////// // LabelService - // ApplyLabelSpecs applies a list of LabelSpecs to the datastore, creating and updating labels as necessary. - ApplyLabelSpecs(ctx context.Context, specs []*LabelSpec) error - // GetLabelSpecs returns all of the stored LabelSpecs. - GetLabelSpecs(ctx context.Context) ([]*LabelSpec, error) + // ApplyLabelSpecs applies a list of LabelSpecs to the datastore, creating and updating labels as necessary, + // plus rename existing labels *on other teams* to avoid name conflicts + ApplyLabelSpecs(ctx context.Context, specs []*LabelSpec, teamID *uint, namesToMove []string) error + // GetLabelSpecs returns global labels, plus either all team labels a user can see or just ones in the specified team ID. + GetLabelSpecs(ctx context.Context, teamID *uint) ([]*LabelSpec, error) // GetLabelSpec gets the spec for the label with the given name. GetLabelSpec(ctx context.Context, name string) (*LabelSpec, error) NewLabel(ctx context.Context, p LabelPayload) (label *Label, hostIDs []uint, err error) - ModifyLabel(ctx context.Context, id uint, payload ModifyLabelPayload) (*Label, []uint, error) - ListLabels(ctx context.Context, opt ListOptions, includeHostCounts bool) (labels []*Label, err error) - LabelsSummary(ctx context.Context) (labels []*LabelSummary, err error) - GetLabel(ctx context.Context, id uint) (label *Label, hostIDs []uint, err error) + ModifyLabel(ctx context.Context, id uint, payload ModifyLabelPayload) (*LabelWithTeamName, []uint, error) + ListLabels(ctx context.Context, opt ListOptions, teamID *uint, includeHostCounts bool) (labels []*Label, err error) + LabelsSummary(ctx context.Context, teamID *uint) (labels []*LabelSummary, err error) + GetLabel(ctx context.Context, id uint) (label *LabelWithTeamName, hostIDs []uint, err error) DeleteLabel(ctx context.Context, name string) (err error) // DeleteLabelByID is for backwards compatibility with the UI diff --git a/server/fleet/teams.go b/server/fleet/teams.go index c54f75e274..2b7c8f84a1 100644 --- a/server/fleet/teams.go +++ b/server/fleet/teams.go @@ -587,6 +587,14 @@ type TeamFilter struct { TeamID *uint } +func (f TeamFilter) UserCanAccessSelectedTeam() bool { + if f.TeamID == nil { // this method doesn't make sense if there's no team ID specified + return false + } + + return f.User.HasAnyGlobalRole() || f.User.HasAnyRoleInTeam(*f.TeamID) +} + const ( TeamKind = "team" ) diff --git a/server/fleet/users.go b/server/fleet/users.go index c98ee581a5..7a25186551 100644 --- a/server/fleet/users.go +++ b/server/fleet/users.go @@ -437,6 +437,32 @@ func (u *User) SetFakePassword(keySize, cost int) error { return nil } +func (u *User) TeamIDsWithAnyRole() (teamIDs []uint) { + for _, team := range u.Teams { + teamIDs = append(teamIDs, team.ID) + } + + return teamIDs +} + +func (u *User) HasAnyGlobalRole() bool { + return u.GlobalRole != nil +} + +func (u *User) HasAnyTeamRole() bool { + return len(u.Teams) > 0 +} + +func (u *User) HasAnyRoleInTeam(id uint) bool { + for _, team := range u.Teams { + if team.ID == id { + return true + } + } + + return false +} + func saltAndHashPassword(keySize int, plaintext string, cost int) (hashed []byte, salt string, err error) { salt, err = server.GenerateRandomText(keySize) if err != nil { diff --git a/server/mdm/android/service/profiles_test.go b/server/mdm/android/service/profiles_test.go index 63c56a3963..68ddd4ce9d 100644 --- a/server/mdm/android/service/profiles_test.go +++ b/server/mdm/android/service/profiles_test.go @@ -637,9 +637,9 @@ func testHostsWithLabelProfiles(t *testing.T, ds fleet.Datastore, client *mock.C }) // make h1 member of inclany and h2 of inclall - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, linclAny.ID, []uint{h1.Host.ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *linclAny, []uint{h1.Host.ID}, fleet.TeamFilter{}) require.NoError(t, err) - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, linclAll.ID, []uint{h2.Host.ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *linclAll, []uint{h2.Host.ID}, fleet.TeamFilter{}) require.NoError(t, err) // no-label, exclude any and the respective include profiles are applied @@ -661,7 +661,7 @@ func testHostsWithLabelProfiles(t *testing.T, ds fleet.Datastore, client *mock.C }) // make h1 member of exclAny so it stops receiving this profile - _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, lexclAny.ID, []uint{h1.Host.ID}, fleet.TeamFilter{}) + _, _, err = ds.UpdateLabelMembershipByHostIDs(ctx, *lexclAny, []uint{h1.Host.ID}, fleet.TeamFilter{}) require.NoError(t, err) // this only affects h1, h2 version is unchanged diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 9afe276b23..11586d539c 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -139,29 +139,33 @@ type ApplyLabelSpecsFunc func(ctx context.Context, specs []*fleet.LabelSpec) err type ApplyLabelSpecsWithAuthorFunc func(ctx context.Context, specs []*fleet.LabelSpec, authorId *uint) error -type GetLabelSpecsFunc func(ctx context.Context) ([]*fleet.LabelSpec, error) +type SetAsideLabelsFunc func(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error -type GetLabelSpecFunc func(ctx context.Context, name string) (*fleet.LabelSpec, error) +type GetLabelSpecsFunc func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) + +type GetLabelSpecFunc func(ctx context.Context, filter fleet.TeamFilter, name string) (*fleet.LabelSpec, error) type AddLabelsToHostFunc func(ctx context.Context, hostID uint, labelIDs []uint) error type RemoveLabelsFromHostFunc func(ctx context.Context, hostID uint, labelIDs []uint) error -type UpdateLabelMembershipByHostIDsFunc func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) +type UpdateLabelMembershipByHostIDsFunc func(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) type UpdateLabelMembershipByHostCriteriaFunc func(ctx context.Context, hvl fleet.HostVitalsLabel) (*fleet.Label, error) type NewLabelFunc func(ctx context.Context, label *fleet.Label, opts ...fleet.OptionalArg) (*fleet.Label, error) -type SaveLabelFunc func(ctx context.Context, label *fleet.Label, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) +type SaveLabelFunc func(ctx context.Context, label *fleet.Label, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) -type DeleteLabelFunc func(ctx context.Context, name string) error +type DeleteLabelFunc func(ctx context.Context, name string, filter fleet.TeamFilter) error -type LabelFunc func(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) +type LabelByNameFunc func(ctx context.Context, name string, filter fleet.TeamFilter) (*fleet.Label, error) -type ListLabelsFunc func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Label, error) +type LabelFunc func(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) -type LabelsSummaryFunc func(ctx context.Context) ([]*fleet.LabelSummary, error) +type ListLabelsFunc func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) + +type LabelsSummaryFunc func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSummary, error) type GetEnrollmentIDsWithPendingMDMAppleCommandsFunc func(ctx context.Context) ([]string, error) @@ -171,13 +175,11 @@ type ListLabelsForHostFunc func(ctx context.Context, hid uint) ([]*fleet.Label, type ListHostsInLabelFunc func(ctx context.Context, filter fleet.TeamFilter, lid uint, opt fleet.HostListOptions) ([]*fleet.Host, error) -type ListUniqueHostsInLabelsFunc func(ctx context.Context, filter fleet.TeamFilter, labels []uint) ([]*fleet.Host, error) - type SearchLabelsFunc func(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Label, error) -type LabelIDsByNameFunc func(ctx context.Context, labels []string) (map[string]uint, error) +type LabelIDsByNameFunc func(ctx context.Context, labels []string, filter fleet.TeamFilter) (map[string]uint, error) -type LabelsByNameFunc func(ctx context.Context, names []string) (map[string]*fleet.Label, error) +type LabelsByNameFunc func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) type AsyncBatchInsertLabelMembershipFunc func(ctx context.Context, batch [][2]uint) error @@ -1888,6 +1890,9 @@ type DataStore struct { ApplyLabelSpecsWithAuthorFunc ApplyLabelSpecsWithAuthorFunc ApplyLabelSpecsWithAuthorFuncInvoked bool + SetAsideLabelsFunc SetAsideLabelsFunc + SetAsideLabelsFuncInvoked bool + GetLabelSpecsFunc GetLabelSpecsFunc GetLabelSpecsFuncInvoked bool @@ -1915,6 +1920,9 @@ type DataStore struct { DeleteLabelFunc DeleteLabelFunc DeleteLabelFuncInvoked bool + LabelByNameFunc LabelByNameFunc + LabelByNameFuncInvoked bool + LabelFunc LabelFunc LabelFuncInvoked bool @@ -1936,9 +1944,6 @@ type DataStore struct { ListHostsInLabelFunc ListHostsInLabelFunc ListHostsInLabelFuncInvoked bool - ListUniqueHostsInLabelsFunc ListUniqueHostsInLabelsFunc - ListUniqueHostsInLabelsFuncInvoked bool - SearchLabelsFunc SearchLabelsFunc SearchLabelsFuncInvoked bool @@ -4658,18 +4663,25 @@ func (s *DataStore) ApplyLabelSpecsWithAuthor(ctx context.Context, specs []*flee return s.ApplyLabelSpecsWithAuthorFunc(ctx, specs, authorId) } -func (s *DataStore) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, error) { +func (s *DataStore) SetAsideLabels(ctx context.Context, notOnTeamID *uint, names []string, user fleet.User) error { + s.mu.Lock() + s.SetAsideLabelsFuncInvoked = true + s.mu.Unlock() + return s.SetAsideLabelsFunc(ctx, notOnTeamID, names, user) +} + +func (s *DataStore) GetLabelSpecs(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { s.mu.Lock() s.GetLabelSpecsFuncInvoked = true s.mu.Unlock() - return s.GetLabelSpecsFunc(ctx) + return s.GetLabelSpecsFunc(ctx, filter) } -func (s *DataStore) GetLabelSpec(ctx context.Context, name string) (*fleet.LabelSpec, error) { +func (s *DataStore) GetLabelSpec(ctx context.Context, filter fleet.TeamFilter, name string) (*fleet.LabelSpec, error) { s.mu.Lock() s.GetLabelSpecFuncInvoked = true s.mu.Unlock() - return s.GetLabelSpecFunc(ctx, name) + return s.GetLabelSpecFunc(ctx, filter, name) } func (s *DataStore) AddLabelsToHost(ctx context.Context, hostID uint, labelIDs []uint) error { @@ -4686,11 +4698,11 @@ func (s *DataStore) RemoveLabelsFromHost(ctx context.Context, hostID uint, label return s.RemoveLabelsFromHostFunc(ctx, hostID, labelIDs) } -func (s *DataStore) UpdateLabelMembershipByHostIDs(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { +func (s *DataStore) UpdateLabelMembershipByHostIDs(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { s.mu.Lock() s.UpdateLabelMembershipByHostIDsFuncInvoked = true s.mu.Unlock() - return s.UpdateLabelMembershipByHostIDsFunc(ctx, labelID, hostIds, teamFilter) + return s.UpdateLabelMembershipByHostIDsFunc(ctx, label, hostIds, teamFilter) } func (s *DataStore) UpdateLabelMembershipByHostCriteria(ctx context.Context, hvl fleet.HostVitalsLabel) (*fleet.Label, error) { @@ -4707,39 +4719,46 @@ func (s *DataStore) NewLabel(ctx context.Context, label *fleet.Label, opts ...fl return s.NewLabelFunc(ctx, label, opts...) } -func (s *DataStore) SaveLabel(ctx context.Context, label *fleet.Label, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { +func (s *DataStore) SaveLabel(ctx context.Context, label *fleet.Label, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { s.mu.Lock() s.SaveLabelFuncInvoked = true s.mu.Unlock() return s.SaveLabelFunc(ctx, label, teamFilter) } -func (s *DataStore) DeleteLabel(ctx context.Context, name string) error { +func (s *DataStore) DeleteLabel(ctx context.Context, name string, filter fleet.TeamFilter) error { s.mu.Lock() s.DeleteLabelFuncInvoked = true s.mu.Unlock() - return s.DeleteLabelFunc(ctx, name) + return s.DeleteLabelFunc(ctx, name, filter) } -func (s *DataStore) Label(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { +func (s *DataStore) LabelByName(ctx context.Context, name string, filter fleet.TeamFilter) (*fleet.Label, error) { + s.mu.Lock() + s.LabelByNameFuncInvoked = true + s.mu.Unlock() + return s.LabelByNameFunc(ctx, name, filter) +} + +func (s *DataStore) Label(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { s.mu.Lock() s.LabelFuncInvoked = true s.mu.Unlock() return s.LabelFunc(ctx, lid, teamFilter) } -func (s *DataStore) ListLabels(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Label, error) { +func (s *DataStore) ListLabels(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { s.mu.Lock() s.ListLabelsFuncInvoked = true s.mu.Unlock() - return s.ListLabelsFunc(ctx, filter, opt) + return s.ListLabelsFunc(ctx, filter, opt, includeHostCounts) } -func (s *DataStore) LabelsSummary(ctx context.Context) ([]*fleet.LabelSummary, error) { +func (s *DataStore) LabelsSummary(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSummary, error) { s.mu.Lock() s.LabelsSummaryFuncInvoked = true s.mu.Unlock() - return s.LabelsSummaryFunc(ctx) + return s.LabelsSummaryFunc(ctx, filter) } func (s *DataStore) GetEnrollmentIDsWithPendingMDMAppleCommands(ctx context.Context) ([]string, error) { @@ -4770,13 +4789,6 @@ func (s *DataStore) ListHostsInLabel(ctx context.Context, filter fleet.TeamFilte return s.ListHostsInLabelFunc(ctx, filter, lid, opt) } -func (s *DataStore) ListUniqueHostsInLabels(ctx context.Context, filter fleet.TeamFilter, labels []uint) ([]*fleet.Host, error) { - s.mu.Lock() - s.ListUniqueHostsInLabelsFuncInvoked = true - s.mu.Unlock() - return s.ListUniqueHostsInLabelsFunc(ctx, filter, labels) -} - func (s *DataStore) SearchLabels(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Label, error) { s.mu.Lock() s.SearchLabelsFuncInvoked = true @@ -4784,18 +4796,18 @@ func (s *DataStore) SearchLabels(ctx context.Context, filter fleet.TeamFilter, q return s.SearchLabelsFunc(ctx, filter, query, omit...) } -func (s *DataStore) LabelIDsByName(ctx context.Context, labels []string) (map[string]uint, error) { +func (s *DataStore) LabelIDsByName(ctx context.Context, labels []string, filter fleet.TeamFilter) (map[string]uint, error) { s.mu.Lock() s.LabelIDsByNameFuncInvoked = true s.mu.Unlock() - return s.LabelIDsByNameFunc(ctx, labels) + return s.LabelIDsByNameFunc(ctx, labels, filter) } -func (s *DataStore) LabelsByName(ctx context.Context, names []string) (map[string]*fleet.Label, error) { +func (s *DataStore) LabelsByName(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { s.mu.Lock() s.LabelsByNameFuncInvoked = true s.mu.Unlock() - return s.LabelsByNameFunc(ctx, names) + return s.LabelsByNameFunc(ctx, names, filter) } func (s *DataStore) AsyncBatchInsertLabelMembership(ctx context.Context, batch [][2]uint) error { diff --git a/server/mock/service/service_mock.go b/server/mock/service/service_mock.go index 312d8f6f30..a8afc07caf 100644 --- a/server/mock/service/service_mock.go +++ b/server/mock/service/service_mock.go @@ -139,21 +139,21 @@ type DeletePackByIDFunc func(ctx context.Context, id uint) (err error) type ListPacksForHostFunc func(ctx context.Context, hid uint) (packs []*fleet.Pack, err error) -type ApplyLabelSpecsFunc func(ctx context.Context, specs []*fleet.LabelSpec) error +type ApplyLabelSpecsFunc func(ctx context.Context, specs []*fleet.LabelSpec, teamID *uint, namesToMove []string) error -type GetLabelSpecsFunc func(ctx context.Context) ([]*fleet.LabelSpec, error) +type GetLabelSpecsFunc func(ctx context.Context, teamID *uint) ([]*fleet.LabelSpec, error) type GetLabelSpecFunc func(ctx context.Context, name string) (*fleet.LabelSpec, error) type NewLabelFunc func(ctx context.Context, p fleet.LabelPayload) (label *fleet.Label, hostIDs []uint, err error) -type ModifyLabelFunc func(ctx context.Context, id uint, payload fleet.ModifyLabelPayload) (*fleet.Label, []uint, error) +type ModifyLabelFunc func(ctx context.Context, id uint, payload fleet.ModifyLabelPayload) (*fleet.LabelWithTeamName, []uint, error) -type ListLabelsFunc func(ctx context.Context, opt fleet.ListOptions, includeHostCounts bool) (labels []*fleet.Label, err error) +type ListLabelsFunc func(ctx context.Context, opt fleet.ListOptions, teamID *uint, includeHostCounts bool) (labels []*fleet.Label, err error) -type LabelsSummaryFunc func(ctx context.Context) (labels []*fleet.LabelSummary, err error) +type LabelsSummaryFunc func(ctx context.Context, teamID *uint) (labels []*fleet.LabelSummary, err error) -type GetLabelFunc func(ctx context.Context, id uint) (label *fleet.Label, hostIDs []uint, err error) +type GetLabelFunc func(ctx context.Context, id uint) (label *fleet.LabelWithTeamName, hostIDs []uint, err error) type DeleteLabelFunc func(ctx context.Context, name string) (err error) @@ -2578,18 +2578,18 @@ func (s *Service) ListPacksForHost(ctx context.Context, hid uint) (packs []*flee return s.ListPacksForHostFunc(ctx, hid) } -func (s *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpec) error { +func (s *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpec, teamID *uint, namesToMove []string) error { s.mu.Lock() s.ApplyLabelSpecsFuncInvoked = true s.mu.Unlock() - return s.ApplyLabelSpecsFunc(ctx, specs) + return s.ApplyLabelSpecsFunc(ctx, specs, teamID, namesToMove) } -func (s *Service) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, error) { +func (s *Service) GetLabelSpecs(ctx context.Context, teamID *uint) ([]*fleet.LabelSpec, error) { s.mu.Lock() s.GetLabelSpecsFuncInvoked = true s.mu.Unlock() - return s.GetLabelSpecsFunc(ctx) + return s.GetLabelSpecsFunc(ctx, teamID) } func (s *Service) GetLabelSpec(ctx context.Context, name string) (*fleet.LabelSpec, error) { @@ -2606,28 +2606,28 @@ func (s *Service) NewLabel(ctx context.Context, p fleet.LabelPayload) (label *fl return s.NewLabelFunc(ctx, p) } -func (s *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.ModifyLabelPayload) (*fleet.Label, []uint, error) { +func (s *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.ModifyLabelPayload) (*fleet.LabelWithTeamName, []uint, error) { s.mu.Lock() s.ModifyLabelFuncInvoked = true s.mu.Unlock() return s.ModifyLabelFunc(ctx, id, payload) } -func (s *Service) ListLabels(ctx context.Context, opt fleet.ListOptions, includeHostCounts bool) (labels []*fleet.Label, err error) { +func (s *Service) ListLabels(ctx context.Context, opt fleet.ListOptions, teamID *uint, includeHostCounts bool) (labels []*fleet.Label, err error) { s.mu.Lock() s.ListLabelsFuncInvoked = true s.mu.Unlock() - return s.ListLabelsFunc(ctx, opt, includeHostCounts) + return s.ListLabelsFunc(ctx, opt, teamID, includeHostCounts) } -func (s *Service) LabelsSummary(ctx context.Context) (labels []*fleet.LabelSummary, err error) { +func (s *Service) LabelsSummary(ctx context.Context, teamID *uint) (labels []*fleet.LabelSummary, err error) { s.mu.Lock() s.LabelsSummaryFuncInvoked = true s.mu.Unlock() - return s.LabelsSummaryFunc(ctx) + return s.LabelsSummaryFunc(ctx, teamID) } -func (s *Service) GetLabel(ctx context.Context, id uint) (label *fleet.Label, hostIDs []uint, err error) { +func (s *Service) GetLabel(ctx context.Context, id uint) (label *fleet.LabelWithTeamName, hostIDs []uint, err error) { s.mu.Lock() s.GetLabelFuncInvoked = true s.mu.Unlock() diff --git a/server/service/appconfig.go b/server/service/appconfig.go index f45f579bb4..7c73f153d7 100644 --- a/server/service/appconfig.go +++ b/server/service/appconfig.go @@ -477,7 +477,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle if newAppConfig.AgentOptions != nil { // if there were Agent Options in the new app config, then it replaced the // agent options in the resulting app config, so validate those. - if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium()); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium(), 0); err != nil { err = fleet.SuggestAgentOptionsCorrection(err) err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOpts.Force && !applyOpts.DryRun { diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index 54c74982c2..a4752919f0 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -888,7 +888,7 @@ func (svc *Service) NewMDMAppleDeclaration(ctx context.Context, teamID uint, dat tmID = &teamID } - validatedLabels, err := svc.validateDeclarationLabels(ctx, labels) + validatedLabels, err := svc.validateDeclarationLabels(ctx, labels, teamID) if err != nil { return nil, err } @@ -963,12 +963,12 @@ func validateDeclarationFleetVariables(contents string) error { return nil } -func (svc *Service) batchValidateDeclarationLabels(ctx context.Context, labelNames []string) (map[string]fleet.ConfigurationProfileLabel, error) { +func (svc *Service) batchValidateDeclarationLabels(ctx context.Context, labelNames []string, teamID uint) (map[string]fleet.ConfigurationProfileLabel, error) { if len(labelNames) == 0 { return nil, nil } - labels, err := svc.ds.LabelIDsByName(ctx, labelNames) + labels, err := svc.ds.LabelIDsByName(ctx, labelNames, fleet.TeamFilter{User: authz.UserFromContext(ctx), TeamID: &teamID}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting label IDs by name") } @@ -997,8 +997,8 @@ func (svc *Service) batchValidateDeclarationLabels(ctx context.Context, labelNam return profLabels, nil } -func (svc *Service) validateDeclarationLabels(ctx context.Context, labelNames []string) ([]fleet.ConfigurationProfileLabel, error) { - labelMap, err := svc.batchValidateDeclarationLabels(ctx, labelNames) +func (svc *Service) validateDeclarationLabels(ctx context.Context, labelNames []string, teamID uint) ([]fleet.ConfigurationProfileLabel, error) { + labelMap, err := svc.batchValidateDeclarationLabels(ctx, labelNames, teamID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "validating declaration labels") } diff --git a/server/service/campaigns.go b/server/service/campaigns.go index 7418fdcea3..85a0522eb1 100644 --- a/server/service/campaigns.go +++ b/server/service/campaigns.go @@ -210,7 +210,7 @@ func (svc *Service) NewDistributedQueryCampaignByIdentifiers(ctx context.Context if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionRead); err != nil { return nil, err } - labelMap, err := svc.ds.LabelIDsByName(ctx, labels) + labelMap, err := svc.ds.LabelIDsByName(ctx, labels, fleet.TeamFilter{User: vc.User}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "finding label IDs") } diff --git a/server/service/campaigns_test.go b/server/service/campaigns_test.go index 91206ce990..2bc8afea7f 100644 --- a/server/service/campaigns_test.go +++ b/server/service/campaigns_test.go @@ -86,7 +86,7 @@ func TestLiveQueryAuth(t *testing.T) { ds.HostIDsByIdentifierFunc = func(ctx context.Context, filter fleet.TeamFilter, identifiers []string) ([]uint, error) { return nil, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, names []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { return nil, nil } ds.CountHostsInTargetsFunc = func(ctx context.Context, filters fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { @@ -277,7 +277,7 @@ func TestLiveQueryLabelValidation(t *testing.T) { return query, nil } - ds.LabelIDsByNameFunc = func(ctx context.Context, names []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { return map[string]uint{"label1": uint(1)}, nil } diff --git a/server/service/client.go b/server/service/client.go index cc8404ede6..e80fa3011d 100644 --- a/server/service/client.go +++ b/server/service/client.go @@ -1882,6 +1882,7 @@ func (c *Client) DoGitOps( delete(incoming.OrgSettings, "certificate_authorities") // Labels + // TODO GitOps if incoming.Labels == nil || len(incoming.Labels) > 0 { labelsToDelete, err := c.doGitOpsLabels(incoming, logFn, dryRun) if err != nil { @@ -2665,6 +2666,7 @@ func (c *Client) doGitOpsNoTeamWebhookSettings( return nil } +// TODO allow spec'ing labels by either "everything" or team-specific (team ID or name?) func (c *Client) doGitOpsLabels(config *spec.GitOps, logFn func(format string, args ...interface{}), dryRun bool) ([]string, error) { persistedLabels, err := c.GetLabels() if err != nil { diff --git a/server/service/client_labels.go b/server/service/client_labels.go index e9dc69f8a1..804ab6096c 100644 --- a/server/service/client_labels.go +++ b/server/service/client_labels.go @@ -8,6 +8,7 @@ import ( // ApplyLabels sends the list of Labels to be applied (upserted) to the // Fleet instance. +// TODO gitops allow specifying by team func (c *Client) ApplyLabels(specs []*fleet.LabelSpec) error { req := applyLabelSpecsRequest{Specs: specs} verb, path := "POST", "/api/latest/fleet/spec/labels" @@ -24,6 +25,7 @@ func (c *Client) GetLabel(name string) (*fleet.LabelSpec, error) { } // GetLabels retrieves the list of all LabelSpecs. +// TODO gitops allow passing team ID or name func (c *Client) GetLabels() ([]*fleet.LabelSpec, error) { verb, path := "GET", "/api/latest/fleet/spec/labels" var responseBody getLabelSpecsResponse diff --git a/server/service/global_policies.go b/server/service/global_policies.go index 8301903d6c..0a12f976be 100644 --- a/server/service/global_policies.go +++ b/server/service/global_policies.go @@ -76,7 +76,7 @@ func (svc Service) NewGlobalPolicy(ctx context.Context, p fleet.PolicyPayload) ( }) } - if err := verifyLabelsToAssociate(ctx, svc.ds, nil, append(p.LabelsIncludeAny, p.LabelsExcludeAny...)); err != nil { + if err := verifyLabelsToAssociate(ctx, svc.ds, nil, append(p.LabelsIncludeAny, p.LabelsExcludeAny...), vc.User); err != nil { return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } @@ -516,42 +516,51 @@ func applyPolicySpecsEndpoint(ctx context.Context, request interface{}, svc flee } // checkPolicySpecAuthorization verifies that the user is authorized to modify the -// policies defined in the spec. -func (svc *Service) checkPolicySpecAuthorization(ctx context.Context, policies []*fleet.PolicySpec) error { +// policies defined in the spec, and returns a map from team names to team IDs if successful +func (svc *Service) checkPolicySpecAuthorization(ctx context.Context, policies []*fleet.PolicySpec) (map[string]uint, error) { checkGlobalPolicyAuth := false + var teamIDsByName = make(map[string]uint) for _, policy := range policies { if policy.Team != "" && policy.Team != "No team" { team, err := svc.ds.TeamByName(ctx, policy.Team) if err != nil { // This is so that the proper HTTP status code is returned svc.authz.SkipAuthorization(ctx) - return ctxerr.Wrap(ctx, err, "getting team by name") + return nil, ctxerr.Wrap(ctx, err, "getting team by name") } if err := svc.authz.Authorize(ctx, &fleet.Policy{ PolicyData: fleet.PolicyData{ TeamID: &team.ID, }, }, fleet.ActionWrite); err != nil { - return err + return nil, err } + + teamIDsByName[policy.Team] = team.ID } else { checkGlobalPolicyAuth = true } } if checkGlobalPolicyAuth { if err := svc.authz.Authorize(ctx, &fleet.Policy{}, fleet.ActionWrite); err != nil { - return err + return nil, err } } - return nil + return teamIDsByName, nil } func (svc *Service) ApplyPolicySpecs(ctx context.Context, policies []*fleet.PolicySpec) error { // Check authorization first. - if err := svc.checkPolicySpecAuthorization(ctx, policies); err != nil { + teamIDsByName, err := svc.checkPolicySpecAuthorization(ctx, policies) + if err != nil { return err } + vc, ok := viewer.FromContext(ctx) + if !ok { + return errors.New("user must be authenticated to apply policies") + } + // After the authorization check, check the policy fields. for _, policy := range policies { if err := policy.Verify(); err != nil { @@ -564,14 +573,19 @@ func (svc *Service) ApplyPolicySpecs(ctx context.Context, policies []*fleet.Poli labels := policy.LabelsIncludeAny labels = append(labels, policy.LabelsExcludeAny...) if len(labels) > 0 { - labelsMap, err := svc.ds.LabelsByName(ctx, labels) + var teamID *uint // ensure labels specified exist and are global or on the same team as the policy + if policy.Team != "" { // if we get 0 as team ID, we'll pull only global labels, which is fine + teamID = ptr.Uint(teamIDsByName[policy.Team]) + } + + labelsMap, err := svc.ds.LabelsByName(ctx, labels, fleet.TeamFilter{User: vc.User, TeamID: teamID}) if err != nil { return ctxerr.Wrap(ctx, err, "getting labels by name") } for _, label := range labels { if _, ok := labelsMap[label]; !ok { return ctxerr.Wrap(ctx, &fleet.BadRequestError{ - Message: fmt.Sprintf("label %q does not exist", label), + Message: fmt.Sprintf("label %q does not exist, or cannot be applied to this policy", label), }) } } @@ -586,10 +600,6 @@ func (svc *Service) ApplyPolicySpecs(ctx context.Context, policies []*fleet.Poli }) } - vc, ok := viewer.FromContext(ctx) - if !ok { - return errors.New("user must be authenticated to apply policies") - } if !license.IsPremium(ctx) { for i := range policies { policies[i].Critical = false diff --git a/server/service/global_policies_test.go b/server/service/global_policies_test.go index a0153dfd24..856dc4a83b 100644 --- a/server/service/global_policies_test.go +++ b/server/service/global_policies_test.go @@ -274,7 +274,7 @@ func TestApplyPolicySpecsLabelsValidation(t *testing.T) { ds.ApplyPolicySpecsFunc = func(ctx context.Context, authorID uint, specs []*fleet.PolicySpec) error { return nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { labels := make(map[string]*fleet.Label, len(names)) for _, name := range names { if name == "foo" { diff --git a/server/service/handler.go b/server/service/handler.go index b45fb6d98a..7b264f36b5 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -463,12 +463,12 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ue.PATCH("/api/_version_/fleet/labels/{id:[0-9]+}", modifyLabelEndpoint, modifyLabelRequest{}) ue.GET("/api/_version_/fleet/labels/{id:[0-9]+}", getLabelEndpoint, getLabelRequest{}) ue.GET("/api/_version_/fleet/labels", listLabelsEndpoint, listLabelsRequest{}) - ue.GET("/api/_version_/fleet/labels/summary", getLabelsSummaryEndpoint, nil) + ue.GET("/api/_version_/fleet/labels/summary", getLabelsSummaryEndpoint, getLabelsSummaryRequest{}) ue.GET("/api/_version_/fleet/labels/{id:[0-9]+}/hosts", listHostsInLabelEndpoint, listHostsInLabelRequest{}) ue.DELETE("/api/_version_/fleet/labels/{name}", deleteLabelEndpoint, deleteLabelRequest{}) ue.DELETE("/api/_version_/fleet/labels/id/{id:[0-9]+}", deleteLabelByIDEndpoint, deleteLabelByIDRequest{}) ue.POST("/api/_version_/fleet/spec/labels", applyLabelSpecsEndpoint, applyLabelSpecsRequest{}) - ue.GET("/api/_version_/fleet/spec/labels", getLabelSpecsEndpoint, nil) + ue.GET("/api/_version_/fleet/spec/labels", getLabelSpecsEndpoint, getLabelSpecsRequest{}) ue.GET("/api/_version_/fleet/spec/labels/{name}", getLabelSpecEndpoint, getGenericSpecRequest{}) // This endpoint runs live queries synchronously (with a configured timeout). diff --git a/server/service/hosts.go b/server/service/hosts.go index 8e34420fc0..8a8ea0b220 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -881,7 +881,7 @@ func (svc *Service) GetHostSummary(ctx context.Context, teamID *uint, platform * } hostSummary.AllLinuxCount = linuxCount - labelsSummary, err := svc.ds.LabelsSummary(ctx) + labelsSummary, err := svc.ds.LabelsSummary(ctx, fleet.TeamFilter{}) if err != nil { return nil, err } @@ -3096,7 +3096,12 @@ func (svc *Service) AddLabelsToHost(ctx context.Context, id uint, labelNames []s return ctxerr.Wrap(ctx, err) } - labelIDs, err := svc.validateLabelNames(ctx, "add", labelNames) + var tmID uint + if host.TeamID != nil { + tmID = *host.TeamID + } + + labelIDs, err := svc.validateLabelNames(ctx, "add", labelNames, tmID) if err != nil { return err } @@ -3141,7 +3146,12 @@ func (svc *Service) RemoveLabelsFromHost(ctx context.Context, id uint, labelName return ctxerr.Wrap(ctx, err) } - labelIDs, err := svc.validateLabelNames(ctx, "remove", labelNames) + var tmID uint + if host.TeamID != nil { + tmID = *host.TeamID + } + + labelIDs, err := svc.validateLabelNames(ctx, "remove", labelNames, tmID) if err != nil { return err } @@ -3156,7 +3166,7 @@ func (svc *Service) RemoveLabelsFromHost(ctx context.Context, id uint, labelName return nil } -func (svc *Service) validateLabelNames(ctx context.Context, action string, labelNames []string) ([]uint, error) { +func (svc *Service) validateLabelNames(ctx context.Context, action string, labelNames []string, teamID uint) ([]uint, error) { if len(labelNames) == 0 { return nil, nil } @@ -3174,7 +3184,8 @@ func (svc *Service) validateLabelNames(ctx context.Context, action string, label return nil, nil } - labels, err := svc.ds.LabelIDsByName(ctx, labelNames) + // team ID is always set because we are assigning labels to an entity; no-team entities can only use global labels + labels, err := svc.ds.LabelIDsByName(ctx, labelNames, fleet.TeamFilter{TeamID: &teamID, User: authz.UserFromContext(ctx)}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting label IDs by name") } diff --git a/server/service/hosts_test.go b/server/service/hosts_test.go index a9026cd65d..53ab0d932c 100644 --- a/server/service/hosts_test.go +++ b/server/service/hosts_test.go @@ -1063,7 +1063,7 @@ func TestStreamHosts(t *testing.T) { hostIterator := func() iter.Seq2[*fleet.HostResponse, error] { return func(yield func(*fleet.HostResponse, error) bool) { for i := 1; i <= 3; i++ { - host := &fleet.HostResponse{Host: &fleet.Host{ID: uint(i)}} + host := &fleet.HostResponse{Host: &fleet.Host{ID: uint(i)}} // nolint:gosec if !yield(host, nil) { return } @@ -1287,7 +1287,7 @@ func TestGetHostSummary(t *testing.T) { Platforms: []*fleet.HostSummaryPlatform{{Platform: "darwin", HostsCount: 1}, {Platform: "debian", HostsCount: 2}, {Platform: "centos", HostsCount: 3}, {Platform: "ubuntu", HostsCount: 4}}, }, nil } - ds.LabelsSummaryFunc = func(ctx context.Context) ([]*fleet.LabelSummary, error) { + ds.LabelsSummaryFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSummary, error) { return []*fleet.LabelSummary{{ID: 1, Name: "All hosts", Description: "All hosts enrolled in Fleet", LabelType: fleet.LabelTypeBuiltIn}, {ID: 10, Name: "Other label", Description: "Not a builtin label", LabelType: fleet.LabelTypeRegular}}, nil } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index eac6214b08..bc18c80960 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -3496,7 +3496,7 @@ func (s *integrationTestSuite) TestHostsAddToTeam() { require.Equal(t, tm2.ID, *getResp.Host.TeamID) // get all hosts label - lblIDs, err := s.ds.LabelIDsByName(context.Background(), []string{"All Hosts"}) + lblIDs, err := s.ds.LabelIDsByName(context.Background(), []string{"All Hosts"}, fleet.TeamFilter{}) require.NoError(t, err) labelID := lblIDs["All Hosts"] @@ -4563,6 +4563,8 @@ func (s *integrationTestSuite) TestGetMacadminsData() { } func (s *integrationTestSuite) TestLabels() { + // TODO team labels + t := s.T() // create some hosts to use for manual labels @@ -5162,7 +5164,7 @@ func (s *integrationTestSuite) TestLabels() { func (s *integrationTestSuite) TestListHostsByLabel() { t := s.T() - lblIDs, err := s.ds.LabelIDsByName(context.Background(), []string{"All Hosts"}) + lblIDs, err := s.ds.LabelIDsByName(context.Background(), []string{"All Hosts"}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, lblIDs, 1) labelID := lblIDs["All Hosts"] @@ -5380,6 +5382,8 @@ func (s *integrationTestSuite) TestLabelSpecs() { // get a non-existing label spec s.DoJSON("GET", "/api/latest/fleet/spec/labels/zzz", nil, http.StatusNotFound, &getResp) + + // TODO team labels } func (s *integrationTestSuite) TestUsers() { @@ -8773,7 +8777,7 @@ func (s *integrationTestSuite) TestSearchTargets() { for name := range fleet.ReservedLabelNames() { builtinNames = append(builtinNames, name) } - lblMap, err := s.ds.LabelIDsByName(context.Background(), builtinNames) + lblMap, err := s.ds.LabelIDsByName(context.Background(), builtinNames, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, lblMap, len(builtinNames)) @@ -8894,7 +8898,7 @@ func (s *integrationTestSuite) TestCountTargets() { hosts := s.createHosts(t) - lblMap, err := s.ds.LabelIDsByName(context.Background(), []string{"All Hosts"}) + lblMap, err := s.ds.LabelIDsByName(context.Background(), []string{"All Hosts"}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, lblMap, 1) @@ -9697,7 +9701,7 @@ func (s *integrationTestSuite) TestHostsReportDownload() { {Name: t.Name(), LabelMembershipType: fleet.LabelMembershipTypeManual, Query: "select 1", Hosts: []string{hosts[2].Hostname}}, }) require.NoError(t, err) - lids, err := s.ds.LabelIDsByName(context.Background(), []string{t.Name()}) + lids, err := s.ds.LabelIDsByName(context.Background(), []string{t.Name()}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, lids, 1) customLabelID := lids[t.Name()] @@ -13475,6 +13479,8 @@ func (s *integrationTestSuite) TestListHostUpcomingActivities() { } func (s *integrationTestSuite) TestAddingRemovingManualLabels() { + // TODO team labels + t := s.T() ctx := context.Background() @@ -13531,7 +13537,7 @@ func (s *integrationTestSuite) TestAddingRemovingManualLabels() { host2 := newHostFunc("host2", nil) teamHost2 := newHostFunc("teamHost2", &team1.ID) - ls, err := s.ds.LabelIDsByName(ctx, []string{"All Hosts"}) + ls, err := s.ds.LabelIDsByName(ctx, []string{"All Hosts"}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, ls, 1) allHostsLabelID, ok := ls["All Hosts"] diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 2bf524b75c..b7361336a5 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -4507,7 +4507,7 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp) require.Len(t, resp.Hosts, 3) - allHostsLabel, err := s.ds.GetLabelSpec(context.Background(), "All hosts") + allHostsLabel, err := s.ds.GetLabelSpec(context.Background(), fleet.TeamFilter{}, "All hosts") require.NoError(t, err) for _, h := range resp.Hosts { err = s.ds.RecordLabelQueryExecutions( @@ -7611,7 +7611,7 @@ func (s *integrationEnterpriseTestSuite) TestOrbitConfigExtensions() { Query: "SELECT 1;", }) require.NoError(t, err) - allHostsLabel, err := s.ds.GetLabelSpec(ctx, "All hosts") + allHostsLabel, err := s.ds.GetLabelSpec(ctx, fleet.TeamFilter{}, "All hosts") require.NoError(t, err) orbitDarwinClient := createOrbitEnrolledHost(t, "darwin", "foobar1", s.ds) @@ -22855,7 +22855,7 @@ func (s *integrationEnterpriseTestSuite) TestTeamLabelsDistributedReadWrite() { filterLabelQueries := func(queries map[string]string) map[string]string { allLabels, err := s.ds.ListLabels(t.Context(), fleet.TeamFilter{ User: user, - }, fleet.ListOptions{}) + }, fleet.ListOptions{}, false) require.NoError(t, err) builtinLabels := make(map[string]struct{}) for _, label := range allLabels { diff --git a/server/service/integration_mdm_profiles_test.go b/server/service/integration_mdm_profiles_test.go index b43fd61104..a132ea8483 100644 --- a/server/service/integration_mdm_profiles_test.go +++ b/server/service/integration_mdm_profiles_test.go @@ -3784,7 +3784,7 @@ func (s *integrationMDMTestSuite) TestListMDMConfigProfiles() { require.NoError(t, err) // break lblFoo by deleting it - require.NoError(t, s.ds.DeleteLabel(ctx, lblFoo.Name)) + require.NoError(t, s.ds.DeleteLabel(ctx, lblFoo.Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})) // test that all fields are correctly returned with team 2 var listResp listMDMConfigProfilesResponse @@ -5925,7 +5925,7 @@ func (s *integrationMDMTestSuite) TestHostMDMProfilesExcludeLabels() { }) // break the A1 profile by deleting labels [1] - err = s.ds.DeleteLabel(ctx, labels[1].Name) + err = s.ds.DeleteLabel(ctx, labels[1].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) // it doesn't get installed to the Apple host, as it is broken @@ -5966,9 +5966,9 @@ func (s *integrationMDMTestSuite) TestHostMDMProfilesExcludeLabels() { // delete labels [2] and [4], breaking D3 and W2, they don't get removed // since they are broken - err = s.ds.DeleteLabel(ctx, labels[2].Name) + err = s.ds.DeleteLabel(ctx, labels[2].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) - err = s.ds.DeleteLabel(ctx, labels[4].Name) + err = s.ds.DeleteLabel(ctx, labels[4].Name, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) triggerReconcileProfiles() diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index 166b4c1669..6038669f93 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -17859,7 +17859,7 @@ func (s *integrationMDMTestSuite) TestNonMDWindowsHostsIgnoredInDiskEncryptionSt s.setSkipWorkerJobs(t) // get the All hosts label ID - ls, err := s.ds.LabelIDsByName(ctx, []string{"All Hosts"}) + ls, err := s.ds.LabelIDsByName(ctx, []string{"All Hosts"}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, ls, 1) allHostsLblID := ls["All Hosts"] @@ -17947,7 +17947,7 @@ func (s *integrationMDMTestSuite) TestLinuxHostsIgnoredInOSSettingsStats() { s.setSkipWorkerJobs(t) // get the All hosts label ID - ls, err := s.ds.LabelIDsByName(ctx, []string{"All Hosts"}) + ls, err := s.ds.LabelIDsByName(ctx, []string{"All Hosts"}, fleet.TeamFilter{}) require.NoError(t, err) require.Len(t, ls, 1) allHostsLblID := ls["All Hosts"] @@ -19440,7 +19440,7 @@ func (s *integrationMDMTestSuite) TestTeamLabelsTeamDeletion() { require.True(t, fleet.IsNotFound(err)) // Make sure l2t2 in t2 is unaffected. - _, _, err = s.ds.Label(t.Context(), l2t2.ID, fleet.TeamFilter{}) + _, _, err = s.ds.Label(t.Context(), l2t2.ID, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) require.NoError(t, err) // Make sure label membership for l1t1 is gone. diff --git a/server/service/labels.go b/server/service/labels.go index d554753147..a687820dea 100644 --- a/server/service/labels.go +++ b/server/service/labels.go @@ -5,8 +5,11 @@ import ( "encoding/json" "fmt" "net/http" + "slices" + "strconv" "github.com/fleetdm/fleet/v4/server" + "github.com/fleetdm/fleet/v4/server/authz" authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/license" @@ -47,7 +50,7 @@ func createLabelEndpoint(ctx context.Context, request interface{}, svc fleet.Ser } func (svc *Service) NewLabel(ctx context.Context, p fleet.LabelPayload) (*fleet.Label, []uint, error) { - if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionWrite); err != nil { + if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionCreate); err != nil { return nil, nil, err } vc, ok := viewer.FromContext(ctx) @@ -121,7 +124,7 @@ func (svc *Service) NewLabel(ctx context.Context, p fleet.LabelPayload) (*fleet. return nil, nil, err } } - return svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.ID, hostIDs, filter) + return svc.ds.UpdateLabelMembershipByHostIDs(ctx, *label, hostIDs, filter) } return label, nil, nil } @@ -136,8 +139,8 @@ type modifyLabelRequest struct { } type modifyLabelResponse struct { - Label labelResponse `json:"label"` - Err error `json:"error,omitempty"` + Label labelWithTeamNameResponse `json:"label"` + Err error `json:"error,omitempty"` } func (r modifyLabelResponse) Error() error { return r.Err } @@ -149,7 +152,7 @@ func modifyLabelEndpoint(ctx context.Context, request interface{}, svc fleet.Ser return modifyLabelResponse{Err: err}, nil } - labelResp, err := labelResponseForLabel(label, hostIDs) + labelResp, err := labelResponseForLabelWithTeamName(label, hostIDs) if err != nil { return modifyLabelResponse{Err: err}, nil } @@ -157,25 +160,33 @@ func modifyLabelEndpoint(ctx context.Context, request interface{}, svc fleet.Ser return modifyLabelResponse{Label: *labelResp}, err } -func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.ModifyLabelPayload) (*fleet.Label, []uint, error) { - if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionWrite); err != nil { - return nil, nil, err - } +func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.ModifyLabelPayload) (*fleet.LabelWithTeamName, []uint, error) { vc, ok := viewer.FromContext(ctx) if !ok { + svc.SkipAuth(ctx) return nil, nil, fleet.ErrNoContext } if len(payload.Hosts) > 0 && len(payload.HostIDs) > 0 { + svc.SkipAuth(ctx) return nil, nil, fleet.NewInvalidArgumentError("hosts", `Only one of either "hosts" or "host_ids" can be included in the request.`) } filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} + // DB query will filter labels the user can't see; auth check filters labels the user can't write label, _, err := svc.ds.Label(ctx, id, filter) if err != nil { + // If we get a retrieval error, 403-wrap it if a user can't write global labels so we don't leak info + if authErr := svc.authz.Authorize(ctx, fleet.Label{}, fleet.ActionWrite); authErr != nil { + return nil, nil, authErr + } return nil, nil, err } + if err := svc.authz.Authorize(ctx, label, fleet.ActionWrite); err != nil { + return nil, nil, err + } + if label.LabelType == fleet.LabelTypeBuiltIn { return nil, nil, fleet.NewInvalidArgumentError("label_type", fmt.Sprintf("cannot modify built-in label '%s'", label.Name)) } @@ -210,12 +221,12 @@ func (svc *Service) ModifyLabel(ctx context.Context, id uint, payload fleet.Modi } if hostIDs != nil { - if _, _, err := svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.ID, hostIDs, filter); err != nil { + if _, _, err := svc.ds.UpdateLabelMembershipByHostIDs(ctx, label.Label, hostIDs, filter); err != nil { return nil, nil, err } } - return svc.ds.SaveLabel(ctx, label, filter) + return svc.ds.SaveLabel(ctx, &label.Label, filter) } //////////////////////////////////////////////////////////////////////////////// @@ -226,6 +237,13 @@ type getLabelRequest struct { ID uint `url:"id"` } +type labelWithTeamNameResponse struct { + fleet.LabelWithTeamName + DisplayText string `json:"display_text"` + Count int `json:"count"` + HostIDs []uint `json:"host_ids,omitempty"` +} + type labelResponse struct { fleet.Label DisplayText string `json:"display_text"` @@ -234,8 +252,8 @@ type labelResponse struct { } type getLabelResponse struct { - Label labelResponse `json:"label"` - Err error `json:"error,omitempty"` + Label labelWithTeamNameResponse `json:"label"` + Err error `json:"error,omitempty"` } func (r getLabelResponse) Error() error { return r.Err } @@ -246,14 +264,15 @@ func getLabelEndpoint(ctx context.Context, request interface{}, svc fleet.Servic if err != nil { return getLabelResponse{Err: err}, nil } - resp, err := labelResponseForLabel(label, hostIDs) + resp, err := labelResponseForLabelWithTeamName(label, hostIDs) if err != nil { return getLabelResponse{Err: err}, nil } return getLabelResponse{Label: *resp}, nil } -func (svc *Service) GetLabel(ctx context.Context, id uint) (*fleet.Label, []uint, error) { +func (svc *Service) GetLabel(ctx context.Context, id uint) (*fleet.LabelWithTeamName, []uint, error) { + // authz intentionally casts a wide net here; we filter unauthorized labels out at the data store level if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionRead); err != nil { return nil, nil, err } @@ -272,6 +291,7 @@ func (svc *Service) GetLabel(ctx context.Context, id uint) (*fleet.Label, []uint type listLabelsRequest struct { ListOptions fleet.ListOptions `url:"list_options"` + TeamID *string `query:"team_id,optional"` // string because it's an int or "global" IncludeHostCounts *bool `query:"include_host_counts,optional"` } @@ -290,7 +310,7 @@ func listLabelsEndpoint(ctx context.Context, request interface{}, svc fleet.Serv includeHostCounts = *req.IncludeHostCounts } - labels, err := svc.ListLabels(ctx, req.ListOptions, includeHostCounts) + labels, err := svc.ListLabels(ctx, req.ListOptions, getTeamIDOrZeroForGlobal(req.TeamID), includeHostCounts) if err != nil { return listLabelsResponse{Err: err}, nil } @@ -306,29 +326,39 @@ func listLabelsEndpoint(ctx context.Context, request interface{}, svc fleet.Serv return resp, nil } -func (svc *Service) ListLabels(ctx context.Context, opt fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { +func getTeamIDOrZeroForGlobal(stringID *string) *uint { + if stringID == nil || *stringID == "" { + return nil + } + + if *stringID == "global" { + return ptr.Uint(0) + } + + if parsedTeamID, err := strconv.ParseUint(*stringID, 10, 32); err == nil { + return ptr.Uint(uint(parsedTeamID)) + } + + return nil +} + +func (svc *Service) ListLabels(ctx context.Context, opt fleet.ListOptions, teamID *uint, includeHostCounts bool) ([]*fleet.Label, error) { if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionRead); err != nil { return nil, err } - filter := fleet.TeamFilter{} vc, ok := viewer.FromContext(ctx) if !ok { return nil, fleet.ErrNoContext } - // Default to including host counts. - if includeHostCounts { - filter = fleet.TeamFilter{User: vc.User, IncludeObserver: true} - } - // TODO(mna): ListLabels doesn't currently return the hostIDs members of the // label, the quick approach would be an N+1 queries endpoint. Leaving like // that for now because we're in a hurry before merge freeze but the solution // would probably be to do it in 2 queries : grab all label IDs from the // list, then select hostID+labelID tuples in one query (where labelID IN // )and fill the hostIDs per label. - return svc.ds.ListLabels(ctx, filter, opt) + return svc.ds.ListLabels(ctx, fleet.TeamFilter{User: vc.User, IncludeObserver: true, TeamID: teamID}, opt, includeHostCounts) } func labelResponseForLabel(label *fleet.Label, hostIDs []uint) (*labelResponse, error) { @@ -340,10 +370,23 @@ func labelResponseForLabel(label *fleet.Label, hostIDs []uint) (*labelResponse, }, nil } +func labelResponseForLabelWithTeamName(label *fleet.LabelWithTeamName, hostIDs []uint) (*labelWithTeamNameResponse, error) { + return &labelWithTeamNameResponse{ + LabelWithTeamName: *label, + DisplayText: label.Name, + Count: label.HostCount, + HostIDs: hostIDs, + }, nil +} + //////////////////////////////////////////////////////////////////////////////// // Labels Summary //////////////////////////////////////////////////////////////////////////////// +type getLabelsSummaryRequest struct { + TeamID *string `query:"team_id,optional"` // string because it's an int or "global" +} + type getLabelsSummaryResponse struct { Labels []*fleet.LabelSummary `json:"labels"` Err error `json:"error,omitempty"` @@ -352,19 +395,26 @@ type getLabelsSummaryResponse struct { func (r getLabelsSummaryResponse) Error() error { return r.Err } func getLabelsSummaryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) { - labels, err := svc.LabelsSummary(ctx) + req := request.(*getLabelsSummaryRequest) + + labels, err := svc.LabelsSummary(ctx, getTeamIDOrZeroForGlobal(req.TeamID)) if err != nil { return getLabelsSummaryResponse{Err: err}, nil } return getLabelsSummaryResponse{Labels: labels}, nil } -func (svc *Service) LabelsSummary(ctx context.Context) ([]*fleet.LabelSummary, error) { +func (svc *Service) LabelsSummary(ctx context.Context, teamID *uint) ([]*fleet.LabelSummary, error) { if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionRead); err != nil { return nil, err } - return svc.ds.LabelsSummary(ctx) + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, fleet.ErrNoContext + } + + return svc.ds.LabelsSummary(ctx, fleet.TeamFilter{User: vc.User, IncludeObserver: true, TeamID: teamID}) } //////////////////////////////////////////////////////////////////////////////// @@ -456,18 +506,36 @@ func deleteLabelEndpoint(ctx context.Context, request interface{}, svc fleet.Ser } func (svc *Service) DeleteLabel(ctx context.Context, name string) error { - if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionWrite); err != nil { - return err + vc, ok := viewer.FromContext(ctx) + if !ok { + svc.SkipAuth(ctx) + return fleet.ErrNoContext } // check if the label is a built-in label for n := range fleet.ReservedLabelNames() { if n == name { + svc.SkipAuth(ctx) return fleet.NewInvalidArgumentError("name", fmt.Sprintf("cannot delete built-in label '%s'", name)) } } - return svc.ds.DeleteLabel(ctx, name) + filter := fleet.TeamFilter{User: vc.User} + + // need to grab the label first to see if we have permission to delete it; + // if the label doesn't exist global users will see the true 404, other users will get a 403 + label, err := svc.ds.LabelByName(ctx, name, filter) + if err != nil { + if authError := svc.authz.Authorize(ctx, fleet.Label{}, fleet.ActionWrite); authError != nil { + return authError + } + return err + } + if err := svc.authz.Authorize(ctx, label, fleet.ActionWrite); err != nil { + return err + } + + return svc.ds.DeleteLabel(ctx, name, filter) } //////////////////////////////////////////////////////////////////////////////// @@ -494,19 +562,27 @@ func deleteLabelByIDEndpoint(ctx context.Context, request interface{}, svc fleet } func (svc *Service) DeleteLabelByID(ctx context.Context, id uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionWrite); err != nil { - return err - } vc, ok := viewer.FromContext(ctx) if !ok { + svc.SkipAuth(ctx) return fleet.ErrNoContext } filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} + // need to grab the label first to see if we have permission to delete it; + // if the label doesn't exist global users will see the true 404, other users will get a 403 label, _, err := svc.ds.Label(ctx, id, filter) if err != nil { + // If we get a retrieval error, 403-wrap it if a user can't write global labels so we don't leak info + if authErr := svc.authz.Authorize(ctx, fleet.Label{}, fleet.ActionWrite); authErr != nil { + return authErr + } return err } + if err := svc.authz.Authorize(ctx, label, fleet.ActionWrite); err != nil { + return err + } + if label.LabelType == fleet.LabelTypeBuiltIn { return fleet.NewInvalidArgumentError("label_type", fmt.Sprintf("cannot delete built-in label '%s'", label.Name)) } @@ -516,7 +592,7 @@ func (svc *Service) DeleteLabelByID(ctx context.Context, id uint) error { } } - return svc.ds.DeleteLabel(ctx, label.Name) + return svc.ds.DeleteLabel(ctx, label.Name, filter) } //////////////////////////////////////////////////////////////////////////////// @@ -524,7 +600,9 @@ func (svc *Service) DeleteLabelByID(ctx context.Context, id uint) error { //////////////////////////////////////////////////////////////////////////////// type applyLabelSpecsRequest struct { - Specs []*fleet.LabelSpec `json:"specs"` + Specs []*fleet.LabelSpec `json:"specs"` + TeamID *uint `json:"-" query:"team_id,optional"` + NamesToMove []string `json:"names_to_move,omitempty"` } type applyLabelSpecsResponse struct { @@ -535,21 +613,27 @@ func (r applyLabelSpecsResponse) Error() error { return r.Err } func applyLabelSpecsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) { req := request.(*applyLabelSpecsRequest) - err := svc.ApplyLabelSpecs(ctx, req.Specs) + err := svc.ApplyLabelSpecs(ctx, req.Specs, req.TeamID, req.NamesToMove) if err != nil { return applyLabelSpecsResponse{Err: err}, nil } return applyLabelSpecsResponse{}, nil } -func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpec) error { - if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionWrite); err != nil { +func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpec, teamID *uint, namesToMove []string) error { + if err := svc.authz.Authorize(ctx, &fleet.Label{TeamID: teamID}, fleet.ActionWrite); err != nil { return err } + if !license.IsPremium(ctx) && teamID != nil && *teamID > 0 { + return fleet.ErrMissingLicense + } regularSpecs := make([]*fleet.LabelSpec, 0, len(specs)) var builtInSpecs []*fleet.LabelSpec var builtInSpecNames []string + + var specLabelNamesNeedingMoving []string // should match namesToMove once specs have been checked + for _, spec := range specs { if spec.LabelMembershipType == fleet.LabelMembershipTypeDynamic && len(spec.Hosts) > 0 { return fleet.NewUserMessageError( @@ -589,12 +673,25 @@ func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpe } } + if slices.Contains(namesToMove, spec.Name) { + specLabelNamesNeedingMoving = append(specLabelNamesNeedingMoving, spec.Name) + } + + // make sure we're only upserting labels on the team we specified; individual spec teams aren't used on writes + spec.TeamID = teamID regularSpecs = append(regularSpecs, spec) } + if len(specLabelNamesNeedingMoving) != len(namesToMove) { + return fleet.NewUserMessageError( + ctxerr.New(ctx, "label names to move list was not a subset of specified labels"), + http.StatusConflict, + ) + } + // If built-in labels have been provided, ensure that they are not attempted to be modified if len(builtInSpecs) > 0 { - labelMap, err := svc.ds.LabelsByName(ctx, builtInSpecNames) + labelMap, err := svc.ds.LabelsByName(ctx, builtInSpecNames, fleet.TeamFilter{}) // built-in labels are all global if err != nil { return err } @@ -616,8 +713,17 @@ func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpe return nil } - // Get the user from the context. user, ok := viewer.FromContext(ctx) + if ok && user.User != nil { + if err := svc.ds.SetAsideLabels(ctx, teamID, namesToMove, *user.User); err != nil { + return ctxerr.Wrap(ctx, err, "cleaning up conflicting other team labels") + } + } else if len(namesToMove) > 0 { + return fleet.NewUserMessageError( + ctxerr.New(ctx, "cannot move labels out of the way without user authentication"), http.StatusForbidden, + ) + } + // If we have a user, mark them as the label's author. if ok { return svc.ds.ApplyLabelSpecsWithAuthor(ctx, regularSpecs, ptr.Uint(user.UserID())) @@ -636,20 +742,30 @@ type getLabelSpecsResponse struct { func (r getLabelSpecsResponse) Error() error { return r.Err } +type getLabelSpecsRequest struct { + TeamID *uint `query:"team_id,optional"` +} + func getLabelSpecsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) { - specs, err := svc.GetLabelSpecs(ctx) + req := request.(*getLabelSpecsRequest) + specs, err := svc.GetLabelSpecs(ctx, req.TeamID) if err != nil { return getLabelSpecsResponse{Err: err}, nil } return getLabelSpecsResponse{Specs: specs}, nil } -func (svc *Service) GetLabelSpecs(ctx context.Context) ([]*fleet.LabelSpec, error) { +func (svc *Service) GetLabelSpecs(ctx context.Context, teamID *uint) ([]*fleet.LabelSpec, error) { if err := svc.authz.Authorize(ctx, &fleet.Label{}, fleet.ActionRead); err != nil { return nil, err } - return svc.ds.GetLabelSpecs(ctx) + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, fleet.ErrNoContext + } + + return svc.ds.GetLabelSpecs(ctx, fleet.TeamFilter{User: vc.User, IncludeObserver: true, TeamID: teamID}) } //////////////////////////////////////////////////////////////////////////////// @@ -677,7 +793,12 @@ func (svc *Service) GetLabelSpec(ctx context.Context, name string) (*fleet.Label return nil, err } - return svc.ds.GetLabelSpec(ctx, name) + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, fleet.ErrNoContext + } + + return svc.ds.GetLabelSpec(ctx, fleet.TeamFilter{User: vc.User, IncludeObserver: true}, name) } func (svc *Service) BatchValidateLabels(ctx context.Context, teamID *uint, labelNames []string) (map[string]fleet.LabelIdent, error) { @@ -693,7 +814,7 @@ func (svc *Service) BatchValidateLabels(ctx context.Context, teamID *uint, label uniqueNames := server.RemoveDuplicatesFromSlice(labelNames) - labels, err := svc.ds.LabelIDsByName(ctx, uniqueNames) + labels, err := svc.ds.LabelIDsByName(ctx, uniqueNames, fleet.TeamFilter{User: authz.UserFromContext(ctx)}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting label IDs by name") } @@ -705,7 +826,7 @@ func (svc *Service) BatchValidateLabels(ctx context.Context, teamID *uint, label } } - if err := verifyLabelsToAssociate(ctx, svc.ds, teamID, labelNames); err != nil { + if err := verifyLabelsToAssociate(ctx, svc.ds, teamID, labelNames, authz.UserFromContext(ctx)); err != nil { return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } diff --git a/server/service/labels_test.go b/server/service/labels_test.go index 24eee8931f..309431ebbe 100644 --- a/server/service/labels_test.go +++ b/server/service/labels_test.go @@ -8,6 +8,7 @@ import ( "time" authz_ctx "github.com/fleetdm/fleet/v4/server/contexts/authz" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql/testing_utils" @@ -25,63 +26,82 @@ func TestLabelsAuth(t *testing.T) { svc, ctx := newTestService(t, ds, nil, nil) ds.NewLabelFunc = func(ctx context.Context, lbl *fleet.Label, opts ...fleet.OptionalArg) (*fleet.Label, error) { + lbl.ID = 1 + if lbl.Name == "Other label" { + lbl.ID = 2 + } return lbl, nil } - ds.SaveLabelFunc = func(ctx context.Context, lbl *fleet.Label, filter fleet.TeamFilter) (*fleet.Label, []uint, error) { - return lbl, nil, nil + ds.SaveLabelFunc = func(ctx context.Context, lbl *fleet.Label, filter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { + return &fleet.LabelWithTeamName{Label: *lbl}, nil, nil } - ds.DeleteLabelFunc = func(ctx context.Context, nm string) error { + ds.DeleteLabelFunc = func(ctx context.Context, nm string, filter fleet.TeamFilter) error { return nil } ds.ApplyLabelSpecsFunc = func(ctx context.Context, specs []*fleet.LabelSpec) error { return nil } - ds.LabelFunc = func(ctx context.Context, id uint, filter fleet.TeamFilter) (*fleet.Label, []uint, error) { - return &fleet.Label{}, nil, nil + ds.LabelFunc = func(ctx context.Context, id uint, filter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { + switch id { + case uint(1): + return &fleet.LabelWithTeamName{Label: fleet.Label{ID: id, AuthorID: &filter.User.ID}}, nil, nil + case uint(2): + return &fleet.LabelWithTeamName{Label: fleet.Label{ID: id}}, nil, nil + } + + return nil, nil, ctxerr.Wrap(ctx, notFoundErr{"label", fleet.ErrorWithUUID{}}) } - ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opts fleet.ListOptions) ([]*fleet.Label, error) { + ds.LabelByNameFunc = func(ctx context.Context, name string, filter fleet.TeamFilter) (*fleet.Label, error) { + return &fleet.Label{ID: 2, Name: name}, nil // for deletes, TODO add cases for authorship/team differences + } + ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opts fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { return nil, nil } - ds.LabelsSummaryFunc = func(ctx context.Context) ([]*fleet.LabelSummary, error) { + ds.LabelsSummaryFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSummary, error) { return nil, nil } ds.ListHostsInLabelFunc = func(ctx context.Context, filter fleet.TeamFilter, lid uint, opts fleet.HostListOptions) ([]*fleet.Host, error) { return nil, nil } - ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { + ds.GetLabelSpecsFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSpec, error) { return nil, nil } - ds.GetLabelSpecFunc = func(ctx context.Context, name string) (*fleet.LabelSpec, error) { + ds.GetLabelSpecFunc = func(ctx context.Context, filter fleet.TeamFilter, name string) (*fleet.LabelSpec, error) { return &fleet.LabelSpec{}, nil } testCases := []struct { - name string - user *fleet.User - shouldFailWrite bool - shouldFailRead bool + name string + user *fleet.User + shouldFailGlobalWrite bool + shouldFailGlobalRead bool + shouldFailGlobalWriteIfAuthor bool }{ { "global admin", &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, false, false, + false, }, { "global maintainer", &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}, false, false, + false, }, { "global observer", &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}, true, false, + true, }, { "team maintainer", &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}, + true, false, false, }, @@ -90,44 +110,64 @@ func TestLabelsAuth(t *testing.T) { &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}, true, false, + true, }, } + + // add a new label authored by no one so we can check writes for labels that aren't authored by the user + otherLabel, _, err := svc.NewLabel(viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{ID: 1, GlobalRole: ptr.String(fleet.RoleMaintainer)}}), fleet.LabelPayload{Name: "Other label", Query: "SELECT 0"}) + require.NoError(t, err) + + // TODO create other-team label + for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { ctx := viewer.NewContext(ctx, viewer.Viewer{User: tt.user}) - _, _, err := svc.NewLabel(ctx, fleet.LabelPayload{Name: t.Name(), Query: `SELECT 1`}) - checkAuthErr(t, tt.shouldFailWrite, err) + myLabel, _, err := svc.NewLabel(ctx, fleet.LabelPayload{Name: t.Name(), Query: `SELECT 1`}) + checkAuthErr(t, tt.shouldFailGlobalWriteIfAuthor, err) // team write users can still create global labels - _, _, err = svc.ModifyLabel(ctx, 1, fleet.ModifyLabelPayload{}) - checkAuthErr(t, tt.shouldFailWrite, err) + _, _, err = svc.ModifyLabel(ctx, otherLabel.ID, fleet.ModifyLabelPayload{}) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{}) - checkAuthErr(t, tt.shouldFailWrite, err) + if myLabel != nil { + _, _, err = svc.ModifyLabel(ctx, myLabel.ID, fleet.ModifyLabelPayload{}) + checkAuthErr(t, tt.shouldFailGlobalWriteIfAuthor, err) + } - _, _, err = svc.GetLabel(ctx, 1) - checkAuthErr(t, tt.shouldFailRead, err) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{}, nil, nil) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) - _, err = svc.GetLabelSpecs(ctx) - checkAuthErr(t, tt.shouldFailRead, err) + _, _, err = svc.GetLabel(ctx, otherLabel.ID) + checkAuthErr(t, tt.shouldFailGlobalRead, err) + + _, err = svc.GetLabelSpecs(ctx, nil) + checkAuthErr(t, tt.shouldFailGlobalRead, err) _, err = svc.GetLabelSpec(ctx, "abc") - checkAuthErr(t, tt.shouldFailRead, err) + checkAuthErr(t, tt.shouldFailGlobalRead, err) - _, err = svc.ListLabels(ctx, fleet.ListOptions{}, true) - checkAuthErr(t, tt.shouldFailRead, err) + _, err = svc.ListLabels(ctx, fleet.ListOptions{}, nil, true) + checkAuthErr(t, tt.shouldFailGlobalRead, err) - _, err = svc.LabelsSummary((ctx)) - checkAuthErr(t, tt.shouldFailRead, err) + _, err = svc.LabelsSummary(ctx, nil) + checkAuthErr(t, tt.shouldFailGlobalRead, err) _, err = svc.ListHostsInLabel(ctx, 1, fleet.HostListOptions{}) - checkAuthErr(t, tt.shouldFailRead, err) + checkAuthErr(t, tt.shouldFailGlobalRead, err) err = svc.DeleteLabel(ctx, "abc") - checkAuthErr(t, tt.shouldFailWrite, err) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) - err = svc.DeleteLabelByID(ctx, 1) - checkAuthErr(t, tt.shouldFailWrite, err) + err = svc.DeleteLabelByID(ctx, otherLabel.ID) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) + + if myLabel != nil { + err = svc.DeleteLabelByID(ctx, myLabel.ID) + checkAuthErr(t, tt.shouldFailGlobalWriteIfAuthor, err) + } + + // TODO add team label permissions }) } } @@ -138,24 +178,26 @@ func TestListLabelsHostCountOptions(t *testing.T) { user := &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)} ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) - ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opts fleet.ListOptions) ([]*fleet.Label, error) { - // Expect the team filter to be empty, meaning no host counts requested - require.Nil(t, filter.User) + ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opts fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { + // Expect host counts not to be requested + require.False(t, includeHostCounts) return nil, nil } // Test explicitly setting include_host_counts to false - _, err := svc.ListLabels(ctx, fleet.ListOptions{}, false) + _, err := svc.ListLabels(ctx, fleet.ListOptions{}, nil, false) require.NoError(t, err) - ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opts fleet.ListOptions) ([]*fleet.Label, error) { - // Expect the team filter to be empty, meaning no host counts requested + ds.ListLabelsFunc = func(ctx context.Context, filter fleet.TeamFilter, opts fleet.ListOptions, includeHostCounts bool) ([]*fleet.Label, error) { + // Expect host counts to be requested + require.True(t, includeHostCounts) + // Expect the team filter to be set require.Equal(t, filter.User, user) return nil, nil } // Test explicitly setting include_host_counts to true - _, err = svc.ListLabels(ctx, fleet.ListOptions{}, true) + _, err = svc.ListLabels(ctx, fleet.ListOptions{}, nil, true) require.NoError(t, err) } @@ -198,11 +240,11 @@ func testLabelsListLabels(t *testing.T, ds *mysql.Datastore) { svc, ctx := newTestService(t, ds, nil, nil) require.NoError(t, ds.MigrateData(context.Background())) - labels, err := svc.ListLabels(test.UserContext(ctx, test.UserAdmin), fleet.ListOptions{Page: 0, PerPage: 1000}, true) + labels, err := svc.ListLabels(test.UserContext(ctx, test.UserAdmin), fleet.ListOptions{Page: 0, PerPage: 1000}, nil, true) require.NoError(t, err) require.Len(t, labels, 8) - labelsSummary, err := svc.LabelsSummary(test.UserContext(ctx, test.UserAdmin)) + labelsSummary, err := svc.LabelsSummary(test.UserContext(ctx, test.UserAdmin), nil) require.NoError(t, err) require.Len(t, labelsSummary, 8) } @@ -231,7 +273,7 @@ func TestApplyLabelSpecsWithBuiltInLabels(t *testing.T) { LabelMembershipType: labelMembershipType, } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { return map[string]*fleet.Label{ name: { Name: name, @@ -245,7 +287,7 @@ func TestApplyLabelSpecsWithBuiltInLabels(t *testing.T) { } // all good - err := svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err := svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) require.NoError(t, err) // trying to add a regular label with the same name as a built-in label should fail @@ -257,7 +299,7 @@ func TestApplyLabelSpecsWithBuiltInLabels(t *testing.T) { Query: query, LabelType: fleet.LabelTypeRegular, }, - }) + }, nil, nil) assert.ErrorContains(t, err, fmt.Sprintf("cannot add label '%s' because it conflicts with the name of a built-in label", name)) } @@ -265,39 +307,39 @@ func TestApplyLabelSpecsWithBuiltInLabels(t *testing.T) { const errorMessage = "cannot modify or add built-in label" // not ok -- built-in label name doesn't exist name = "not-foo" - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) assert.ErrorContains(t, err, errorMessage) name = "foo" // not ok -- description does not match description = "not-bar" - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) assert.ErrorContains(t, err, errorMessage) description = "bar" // not ok -- query does not match query = "select * from not-foo;" - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) assert.ErrorContains(t, err, errorMessage) query = "select * from foo;" // not ok -- label type does not match labelType = fleet.LabelTypeRegular - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) assert.ErrorContains(t, err, errorMessage) labelType = fleet.LabelTypeBuiltIn // not ok -- label membership type does not match labelMembershipType = fleet.LabelMembershipTypeManual - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) assert.ErrorContains(t, err, errorMessage) labelMembershipType = fleet.LabelMembershipTypeDynamic // not ok -- DB error - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { return nil, assert.AnError } - err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}) + err = svc.ApplyLabelSpecs(ctx, []*fleet.LabelSpec{spec}, nil, nil) assert.ErrorIs(t, err, assert.AnError) } @@ -347,27 +389,27 @@ func TestLabelsWithReplica(t *testing.T) { // make the newly-created label available to the reader opts.RunReplication("labels", "label_membership") - lbl, hostIDs, err = svc.ModifyLabel(ctx, lbl.ID, fleet.ModifyLabelPayload{Hosts: []string{"host1"}}) + lblWithName, hostIDs, err := svc.ModifyLabel(ctx, lbl.ID, fleet.ModifyLabelPayload{Hosts: []string{"host1"}}) require.NoError(t, err) require.ElementsMatch(t, []uint{h1.ID}, hostIDs) - require.Equal(t, 1, lbl.HostCount) - require.Equal(t, user.ID, *lbl.AuthorID) + require.Equal(t, 1, lblWithName.HostCount) + require.Equal(t, user.ID, *lblWithName.AuthorID) // reading this label without replication returns the old data as it only uses the reader - lbl, hostIDs, err = svc.GetLabel(ctx, lbl.ID) + lblWithName, hostIDs, err = svc.GetLabel(ctx, lblWithName.ID) require.NoError(t, err) require.ElementsMatch(t, []uint{h1.ID, h2.ID}, hostIDs) - require.Equal(t, 2, lbl.HostCount) - require.Equal(t, user.ID, *lbl.AuthorID) + require.Equal(t, 2, lblWithName.HostCount) + require.Equal(t, user.ID, *lblWithName.AuthorID) // running the replication makes the updated data available opts.RunReplication("labels", "label_membership") - lbl, hostIDs, err = svc.GetLabel(ctx, lbl.ID) + lblWithName, hostIDs, err = svc.GetLabel(ctx, lblWithName.ID) require.NoError(t, err) require.ElementsMatch(t, []uint{h1.ID}, hostIDs) - require.Equal(t, 1, lbl.HostCount) - require.Equal(t, user.ID, *lbl.AuthorID) + require.Equal(t, 1, lblWithName.HostCount) + require.Equal(t, user.ID, *lblWithName.AuthorID) } func TestBatchValidateLabels(t *testing.T) { @@ -401,7 +443,7 @@ func TestBatchValidateLabels(t *testing.T) { return fleet.LabelIdent{LabelID: id, LabelName: name} } - ds.LabelIDsByNameFunc = func(ctx context.Context, names []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { res := make(map[string]uint) if names == nil { return res, nil @@ -413,7 +455,7 @@ func TestBatchValidateLabels(t *testing.T) { } return res, nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { res := make(map[string]*fleet.Label) if names == nil { return res, nil @@ -505,8 +547,8 @@ func TestNewManualLabel(t *testing.T) { } t.Run("using hostnames", func(t *testing.T) { - ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { - require.Equal(t, uint(1), labelID) + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), label.ID) require.Equal(t, []uint{99, 100}, hostIds) return nil, nil, nil } @@ -518,8 +560,8 @@ func TestNewManualLabel(t *testing.T) { }) t.Run("using IDs", func(t *testing.T) { - ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { - require.Equal(t, uint(1), labelID) + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), label.ID) require.Equal(t, []uint{1, 2}, hostIds) return nil, nil, nil } @@ -536,22 +578,24 @@ func TestModifyManualLabel(t *testing.T) { svc, ctx := newTestService(t, ds, nil, nil) ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) - ds.LabelFunc = func(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { - return &fleet.Label{ - ID: lid, - LabelMembershipType: fleet.LabelMembershipTypeManual, + ds.LabelFunc = func(ctx context.Context, lid uint, teamFilter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { + return &fleet.LabelWithTeamName{ + Label: fleet.Label{ + ID: lid, + LabelMembershipType: fleet.LabelMembershipTypeManual, + }, }, nil, nil } ds.HostIDsByIdentifierFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { return []uint{99, 100}, nil } - ds.SaveLabelFunc = func(ctx context.Context, lbl *fleet.Label, filter fleet.TeamFilter) (*fleet.Label, []uint, error) { + ds.SaveLabelFunc = func(ctx context.Context, lbl *fleet.Label, filter fleet.TeamFilter) (*fleet.LabelWithTeamName, []uint, error) { return nil, nil, nil } t.Run("using hostnames", func(t *testing.T) { - ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { - require.Equal(t, uint(1), labelID) + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), label.ID) require.Equal(t, []uint{99, 100}, hostIds) return nil, nil, nil } @@ -562,8 +606,8 @@ func TestModifyManualLabel(t *testing.T) { }) t.Run("using IDs", func(t *testing.T) { - ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, labelID uint, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { - require.Equal(t, uint(1), labelID) + ds.UpdateLabelMembershipByHostIDsFunc = func(ctx context.Context, label fleet.Label, hostIds []uint, teamFilter fleet.TeamFilter) (*fleet.Label, []uint, error) { + require.Equal(t, uint(1), label.ID) require.Equal(t, []uint{1, 2}, hostIds) return nil, nil, nil } diff --git a/server/service/labels_util.go b/server/service/labels_util.go index 40242c4049..4306219355 100644 --- a/server/service/labels_util.go +++ b/server/service/labels_util.go @@ -5,10 +5,11 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" ) -func loadLabelsFromNames(ctx context.Context, ds fleet.Datastore, labelNames []string) (map[string]*fleet.Label, error) { - labelsMap, err := ds.LabelsByName(ctx, labelNames) +func loadLabelsFromNames(ctx context.Context, ds fleet.Datastore, labelNames []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { + labelsMap, err := ds.LabelsByName(ctx, labelNames, filter) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get labels by name") } @@ -21,7 +22,7 @@ func loadLabelsFromNames(ctx context.Context, ds fleet.Datastore, labelNames []s return labelsMap, nil } -func verifyLabelsToAssociate(ctx context.Context, ds fleet.Datastore, entityTeamID *uint, labelNames []string) error { +func verifyLabelsToAssociate(ctx context.Context, ds fleet.Datastore, entityTeamID *uint, labelNames []string, user *fleet.User) error { if len(labelNames) == 0 { return nil } @@ -37,35 +38,17 @@ func verifyLabelsToAssociate(ctx context.Context, ds fleet.Datastore, entityTeam uniqueLabelNames = append(uniqueLabelNames, s) } - // Load data of all labels. - labels, err := loadLabelsFromNames(ctx, ds, uniqueLabelNames) + if entityTeamID == nil { // no-team/all-teams entities can only access global labels + entityTeamID = ptr.Uint(0) + } + + labels, err := loadLabelsFromNames(ctx, ds, uniqueLabelNames, fleet.TeamFilter{User: user, TeamID: entityTeamID}) if err != nil { return ctxerr.Wrap(ctx, err, "labels by name") } - // Perform team ID checks for "No team" or global entities. - if entityTeamID == nil || *entityTeamID == 0 { - // entityTeamID == nil: global entity (like "All teams" policies and "All team" queries) - // entityTeamID == 0: "no team" entity. - // For both cases, labels must be global because currently we don't support labels in "No team". - for _, label := range labels { - if label.TeamID != nil { - return ctxerr.Wrap(ctx, badRequestf("label %q is a team label", label.Name)) - } - } - return nil - } - - // Perform team ID checks for team entities. - for _, label := range labels { - // Team entities can reference global labels. - if label.TeamID == nil { - continue - } - // Team entities cannot reference labels that belong another team. - if *label.TeamID != *entityTeamID { - return ctxerr.Wrap(ctx, badRequestf("label %q belongs to a different team", label.Name)) - } + if len(labels) != len(uniqueLabelNames) { + return ctxerr.Wrap(ctx, badRequest("one or more labels specified do not exist, or cannot be applied to this entity")) } return nil diff --git a/server/service/mdm.go b/server/service/mdm.go index a852d15c35..819f780c1a 100644 --- a/server/service/mdm.go +++ b/server/service/mdm.go @@ -1828,7 +1828,7 @@ func (svc *Service) batchValidateProfileLabels(ctx context.Context, teamID *uint return nil, nil } - labels, err := svc.ds.LabelIDsByName(ctx, labelNames) + labels, err := svc.ds.LabelIDsByName(ctx, labelNames, fleet.TeamFilter{User: authz.UserFromContext(ctx)}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting label IDs by name") } @@ -1850,7 +1850,7 @@ func (svc *Service) batchValidateProfileLabels(ctx context.Context, teamID *uint // NOTE(lucas): To not break API error string returned above // AND for code reusability we are a-ok with loading labels again in verifyLabelsToAssociate. // This can definitely be optimized if need be. - if err := verifyLabelsToAssociate(ctx, svc.ds, teamID, labelNames); err != nil { + if err := verifyLabelsToAssociate(ctx, svc.ds, teamID, labelNames, authz.UserFromContext(ctx)); err != nil { return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } diff --git a/server/service/mdm_test.go b/server/service/mdm_test.go index 46826fbbae..6bad46cef1 100644 --- a/server/service/mdm_test.go +++ b/server/service/mdm_test.go @@ -2347,9 +2347,9 @@ func TestBatchSetMDMProfilesLabels(t *testing.T) { return fleet.MDMProfilesUpdates{}, nil } var labelID uint - ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { m := map[string]uint{} - for _, label := range labels { + for _, label := range names { if label != "baddy" { labelID++ m[label] = labelID @@ -2357,7 +2357,7 @@ func TestBatchSetMDMProfilesLabels(t *testing.T) { } return m, nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { m := map[string]*fleet.Label{} for _, name := range names { if name != "baddy" { diff --git a/server/service/metrics_labels.go b/server/service/metrics_labels.go index ad253d2c7b..184fa13c4d 100644 --- a/server/service/metrics_labels.go +++ b/server/service/metrics_labels.go @@ -8,17 +8,12 @@ import ( "github.com/fleetdm/fleet/v4/server/fleet" ) -func (mw metricsMiddleware) ModifyLabel(ctx context.Context, id uint, p fleet.ModifyLabelPayload) (*fleet.Label, []uint, error) { - var ( - lic *fleet.Label - hids []uint - err error - ) +func (mw metricsMiddleware) ModifyLabel(ctx context.Context, id uint, p fleet.ModifyLabelPayload) (*fleet.LabelWithTeamName, []uint, error) { + var err error defer func(begin time.Time) { lvs := []string{"method", "ModifyLabel", "error", fmt.Sprint(err != nil)} mw.requestCount.With(lvs...).Add(1) mw.requestLatency.With(lvs...).Observe(time.Since(begin).Seconds()) }(time.Now()) - lic, hids, err = mw.Service.ModifyLabel(ctx, id, p) - return lic, hids, err + return mw.Service.ModifyLabel(ctx, id, p) } diff --git a/server/service/queries.go b/server/service/queries.go index 82bfa02435..0535b83f24 100644 --- a/server/service/queries.go +++ b/server/service/queries.go @@ -281,14 +281,17 @@ func (svc *Service) NewQuery(ctx context.Context, p fleet.QueryPayload) (*fleet. }) } - if err := verifyLabelsToAssociate(ctx, svc.ds, p.TeamID, p.LabelsIncludeAny); err != nil { - return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") + query := &fleet.Query{Saved: true, TeamID: p.TeamID} + + vc, ok := viewer.FromContext(ctx) + if ok { + query.AuthorID = ptr.Uint(vc.UserID()) + query.AuthorName = vc.FullName() + query.AuthorEmail = vc.Email() } - query := &fleet.Query{ - Saved: true, - - TeamID: p.TeamID, + if err := verifyLabelsToAssociate(ctx, svc.ds, p.TeamID, p.LabelsIncludeAny, vc.User); err != nil { + return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } if p.Name != nil { @@ -331,13 +334,6 @@ func (svc *Service) NewQuery(ctx context.Context, p fleet.QueryPayload) (*fleet. logging.WithExtras(ctx, "name", query.Name, "sql", query.Query) - vc, ok := viewer.FromContext(ctx) - if ok { - query.AuthorID = ptr.Uint(vc.UserID()) - query.AuthorName = vc.FullName() - query.AuthorEmail = vc.Email() - } - query, err := svc.ds.NewQuery(ctx, query) if err != nil { return nil, err @@ -422,7 +418,7 @@ func (svc *Service) ModifyQuery(ctx context.Context, id uint, p fleet.QueryPaylo } // We use query.TeamID because we do not allow changing the team - if err := verifyLabelsToAssociate(ctx, svc.ds, query.TeamID, p.LabelsIncludeAny); err != nil { + if err := verifyLabelsToAssociate(ctx, svc.ds, query.TeamID, p.LabelsIncludeAny, authz.UserFromContext(ctx)); err != nil { return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } @@ -854,7 +850,12 @@ func (svc *Service) queryFromSpec(ctx context.Context, spec *fleet.QuerySpec) (* // Find labels by name var queryLabels []fleet.LabelIdent if len(spec.LabelsIncludeAny) > 0 { - labelsMap, err := svc.ds.LabelsByName(ctx, spec.LabelsIncludeAny) + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, fleet.ErrNoContext + } + + labelsMap, err := svc.ds.LabelsByName(ctx, spec.LabelsIncludeAny, fleet.TeamFilter{User: vc.User, TeamID: teamID}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get labels by name") } diff --git a/server/service/queries_test.go b/server/service/queries_test.go index 8b73764f71..c41f300933 100644 --- a/server/service/queries_test.go +++ b/server/service/queries_test.go @@ -978,7 +978,8 @@ func TestApplyQuerySpec(t *testing.T) { ds.ApplyQueriesFunc = func(ctx context.Context, authID uint, queries []*fleet.Query, queriesToDiscardResults map[uint]struct{}) error { return nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { + require.NotNil(t, filter.User) labels := make(map[string]*fleet.Label, len(names)) for _, name := range names { if name == "foo" { diff --git a/server/service/software_installers_test.go b/server/service/software_installers_test.go index 595760c70b..1219ed45bf 100644 --- a/server/service/software_installers_test.go +++ b/server/service/software_installers_test.go @@ -232,7 +232,7 @@ func TestValidateSoftwareLabels(t *testing.T) { "baz": 3, } - ds.LabelIDsByNameFunc = func(ctx context.Context, names []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { res := make(map[string]uint) if names == nil { return res, nil @@ -244,7 +244,7 @@ func TestValidateSoftwareLabels(t *testing.T) { } return res, nil } - ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + ds.LabelsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]*fleet.Label, error) { res := make(map[string]*fleet.Label) if names == nil { return res, nil @@ -381,7 +381,7 @@ func TestValidateSoftwareLabels(t *testing.T) { "baz": 3, } - ds.LabelIDsByNameFunc = func(ctx context.Context, names []string) (map[string]uint, error) { + ds.LabelIDsByNameFunc = func(ctx context.Context, names []string, filter fleet.TeamFilter) (map[string]uint, error) { res := make(map[string]uint) if names == nil { return res, nil diff --git a/server/service/team_policies.go b/server/service/team_policies.go index accfa3e7a4..a266cb29ec 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -91,7 +91,7 @@ func (svc Service) NewTeamPolicy(ctx context.Context, teamID uint, tp fleet.NewT }) } - if err := verifyLabelsToAssociate(ctx, svc.ds, &teamID, append(tp.LabelsIncludeAny, tp.LabelsExcludeAny...)); err != nil { + if err := verifyLabelsToAssociate(ctx, svc.ds, &teamID, append(tp.LabelsIncludeAny, tp.LabelsExcludeAny...), vc.User); err != nil { return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } @@ -570,7 +570,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f }) } - if err := verifyLabelsToAssociate(ctx, svc.ds, teamID, append(p.LabelsIncludeAny, p.LabelsExcludeAny...)); err != nil { + if err := verifyLabelsToAssociate(ctx, svc.ds, teamID, append(p.LabelsIncludeAny, p.LabelsExcludeAny...), authz.UserFromContext(ctx)); err != nil { return nil, ctxerr.Wrap(ctx, err, "verify labels to associate") } diff --git a/server/service/testing_client.go b/server/service/testing_client.go index 26b5ad2201..046bb464b6 100644 --- a/server/service/testing_client.go +++ b/server/service/testing_client.go @@ -160,11 +160,11 @@ func (ts *withServer) commonTearDownTest(t *testing.T) { return nil }) - lbls, err := ts.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{}) + lbls, err := ts.ds.ListLabels(ctx, filter, fleet.ListOptions{}, false) require.NoError(t, err) for _, lbl := range lbls { if lbl.LabelType != fleet.LabelTypeBuiltIn { - err := ts.ds.DeleteLabel(ctx, lbl.Name) + err := ts.ds.DeleteLabel(ctx, lbl.Name, filter) require.NoError(t, err) } } diff --git a/server/service/translator.go b/server/service/translator.go index 2ec4d0c794..3339f4cfe7 100644 --- a/server/service/translator.go +++ b/server/service/translator.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/service/middleware/endpoint_utils" ) @@ -39,7 +40,7 @@ func translateEmailToUserID(ctx context.Context, ds fleet.Datastore, identifier } func translateLabelToID(ctx context.Context, ds fleet.Datastore, identifier string) (uint, error) { - labelIDs, err := ds.LabelIDsByName(ctx, []string{identifier}) + labelIDs, err := ds.LabelIDsByName(ctx, []string{identifier}, fleet.TeamFilter{User: authz.UserFromContext(ctx)}) if err != nil { return 0, err }