diff --git a/changes/27301-add-labels-to-policies-with-gitops b/changes/27301-add-labels-to-policies-with-gitops new file mode 100644 index 0000000000..e9b2707636 --- /dev/null +++ b/changes/27301-add-labels-to-policies-with-gitops @@ -0,0 +1 @@ +- Added ability to set labels on policies via GitOps \ No newline at end of file diff --git a/cmd/fleetctl/gitops.go b/cmd/fleetctl/gitops.go index 50cc4c1e3f..f60a5d64b8 100644 --- a/cmd/fleetctl/gitops.go +++ b/cmd/fleetctl/gitops.go @@ -201,7 +201,10 @@ func gitopsCommand() *cli.Command { // so we can bail if any of the referenced labels wouldn't exist // after this run (either because they'd be deleted, never existed // in the first place). - labelsUsed := getLabelUsage(config) + labelsUsed, err := getLabelUsage(config) + if err != nil { + return err + } // Check if any used labels are not in the proposed labels list. // If there are, we'll bail out with helpful error messages. @@ -347,15 +350,6 @@ func gitopsCommand() *cli.Command { } } -// Merge sets of label names. -func concatLabels(labelArrays ...[]string) []string { - var result []string - for _, arr := range labelArrays { - result = append(result, arr...) - } - return result -} - // Given a set of referenced labels and info about who is using them, update a provided usage map. func updateLabelUsage(labels []string, ident string, usageType string, currentUsage map[string][]LabelUsage) { for _, label := range labels { @@ -374,14 +368,32 @@ func updateLabelUsage(labels []string, ident string, usageType string, currentUs // Create a map of label name -> who is using that label. // This will be used to determine if any non-existent labels are being referenced. -func getLabelUsage(config *spec.GitOps) map[string][]LabelUsage { +func getLabelUsage(config *spec.GitOps) (map[string][]LabelUsage, error) { result := make(map[string][]LabelUsage) // Get profile label usage for _, osSettingName := range []interface{}{config.Controls.MacOSSettings, config.Controls.WindowsSettings} { if osSettings, ok := getCustomSettings(osSettingName); ok { for _, setting := range osSettings { - labels := concatLabels(setting.LabelsIncludeAny, setting.LabelsIncludeAll, setting.LabelsExcludeAny) + var labels []string + err := fmt.Errorf("MDM profile '%s' has multiple label keys; please choose one of `labels_include_any`, `labels_include_all` or `labels_exclude_any`.", setting.Path) + + if len(setting.LabelsIncludeAny) > 0 { + labels = setting.LabelsIncludeAny + } + if len(setting.LabelsIncludeAll) > 0 { + if len(labels) > 0 { + return nil, err + } + labels = setting.LabelsIncludeAll + } + if len(setting.LabelsExcludeAny) > 0 { + if len(labels) > 0 { + return nil, err + } + labels = setting.LabelsExcludeAny + } + updateLabelUsage(labels, setting.Path, "MDM Profile", result) } } @@ -389,13 +401,31 @@ func getLabelUsage(config *spec.GitOps) map[string][]LabelUsage { // Get software package installer label usage for _, setting := range config.Software.Packages { - labels := concatLabels(setting.LabelsIncludeAny, setting.LabelsExcludeAny) + var labels []string + if len(setting.LabelsIncludeAny) > 0 { + labels = setting.LabelsIncludeAny + } + if len(setting.LabelsExcludeAny) > 0 { + if len(labels) > 0 { + return nil, fmt.Errorf("Software package '%s' has multiple label keys; please choose one of `labels_include_any`, `labels_exclude_any`.", setting.URL) + } + labels = setting.LabelsExcludeAny + } updateLabelUsage(labels, setting.URL, "Software Package", result) } // Get app store app installer label usage for _, setting := range config.Software.AppStoreApps { - labels := concatLabels(setting.LabelsIncludeAny, setting.LabelsExcludeAny) + var labels []string + if len(setting.LabelsIncludeAny) > 0 { + labels = setting.LabelsIncludeAny + } + if len(setting.LabelsExcludeAny) > 0 { + if len(labels) > 0 { + return nil, fmt.Errorf("App Store App '%s' has multiple label keys; please choose one of `labels_include_any`, `labels_exclude_any`.", setting.AppStoreID) + } + labels = setting.LabelsExcludeAny + } updateLabelUsage(labels, setting.AppStoreID, "App Store App", result) } @@ -404,7 +434,22 @@ func getLabelUsage(config *spec.GitOps) map[string][]LabelUsage { updateLabelUsage(query.LabelsIncludeAny, query.Name, "Query", result) } - return result + // Get policy label usage + for _, policy := range config.Policies { + var labels []string + if len(policy.LabelsIncludeAny) > 0 { + labels = policy.LabelsIncludeAny + } + if len(policy.LabelsExcludeAny) > 0 { + if len(labels) > 0 { + return nil, fmt.Errorf("Policy '%s' has multiple label keys; please choose one of `labels_include_any`, `labels_exclude_any`.", policy.Name) + } + labels = policy.LabelsExcludeAny + } + updateLabelUsage(labels, policy.Name, "Policy", result) + } + + return result, nil } func getCustomSettings(osSettings interface{}) ([]fleet.MDMProfileSpec, bool) { diff --git a/cmd/fleetctl/gitops_test.go b/cmd/fleetctl/gitops_test.go index aedeec0254..7392b1c727 100644 --- a/cmd/fleetctl/gitops_test.go +++ b/cmd/fleetctl/gitops_test.go @@ -44,7 +44,20 @@ const ( func addLabelMocks(ds *mock.Store) { var deletedLabels []string ds.GetLabelSpecsFunc = func(ctx context.Context) ([]*fleet.LabelSpec, error) { - return nil, nil + return []*fleet.LabelSpec{ + { + Name: "a", + Description: "A global label", + LabelMembershipType: fleet.LabelMembershipTypeManual, + Hosts: []string{"host2", "host3"}, + }, + { + Name: "b", + Description: "Another global label", + LabelMembershipType: fleet.LabelMembershipTypeDynamic, + Query: "SELECT 1 from osquery_info", + }, + }, nil } ds.ApplyLabelSpecsWithAuthorFunc = func(ctx context.Context, specs []*fleet.LabelSpec, authorID *uint) (err error) { return nil @@ -878,6 +891,13 @@ func TestGitOpsFullGlobal(t *testing.T) { assert.Len(t, enrolledSecrets, 2) assert.True(t, policyDeleted) assert.Len(t, appliedPolicySpecs, 5) + assert.Len(t, appliedPolicySpecs[0].LabelsIncludeAny, 1) + assert.Len(t, appliedPolicySpecs[0].LabelsExcludeAny, 0) + assert.Equal(t, appliedPolicySpecs[0].LabelsIncludeAny[0], "a") + assert.Len(t, appliedPolicySpecs[1].LabelsIncludeAny, 0) + assert.Len(t, appliedPolicySpecs[1].LabelsExcludeAny, 1) + assert.Equal(t, appliedPolicySpecs[1].LabelsExcludeAny[0], "b") + assert.True(t, queryDeleted) assert.Len(t, appliedQueries, 3) assert.Len(t, appliedScripts, 1) @@ -2260,6 +2280,21 @@ software: require.Len(t, (*savedAppConfigPtr).MDM.WindowsSettings.CustomSettings.Value, 1) assert.Equal(t, filepath.Base(cspFile.Name()), filepath.Base((*savedAppConfigPtr).MDM.WindowsSettings.CustomSettings.Value[0].Path)) assert.True(t, ds.BatchSetScriptsFuncInvoked) + + // Get applied policies for the team + teamAppliedPoliceSpecs := make([]*fleet.PolicySpec, 0) + for _, appliedPolicySpec := range appliedPolicySpecs { + if appliedPolicySpec.Team == teamName { + teamAppliedPoliceSpecs = append(teamAppliedPoliceSpecs, appliedPolicySpec) + } + } + assert.Len(t, teamAppliedPoliceSpecs, 5) + assert.Len(t, teamAppliedPoliceSpecs[0].LabelsIncludeAny, 0) + assert.Len(t, teamAppliedPoliceSpecs[0].LabelsExcludeAny, 1) + assert.Equal(t, teamAppliedPoliceSpecs[0].LabelsExcludeAny[0], "a") + assert.Len(t, teamAppliedPoliceSpecs[1].LabelsIncludeAny, 1) + assert.Len(t, teamAppliedPoliceSpecs[1].LabelsExcludeAny, 0) + assert.Equal(t, teamAppliedPoliceSpecs[1].LabelsIncludeAny[0], "b") }) } @@ -2714,13 +2749,13 @@ func TestGitOpsCustomSettings(t *testing.T) { }{ {"testdata/gitops/global_macos_windows_custom_settings_valid.yml", ""}, {"testdata/gitops/global_macos_custom_settings_valid_deprecated.yml", ""}, - {"testdata/gitops/global_windows_custom_settings_invalid_label_mix.yml", `For each profile, only one of "labels_exclude_any", "labels_include_all", "labels_include_any" or "labels" can be included`}, - {"testdata/gitops/global_windows_custom_settings_invalid_label_mix_2.yml", `For each profile, only one of "labels_exclude_any", "labels_include_all", "labels_include_any" or "labels" can be included`}, + {"testdata/gitops/global_windows_custom_settings_invalid_label_mix.yml", "please choose one of `labels_include_any`, `labels_include_all` or `labels_exclude_any`"}, + {"testdata/gitops/global_windows_custom_settings_invalid_label_mix_2.yml", "please choose one of `labels_include_any`, `labels_include_all` or `labels_exclude_any`"}, {"testdata/gitops/global_windows_custom_settings_unknown_label.yml", `Please create the missing labels, or update your settings to not refer to these labels.`}, {"testdata/gitops/team_macos_windows_custom_settings_valid.yml", ""}, {"testdata/gitops/team_macos_custom_settings_valid_deprecated.yml", ""}, - {"testdata/gitops/team_macos_windows_custom_settings_invalid_labels_mix.yml", `For each profile, only one of "labels_exclude_any", "labels_include_all", "labels_include_any" or "labels" can be included`}, - {"testdata/gitops/team_macos_windows_custom_settings_invalid_labels_mix_2.yml", `For each profile, only one of "labels_exclude_any", "labels_include_all", "labels_include_any" or "labels" can be included`}, + {"testdata/gitops/team_macos_windows_custom_settings_invalid_labels_mix.yml", "please choose one of `labels_include_any`, `labels_include_all` or `labels_exclude_any`"}, + {"testdata/gitops/team_macos_windows_custom_settings_invalid_labels_mix_2.yml", "please choose one of `labels_include_any`, `labels_include_all` or `labels_exclude_any`"}, {"testdata/gitops/team_macos_windows_custom_settings_unknown_label.yml", `Please create the missing labels, or update your settings to not refer to these labels.`}, } for _, c := range cases { diff --git a/cmd/fleetctl/testdata/gitops/global_config_no_paths.yml b/cmd/fleetctl/testdata/gitops/global_config_no_paths.yml index 50989f6457..e19164577a 100644 --- a/cmd/fleetctl/testdata/gitops/global_config_no_paths.yml +++ b/cmd/fleetctl/testdata/gitops/global_config_no_paths.yml @@ -72,11 +72,15 @@ policies: description: This policy should always fail. resolution: There is no resolution for this policy. query: SELECT 1 FROM osquery_info WHERE start_time < 0; + labels_include_any: + - a - name: Passing policy platform: linux,windows,darwin,chrome description: This policy should always pass. resolution: There is no resolution for this policy. query: SELECT 1; + labels_exclude_any: + - b - name: No root logins (macOS, Linux) platform: linux,darwin query: SELECT 1 WHERE NOT EXISTS (SELECT * FROM last diff --git a/cmd/fleetctl/testdata/gitops/team_config_no_paths.yml b/cmd/fleetctl/testdata/gitops/team_config_no_paths.yml index e671d17d29..d3db1bdd65 100644 --- a/cmd/fleetctl/testdata/gitops/team_config_no_paths.yml +++ b/cmd/fleetctl/testdata/gitops/team_config_no_paths.yml @@ -94,11 +94,15 @@ policies: resolution: There is no resolution for this policy. query: SELECT 1 FROM osquery_info WHERE start_time < 0; calendar_events_enabled: true + labels_exclude_any: + - a - name: Passing policy platform: linux,windows,darwin,chrome description: This policy should always pass. resolution: There is no resolution for this policy. query: SELECT 1; + labels_include_any: + - b - name: No root logins (macOS, Linux) platform: linux,darwin query: SELECT 1 WHERE NOT EXISTS (SELECT * FROM last diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 3ccbd72b37..368caf89dc 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -349,7 +349,6 @@ func savePolicy(ctx context.Context, db sqlx.ExtContext, logger kitlog.Logger, p return cleanupPolicy( ctx, db, db, p.ID, p.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, logger, ) - } func assertTeamMatches(ctx context.Context, db sqlx.QueryerContext, teamID uint, softwareInstallerID *uint, scriptID *uint, vppAppsTeamsID *uint) error { @@ -1259,51 +1258,90 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert") } + // Get the last inserted ID -- this will be 0 if it was an update. + var lastID int64 + lastID, _ = res.LastInsertId() + + // If something was actually inserted or updated, perform any necessary cleanup. if insertOnDuplicateDidInsertOrUpdate(res) { - // when the upsert results in an UPDATE that *did* change some values, - // it returns the updated ID as last inserted id. - if lastID, _ := res.LastInsertId(); lastID > 0 { - var ( - shouldRemoveAllPolicyMemberships bool - removePolicyStats bool - ) - // Figure out if the query, platform, software installer, or VPP app changed. - var softwareInstallerID *uint - if spec.SoftwareTitleID != nil { - softwareInstallerID = softwareInstallerIDs[teamID][*spec.SoftwareTitleID] - } - if prev, ok := teamIDToPoliciesByName[teamID][spec.Name]; ok { - switch { - case prev.Query != spec.Query: - shouldRemoveAllPolicyMemberships = true - removePolicyStats = true - case teamID != nil && - ((prev.SoftwareInstallerID == nil && spec.SoftwareTitleID != nil) || - (prev.SoftwareInstallerID != nil && softwareInstallerID != nil && *prev.SoftwareInstallerID != *softwareInstallerID)): - shouldRemoveAllPolicyMemberships = true - removePolicyStats = true - case teamID != nil && - ((prev.VPPAppsTeamsID == nil && spec.SoftwareTitleID != nil) || - (prev.VPPAppsTeamsID != nil && vppAppsTeamsID != nil && *prev.VPPAppsTeamsID != *vppAppsTeamsID)): - shouldRemoveAllPolicyMemberships = true - removePolicyStats = true - case teamID != nil && - ((prev.ScriptID == nil && spec.ScriptID != nil) || - (prev.ScriptID != nil && spec.ScriptID != nil && *prev.ScriptID != *spec.ScriptID)): - shouldRemoveAllPolicyMemberships = true - removePolicyStats = true - case prev.Platforms != spec.Platform: - removePolicyStats = true - } - } - if err = cleanupPolicy( - ctx, tx, tx, uint(lastID), spec.Platform, shouldRemoveAllPolicyMemberships, //nolint:gosec // dismiss G115 - removePolicyStats, ds.logger, - ); err != nil { - return err + var ( + shouldRemoveAllPolicyMemberships bool + removePolicyStats bool + ) + // Figure out if the query, platform, software installer, or VPP app changed. + var softwareInstallerID *uint + if spec.SoftwareTitleID != nil { + softwareInstallerID = softwareInstallerIDs[teamID][*spec.SoftwareTitleID] + } + if prev, ok := teamIDToPoliciesByName[teamID][spec.Name]; ok { + switch { + case prev.Query != spec.Query: + shouldRemoveAllPolicyMemberships = true + removePolicyStats = true + case teamID != nil && + ((prev.SoftwareInstallerID == nil && spec.SoftwareTitleID != nil) || + (prev.SoftwareInstallerID != nil && softwareInstallerID != nil && *prev.SoftwareInstallerID != *softwareInstallerID)): + shouldRemoveAllPolicyMemberships = true + removePolicyStats = true + case teamID != nil && + ((prev.VPPAppsTeamsID == nil && spec.SoftwareTitleID != nil) || + (prev.VPPAppsTeamsID != nil && vppAppsTeamsID != nil && *prev.VPPAppsTeamsID != *vppAppsTeamsID)): + shouldRemoveAllPolicyMemberships = true + removePolicyStats = true + case teamID != nil && + ((prev.ScriptID == nil && spec.ScriptID != nil) || + (prev.ScriptID != nil && spec.ScriptID != nil && *prev.ScriptID != *spec.ScriptID)): + shouldRemoveAllPolicyMemberships = true + removePolicyStats = true + case prev.Platforms != spec.Platform: + removePolicyStats = true } } + if err = cleanupPolicy( + ctx, tx, tx, uint(lastID), spec.Platform, shouldRemoveAllPolicyMemberships, //nolint:gosec // dismiss G115 + removePolicyStats, ds.logger, + ); err != nil { + return err + } } + + // Even if the policy record itself wasn't updated, we still may need to update labels. + // So we'll get the ID of the policy that was just updated. + if lastID == 0 { + var err error + // Get the policy that was updated. + if teamID == nil { + err = sqlx.GetContext(ctx, tx, &lastID, "SELECT id FROM policies WHERE name = ? AND team_id is NULL", spec.Name) + } else { + err = sqlx.GetContext(ctx, tx, &lastID, "SELECT id FROM policies WHERE name = ? AND team_id = ?", spec.Name, teamID) + } + if err != nil { + return ctxerr.Wrap(ctx, err, "select policies id") + } + } + + // Create LabelIdents to send to updatePolicyLabelsTx. + // Right now we only need the names. + // @future: use IDs instead of names. + labelsIncludeAnyIdents := make([]fleet.LabelIdent, 0, len(spec.LabelsIncludeAny)) + for _, labelInclude := range spec.LabelsIncludeAny { + labelsIncludeAnyIdents = append(labelsIncludeAnyIdents, fleet.LabelIdent{LabelName: labelInclude}) + } + labelsExcludeAnyIdents := make([]fleet.LabelIdent, 0, len(spec.LabelsExcludeAny)) + for _, labelExclude := range spec.LabelsExcludeAny { + labelsExcludeAnyIdents = append(labelsExcludeAnyIdents, fleet.LabelIdent{LabelName: labelExclude}) + } + err = updatePolicyLabelsTx(ctx, tx, &fleet.Policy{ + PolicyData: fleet.PolicyData{ + ID: uint(lastID), //nolint:gosec // dismiss G115 + LabelsIncludeAny: labelsIncludeAnyIdents, + LabelsExcludeAny: labelsExcludeAnyIdents, + }, + }) + if err != nil { + return ctxerr.Wrap(ctx, err, "exec policies update labels") + } + } } return nil diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 7b4d970176..cbb99b9ac5 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -1664,14 +1664,38 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { unicode, _ := strconv.Unquote(`"\uAC00"`) // 가 unicodeEq, _ := strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ + // Add a user-defined label + fooLabel, err := ds.NewLabel( + context.Background(), + &fleet.Label{ + Name: "Foo", + Query: "select 1", + LabelType: fleet.LabelTypeRegular, + LabelMembershipType: fleet.LabelMembershipTypeManual, + }, + ) + require.NoError(t, err) + + barLabel, err := ds.NewLabel( + context.Background(), + &fleet.Label{ + Name: "Bar", + Query: "select 1", + LabelType: fleet.LabelTypeRegular, + LabelMembershipType: fleet.LabelMembershipTypeManual, + }, + ) + require.NoError(t, err) + require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{ { - Name: "query1" + unicodeEq, - Query: "select 1;", - Description: "query1 desc", - Resolution: "some resolution", - Team: "", - Platform: "", + Name: "query1" + unicodeEq, + Query: "select 1;", + Description: "query1 desc", + Resolution: "some resolution", + Team: "", + Platform: "", + LabelsIncludeAny: []string{fooLabel.Name}, }, { Name: "query2", @@ -1681,6 +1705,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { Team: "team1", Platform: "darwin", CalendarEventsEnabled: true, + LabelsExcludeAny: []string{barLabel.Name}, }, { Name: "query3", @@ -1711,6 +1736,10 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { require.NotNil(t, policies[0].Resolution) assert.Equal(t, "some resolution", *policies[0].Resolution) assert.Equal(t, "", policies[0].Platform) + assert.Equal(t, []fleet.LabelIdent{{ + LabelName: fooLabel.Name, + LabelID: fooLabel.ID, + }}, policies[0].LabelsIncludeAny) teamPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) require.NoError(t, err) @@ -1724,6 +1753,10 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { assert.Equal(t, "some other resolution", *teamPolicies[0].Resolution) assert.Equal(t, "darwin", teamPolicies[0].Platform) assert.True(t, teamPolicies[0].CalendarEventsEnabled) + assert.Equal(t, []fleet.LabelIdent{{ + LabelName: barLabel.Name, + LabelID: barLabel.ID, + }}, teamPolicies[0].LabelsExcludeAny) assert.Equal(t, "query3", teamPolicies[1].Name) assert.Equal(t, "select 3;", teamPolicies[1].Query) @@ -1753,12 +1786,13 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { // Make sure apply is idempotent require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{ { - Name: "query1" + unicode, - Query: "select 1;", - Description: "query1 desc", - Resolution: "some resolution", - Team: "", - Platform: "", + Name: "query1" + unicode, + Query: "select 1;", + Description: "query1 desc", + Resolution: "some resolution", + Team: "", + Platform: "", + LabelsIncludeAny: []string{fooLabel.Name}, }, { Name: "query2", @@ -1768,6 +1802,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { Team: "team1", Platform: "darwin", CalendarEventsEnabled: true, + LabelsExcludeAny: []string{barLabel.Name}, }, { Name: "query3", @@ -1800,12 +1835,13 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { // Test policy updating. require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{ { - Name: "query1" + unicodeEq, - Query: "select 1 from updated;", - Description: "query1 desc updated", - Resolution: "some resolution updated", - Team: "", // No error, team did not change - Platform: "", + Name: "query1" + unicodeEq, + Query: "select 1 from updated;", + Description: "query1 desc updated", + Resolution: "some resolution updated", + Team: "", // No error, team did not change + Platform: "", + LabelsExcludeAny: []string{fooLabel.Name, barLabel.Name}, }, { Name: "query2", @@ -1830,6 +1866,14 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { assert.Equal(t, "some resolution updated", *policies[0].Resolution) assert.Equal(t, "", policies[0].Platform) assert.False(t, policies[0].CalendarEventsEnabled) + assert.Contains(t, policies[0].LabelsExcludeAny, fleet.LabelIdent{ + LabelName: fooLabel.Name, + LabelID: fooLabel.ID, + }) + assert.Contains(t, policies[0].LabelsExcludeAny, fleet.LabelIdent{ + LabelName: barLabel.Name, + LabelID: barLabel.ID, + }) teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}) require.NoError(t, err) @@ -1844,6 +1888,8 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { require.NotNil(t, teamPolicies[0].Resolution) assert.Equal(t, "some other resolution updated", *teamPolicies[0].Resolution) assert.Equal(t, "windows", teamPolicies[0].Platform) + assert.Nil(t, teamPolicies[0].LabelsIncludeAny) + assert.Nil(t, teamPolicies[0].LabelsExcludeAny) // Creating the same policy for a different team is allowed. require.NoError( @@ -2205,7 +2251,6 @@ func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) { } func testPoliciesSave(t *testing.T, ds *Datastore) { - requireLabels := func(t *testing.T, expected []string, actual []fleet.LabelIdent) { actualLabels := make([]string, 0, len(actual)) for _, label := range actual { diff --git a/server/fleet/policies.go b/server/fleet/policies.go index d430f5e85a..b662366863 100644 --- a/server/fleet/policies.go +++ b/server/fleet/policies.go @@ -361,7 +361,9 @@ type PolicySpec struct { SoftwareTitleID *uint `json:"software_title_id"` // ScriptID is the ID of the script associated with this policy (team policies only). // When editing a policy, if this is nil or 0 then the script ID is unset from the policy. - ScriptID *uint `json:"script_id"` + ScriptID *uint `json:"script_id"` + LabelsIncludeAny []string `json:"labels_include_any,omitempty"` + LabelsExcludeAny []string `json:"labels_exclude_any,omitempty"` } // PolicySoftwareTitle contains software title data for policies. diff --git a/server/service/global_policies.go b/server/service/global_policies.go index 8a52d7e05d..70d01361f9 100644 --- a/server/service/global_policies.go +++ b/server/service/global_policies.go @@ -533,6 +533,24 @@ func (svc *Service) ApplyPolicySpecs(ctx context.Context, policies []*fleet.Poli Message: fmt.Sprintf("policy spec payload verification: %s", err), }) } + + // Make sure any applied labels exist. + labels := policy.LabelsIncludeAny + labels = append(labels, policy.LabelsExcludeAny...) + if len(labels) > 0 { + labelsMap, err := svc.ds.LabelsByName(ctx, labels) + if err != nil { + return ctxerr.Wrap(ctx, err, "getting labels by name") + } + for _, label := range labels { + if _, ok := labelsMap[label]; !ok { + return ctxerr.Wrap(ctx, &fleet.BadRequestError{ + Message: fmt.Sprintf("label %q does not exist", label), + }) + } + } + } + } // An empty string indicates there are no duplicate names. @@ -551,6 +569,7 @@ func (svc *Service) ApplyPolicySpecs(ctx context.Context, policies []*fleet.Poli policies[i].Critical = false } } + if err := svc.ds.ApplyPolicySpecs(ctx, vc.UserID(), policies); err != nil { return ctxerr.Wrap(ctx, err, "applying policy specs") } diff --git a/server/service/global_policies_test.go b/server/service/global_policies_test.go index db6ddacc66..a0153dfd24 100644 --- a/server/service/global_policies_test.go +++ b/server/service/global_policies_test.go @@ -257,3 +257,66 @@ func TestApplyPolicySpecsReturnsErrorOnDuplicatePolicyNamesInSpecs(t *testing.T) require.ErrorAs(t, err, &badRequestError) require.Equal(t, "duplicate policy names not allowed", badRequestError.Message) } + +func TestApplyPolicySpecsLabelsValidation(t *testing.T) { + ds := new(mock.Store) + ds.NewGlobalPolicyFunc = func(ctx context.Context, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) { + return &fleet.Policy{}, nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + ds.NewActivityFunc = func( + ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time, + ) error { + return nil + } + ds.ApplyPolicySpecsFunc = func(ctx context.Context, authorID uint, specs []*fleet.PolicySpec) error { + return nil + } + ds.LabelsByNameFunc = func(ctx context.Context, names []string) (map[string]*fleet.Label, error) { + labels := make(map[string]*fleet.Label, len(names)) + for _, name := range names { + if name == "foo" { + labels["foo"] = &fleet.Label{ + Name: "foo", + ID: 1, + } + } + } + return labels, nil + } + + svc, ctx := newTestService(t, ds, nil, nil) + + testAdmin := fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{}, + GlobalRole: ptr.String(fleet.RoleAdmin), + } + viewerCtx := viewer.NewContext(ctx, viewer.Viewer{User: &testAdmin}) + + // Test that a query spec with a label that exists doesn't return an error + err := svc.ApplyPolicySpecs(viewerCtx, []*fleet.PolicySpec{ + { + Name: "test query", + Query: "select 1", + LabelsIncludeAny: []string{"foo"}, + Platform: "darwin,windows", + }, + }) + // Check that no error is returned + require.NoError(t, err) + + // Test that a query spec with a label that doesn't exist returns an error. + err = svc.ApplyPolicySpecs(viewerCtx, []*fleet.PolicySpec{ + { + Name: "test query", + Query: "select 1", + LabelsIncludeAny: []string{"nope"}, + Platform: "darwin,windows", + }, + }) + + require.Error(t, err) +}