From b265d56d732d5b9efabe915f1f5950b9cd360347 Mon Sep 17 00:00:00 2001 From: Jahziel Villasana-Espinoza Date: Tue, 31 Oct 2023 12:29:09 -0400 Subject: [PATCH] feat: reset yes/no count when query changes (#14776) # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] Added/updated tests - [x] Manual QA for all new/changed functionality --- server/datastore/mysql/policies.go | 24 +++++++++++- server/datastore/mysql/policies_test.go | 26 +++++++++---- server/fleet/datastore.go | 2 +- server/mock/datastore_mock.go | 6 +-- server/service/global_policies_test.go | 2 +- server/service/integration_core_test.go | 50 +++++++++++++++++++++++++ server/service/team_policies.go | 8 +++- server/service/team_policies_test.go | 2 +- 8 files changed, 105 insertions(+), 15 deletions(-) diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index df3e484282..d932c806c4 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -108,7 +108,7 @@ func policyDB(ctx context.Context, q sqlx.QueryerContext, id uint, teamID *uint) // SavePolicy updates some fields of the given policy on the datastore. // // Currently SavePolicy does not allow updating the team of an existing policy. -func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error { +func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error { sql := ` UPDATE policies SET name = ?, query = ?, description = ?, resolution = ?, platforms = ?, critical = ? @@ -126,6 +126,9 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error { return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID)) } + if shouldRemoveAllPolicyMemberships { + return cleanupPolicyMembership(ctx, ds.writer(ctx), p.ID) + } return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform) } @@ -782,6 +785,25 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo return ctxerr.Wrap(ctx, err, "cleanup policy membership") } +// cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints. +// Used when we want to remove all policy membership. +func cleanupPolicyMembership(ctx context.Context, db sqlx.ExecerContext, policyID uint) error { + delStmt := ` + DELETE + pm + FROM + policy_membership pm + LEFT JOIN + hosts h + ON + pm.host_id = h.id + WHERE + pm.policy_id = ?` + + _, err := db.ExecContext(ctx, delStmt, policyID) + return ctxerr.Wrap(ctx, err, "cleanup policy membership") +} + // CleanupPolicyMembership deletes the host's membership from policies that // have been updated recently if those hosts don't meet the policy's criteria // anymore (e.g. if the policy's platforms has been updated from "any" - the diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 215acb56ca..76506e4381 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -1321,7 +1321,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { Name: "non-existent query", Query: "select 1;", }, - }) + }, false) require.Error(t, err) var nfe *notFoundError require.True(t, errors.As(err, &nfe)) @@ -1359,7 +1359,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { gp2 := *gp gp2.Name = "global query updated" gp2.Critical = true - err = ds.SavePolicy(ctx, &gp2) + err = ds.SavePolicy(ctx, &gp2, false) require.NoError(t, err) gp, err = ds.Policy(ctx, gp.ID) require.NoError(t, err) @@ -1373,12 +1373,24 @@ func testPoliciesSave(t *testing.T, ds *Datastore) { tp2.Description = "team1 query desc updated" tp2.Resolution = ptr.String("team1 query resolution updated") tp2.Critical = false - err = ds.SavePolicy(ctx, &tp2) + err = ds.SavePolicy(ctx, &tp2, true) require.NoError(t, err) tp1, err = ds.Policy(ctx, tp1.ID) tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps require.NoError(t, err) require.Equal(t, tp1, &tp2) + + loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id = ?`, tp2.ID) + require.NoError(t, err) + + type polHostIDs struct { + PolicyID uint `db:"policy_id"` + HostID uint `db:"host_id"` + } + var rows []polHostIDs + err = ds.writer(context.Background()).SelectContext(context.Background(), &rows, loadMembershipStmt, args...) + require.NoError(t, err) + require.Len(t, rows, 0) } func testPoliciesDelUser(t *testing.T, ds *Datastore) { @@ -1712,9 +1724,9 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { } // updating without change works fine - err = ds.SavePolicy(ctx, polsByName["g1"]) + err = ds.SavePolicy(ctx, polsByName["g1"], false) require.NoError(t, err) - err = ds.SavePolicy(ctx, polsByName["t2"]) + err = ds.SavePolicy(ctx, polsByName["t2"], false) require.NoError(t, err) // apply specs that result in an update (without change) works fine err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{ @@ -1766,7 +1778,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { g1 := polsByName["g1"] g1.Platform = "linux" polsByName["g1"] = g1 - err = ds.SavePolicy(ctx, g1) + err = ds.SavePolicy(ctx, g1, false) require.NoError(t, err) wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID} assertPolicyMembership(t, ds, polsByName, wantHostsByPol) @@ -1775,7 +1787,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { t1 := polsByName["t1"] t1.Platform = "windows,darwin" polsByName["t1"] = t1 - err = ds.SavePolicy(ctx, t1) + err = ds.SavePolicy(ctx, t1, false) require.NoError(t, err) wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID} assertPolicyMembership(t, ds, polsByName, wantHostsByPol) diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 0adc9c80a8..451dedfdb0 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -521,7 +521,7 @@ type Datastore interface { // SavePolicy updates some fields of the given policy on the datastore. // // It is also used to update team policies. - SavePolicy(ctx context.Context, p *Policy) error + SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool) error ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 7a1aeab64a..1645433d20 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -382,7 +382,7 @@ type PolicyFunc func(ctx context.Context, id uint) (*fleet.Policy, error) type PolicyByNameFunc func(ctx context.Context, name string) (*fleet.Policy, error) -type SavePolicyFunc func(ctx context.Context, p *fleet.Policy) error +type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error type ListGlobalPoliciesFunc func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) @@ -3041,11 +3041,11 @@ func (s *DataStore) PolicyByName(ctx context.Context, name string) (*fleet.Polic return s.PolicyByNameFunc(ctx, name) } -func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy) error { +func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error { s.mu.Lock() s.SavePolicyFuncInvoked = true s.mu.Unlock() - return s.SavePolicyFunc(ctx, p) + return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships) } func (s *DataStore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) { diff --git a/server/service/global_policies_test.go b/server/service/global_policies_test.go index bdaf2e429f..615c0600f4 100644 --- a/server/service/global_policies_test.go +++ b/server/service/global_policies_test.go @@ -68,7 +68,7 @@ func TestGlobalPoliciesAuth(t *testing.T) { ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { return nil } - ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy) error { + ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error { return nil } ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 1398a01aa9..4858210c07 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1847,6 +1847,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { require.NotNil(t, mgpResp.Policy.Resolution) assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution) assert.Equal(t, "darwin", mgpResp.Policy.Platform) + assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount) + assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount) ggpResp := getPolicyByIDResponse{} s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), getPolicyByIDRequest{}, http.StatusOK, &ggpResp) @@ -1857,6 +1859,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { require.NotNil(t, ggpResp.Policy.Resolution) assert.Equal(t, "some global resolution updated", *ggpResp.Policy.Resolution) assert.Equal(t, "darwin", mgpResp.Policy.Platform) + assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount) + assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount) policiesResponse := listGlobalPoliciesResponse{} s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse) @@ -1867,6 +1871,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { require.NotNil(t, policiesResponse.Policies[0].Resolution) assert.Equal(t, "some global resolution updated", *policiesResponse.Policies[0].Resolution) assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform) + assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount) + assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount) listHostsURL := fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID) listHostsResp := listHostsResponse{} @@ -1880,6 +1886,11 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) require.Len(t, listHostsResp.Hosts, 0) + listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 0) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false)) require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false)) @@ -1888,6 +1899,45 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() { s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) require.Len(t, listHostsResp.Hosts, 1) + mgpParams = modifyGlobalPolicyRequest{ + ModifyPolicyPayload: fleet.ModifyPolicyPayload{ + Query: ptr.String("select * from users;"), + }, + } + mgpResp = modifyGlobalPolicyResponse{} + s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), mgpParams, http.StatusOK, &mgpResp) + require.NotNil(t, gpResp.Policy) + assert.Equal(t, "TestQuery4", mgpResp.Policy.Name) + assert.Equal(t, "select * from users;", mgpResp.Policy.Query) + assert.Equal(t, "Some description updated", mgpResp.Policy.Description) + require.NotNil(t, mgpResp.Policy.Resolution) + assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution) + assert.Equal(t, "darwin", mgpResp.Policy.Platform) + assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount) + assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount) + + listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 0) + + listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 0) + + policiesResponse = listGlobalPoliciesResponse{} + s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse) + require.Len(t, policiesResponse.Policies, 1) + assert.Equal(t, "TestQuery4", policiesResponse.Policies[0].Name) + assert.Equal(t, "select * from users;", policiesResponse.Policies[0].Query) + assert.Equal(t, "Some description updated", policiesResponse.Policies[0].Description) + require.NotNil(t, policiesResponse.Policies[0].Resolution) + assert.Equal(t, "some global resolution updated", *policiesResponse.Policies[0].Resolution) + assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform) + assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount) + assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount) + deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}} deletePolicyResp := deleteGlobalPoliciesResponse{} s.DoJSON("POST", "/api/latest/fleet/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp) diff --git a/server/service/team_policies.go b/server/service/team_policies.go index 483bae292f..40187769ba 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -366,6 +366,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f }) } + var shouldRemoveAll bool if p.Name != nil { policy.Name = *p.Name } @@ -373,6 +374,11 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f policy.Description = *p.Description } if p.Query != nil { + if policy.Query != *p.Query { + shouldRemoveAll = true + policy.FailingHostCount = 0 + policy.PassingHostCount = 0 + } policy.Query = *p.Query } if p.Resolution != nil { @@ -386,7 +392,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f } logging.WithExtras(ctx, "name", policy.Name, "sql", policy.Query) - err = svc.ds.SavePolicy(ctx, policy) + err = svc.ds.SavePolicy(ctx, policy, shouldRemoveAll) if err != nil { return nil, ctxerr.Wrap(ctx, err, "saving policy") } diff --git a/server/service/team_policies_test.go b/server/service/team_policies_test.go index fa3105a6d3..0adc05db36 100644 --- a/server/service/team_policies_test.go +++ b/server/service/team_policies_test.go @@ -44,7 +44,7 @@ func TestTeamPoliciesAuth(t *testing.T) { } return nil, nil } - ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy) error { + ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error { return nil } ds.DeleteTeamPoliciesFunc = func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {