fleet/server/datastore/mysql/policies_test.go
Jonathan Katz 0d15fd6cd6
Override patch policy query (#42322)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #41815
### Changes
- Extracted patch policy creation to `pkg/patch_policy`
- Added a `patch_query` column to the `software_installers` table
- By default that column is empty, and patch policies will generate with
the default query if so
- On app manifest ingestion, the appropriate entry in
`software_installers` will save the override "patch" query from the
manifest in patch_query

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

- [ ] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements), JS
inline code is prevented especially for url redirects, and untrusted
data interpolated into shell scripts/commands is validated against shell
metacharacters.
- [ ] If paths of existing endpoints are modified without backwards
compatibility, checked the frontend/CLI for any necessary changes

## Testing

- [x] Added/updated automated tests
- [ ] Where appropriate, [automated tests simulate multiple hosts and
test for host
isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing)
(updates to one hosts's records do not affect another)

- [ ] QA'd all new/changed functionality manually
- Relied on integration test for FMA version pinning

## Database migrations

- [x] Checked schema for all modified table for columns that will
auto-update timestamps during migration.
- [ ] Confirmed that updating the timestamps is acceptable, and will not
cause unwanted side effects.
- [x] Ensured the correct collation is explicitly set for character
columns (`COLLATE utf8mb4_unicode_ci`).
2026-03-25 10:32:41 -04:00

7905 lines
292 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"context"
"crypto/md5" //nolint:gosec // (only used for tests)
"encoding/hex"
"errors"
"fmt"
"sort"
"strconv"
"strings"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
common_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestPolicies(t *testing.T) {
ds := CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *Datastore)
}{
{"NewGlobalPolicyLegacy", testPoliciesNewGlobalPolicyLegacy},
{"NewGlobalPolicyProprietary", testPoliciesNewGlobalPolicyProprietary},
{"GlobalPolicyPendingScriptsAndInstalls", testGlobalPolicyPendingScriptsAndInstalls},
{"MembershipViewDeferred", func(t *testing.T, ds *Datastore) { testPoliciesMembershipView(true, t, ds) }},
{"MembershipViewNotDeferred", func(t *testing.T, ds *Datastore) { testPoliciesMembershipView(false, t, ds) }},
{"TeamPolicyLegacy", testTeamPolicyLegacy},
{"TeamPolicyProprietary", testTeamPolicyProprietary},
{"TeamPolicyPendingScriptsAndInstalls", testTeamPolicyPendingScriptsAndInstalls},
{"ListMergedTeamPolicies", testListMergedTeamPolicies},
{"PolicyQueriesForHost", testPolicyQueriesForHost},
{"PolicyQueriesForHostPlatforms", testPolicyQueriesForHostPlatforms},
{"PoliciesByID", testPoliciesByID},
{"TeamPolicyTransfer", testTeamPolicyTransfer},
{"ApplyPolicySpec", testApplyPolicySpec},
{"ApplyPolicySpecWithQueryPlatformChanges", testApplyPolicySpecWithQueryPlatformChanges},
{"Save", testPoliciesSave},
{"DelUser", testPoliciesDelUser},
{"FlippingPoliciesForHost", testFlippingPoliciesForHost},
{"PlatformUpdate", testPolicyPlatformUpdate},
{"CleanupPolicyMembership", testPolicyCleanupPolicyMembership},
{"DeleteAllPolicyMemberships", testDeleteAllPolicyMemberships},
{"PolicyViolationDays", testPolicyViolationDays},
{"IncreasePolicyAutomationIteration", testIncreasePolicyAutomationIteration},
{"OutdatedAutomationBatch", testOutdatedAutomationBatch},
{"TestListGlobalPoliciesCanPaginate", testListGlobalPoliciesCanPaginate},
{"TestListTeamPoliciesCanPaginate", testListTeamPoliciesCanPaginate},
{"TestCountPolicies", testCountPolicies},
{"TestUpdatePolicyHostCounts", testUpdatePolicyHostCounts},
{"TestCachedPolicyCountDeletesOnPolicyChange", testCachedPolicyCountDeletesOnPolicyChange},
{"TestPoliciesListOptions", testPoliciesListOptions},
{"TestPoliciesNameUnicode", testPoliciesNameUnicode},
{"TestPoliciesNameEmoji", testPoliciesNameEmoji},
{"TestPoliciesNameSort", testPoliciesNameSort},
{"TestGetCalendarPolicies", testGetCalendarPolicies},
{"GetTeamHostsPolicyMemberships", testGetTeamHostsPolicyMemberships},
{"GetTeamHostsPolicyMembershipsEmailPriority", testGetTeamHostsPolicyMembershipsEmailPriority},
{"TestPoliciesNewGlobalPolicyWithInstaller", testNewGlobalPolicyWithInstaller},
{"TestPoliciesTeamPoliciesWithInstaller", testTeamPoliciesWithInstaller},
{"TestPoliciesTeamPoliciesWithVPP", testTeamPoliciesWithVPP},
{"ApplyPolicySpecWithInstallers", testApplyPolicySpecWithInstallers},
{"TestPoliciesNewGlobalPolicyWithScript", testNewGlobalPolicyWithScript},
{"TestPoliciesTeamPoliciesWithScript", testTeamPoliciesWithScript},
{"TeamPoliciesNoTeam", testTeamPoliciesNoTeam},
{"TestPoliciesBySoftwareTitleID", testPoliciesBySoftwareTitleID},
{"TestClearAutoInstallPolicyStatusForHost", testClearAutoInstallPolicyStatusForHost},
{"PolicyLabels", testPolicyLabels},
{"PolicyLabelMembershipCleanup", testPolicyLabelMembershipCleanup},
{"DeletePolicyWithSoftwareActivatesNextActivity", testDeletePolicyWithSoftwareActivatesNextActivity},
{"DeletePolicyWithScriptActivatesNextActivity", testDeletePolicyWithScriptActivatesNextActivity},
{"SimultaneousSavePolicy", testSimultaneousSavePolicy},
{"IsPolicyFailing", testIsPolicyFailing},
{"ResetAttemptsOnFailingToPassingSync", testResetAttemptsOnFailingToPassingSync},
{"ResetAttemptsOnFailingToPassingAsync", testResetAttemptsOnFailingToPassingAsync},
{"PolicyModificationResetsAttemptNumber", testPolicyModificationResetsAttemptNumber},
{"TeamPatchPolicy", testTeamPatchPolicy},
{"TeamPolicyAutomationFilter", testTeamPolicyAutomationFilter},
{"BatchedPolicyMembershipCleanup", testBatchedPolicyMembershipCleanup},
{"BatchedPolicyMembershipCleanupOnPolicyUpdate", testBatchedPolicyMembershipCleanupOnPolicyUpdate},
{"ApplyPolicySpecsNeedsFullMembershipCleanupFlag", testApplyPolicySpecsNeedsFullMembershipCleanupFlag},
{"CleanupPolicyMembershipCrashRecovery", testCleanupPolicyMembershipCrashRecovery},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer TruncateTables(t, ds)
c.fn(t, ds)
})
}
}
func testPoliciesNewGlobalPolicyLegacy(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
q, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
p, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
})
require.NoError(t, err)
assert.Equal(t, "query1", p.Name)
assert.Equal(t, "query1 desc", p.Description)
assert.Equal(t, "select 1;", p.Query)
require.NotNil(t, p.AuthorID)
assert.Equal(t, user1.ID, *p.AuthorID)
q2, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 42;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
_, err = ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
QueryID: &q2.ID,
})
require.NoError(t, err)
policies, err := ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 2)
assert.Equal(t, q.Name, policies[0].Name)
assert.Equal(t, q.Query, policies[0].Query)
assert.Equal(t, q.Description, policies[0].Description)
assert.Equal(t, q2.Name, policies[1].Name)
assert.Equal(t, q2.Query, policies[1].Query)
assert.Equal(t, q2.Description, policies[1].Description)
require.NotNil(t, policies[1].AuthorID)
assert.Equal(t, user1.ID, *policies[1].AuthorID)
// The original query can be removed as the policy owns it's own query.
require.NoError(t, ds.DeleteQuery(context.Background(), nil, q.Name))
_, err = ds.DeleteGlobalPolicies(context.Background(), []uint{policies[0].ID, policies[1].ID})
require.NoError(t, err)
policies, err = ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 0)
}
func testPoliciesNewGlobalPolicyProprietary(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
ctx := context.Background()
p, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
assert.Equal(t, "query1", p.Name)
assert.Equal(t, "query1 desc", p.Description)
assert.Equal(t, "select 1;", p.Query)
require.NotNil(t, p.Resolution)
assert.Equal(t, "query1 resolution", *p.Resolution)
require.NotNil(t, p.AuthorID)
assert.Equal(t, user1.ID, *p.AuthorID)
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "query2",
Query: "select 2;",
Description: "query2 desc",
Resolution: "query2 resolution",
})
require.NoError(t, err)
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 2)
assert.Equal(t, "query1", policies[0].Name)
assert.Equal(t, "select 1;", policies[0].Query)
assert.Equal(t, "query1 desc", policies[0].Description)
require.NotNil(t, policies[0].Resolution)
assert.Equal(t, "query1 resolution", *policies[0].Resolution)
require.NotNil(t, policies[0].AuthorID)
assert.Equal(t, user1.ID, *policies[0].AuthorID)
assert.Equal(t, "query2", policies[1].Name)
assert.Equal(t, "select 2;", policies[1].Query)
assert.Equal(t, "query2 desc", policies[1].Description)
require.NotNil(t, policies[1].Resolution)
assert.Equal(t, "query2 resolution", *policies[1].Resolution)
require.NotNil(t, policies[1].AuthorID)
assert.Equal(t, user1.ID, *policies[1].AuthorID)
// Can't create a global policy with an existing name.
p3, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 3;",
Description: "query1 other description",
Resolution: "query1 other resolution",
})
require.Error(t, err)
var isExist interface {
IsExists() bool
}
require.True(t, errors.As(err, &isExist) && isExist.IsExists())
require.Nil(t, p3)
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policies[0].ID, policies[1].ID})
require.NoError(t, err)
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 0)
// Now the name is available and we can create the global policy.
p3, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 3;",
Description: "query1 other description",
Resolution: "query1 other resolution",
})
require.NoError(t, err)
assert.Equal(t, "query1", p3.Name)
assert.Equal(t, "select 3;", p3.Query)
assert.Equal(t, "query1 other description", p3.Description)
require.NotNil(t, p3.Resolution)
assert.Equal(t, "query1 other resolution", *p3.Resolution)
require.NotNil(t, p3.AuthorID)
assert.Equal(t, user1.ID, *p3.AuthorID)
}
func testGlobalPolicyPendingScriptsAndInstalls(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
// create a new script and associate with global policy
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo",
TeamID: nil,
})
require.NoError(t, err)
policy1, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE policies SET script_id = ?", script.ID)
return err
})
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
// create pending script execution
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host1.ID,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &policy1.ID,
SyncRequest: true,
ScriptID: &script.ID,
})
require.NoError(t, err)
pendingScripts, err := ds.ListPendingHostScriptExecutions(ctx, policy1.ID, false)
require.NoError(t, err)
require.Equal(t, 1, len(pendingScripts))
// delete the policy
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policy1.ID})
require.NoError(t, err)
pendingScripts, err = ds.ListPendingHostScriptExecutions(ctx, policy1.ID, false)
require.NoError(t, err)
require.Equal(t, 0, len(pendingScripts))
// create a new installer and associate with global policy
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
UninstallScript: "goodbye",
InstallerFile: tfr1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "query2",
Query: "select 1;",
Description: "query2 desc",
Resolution: "query2 resolution",
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE policies SET software_installer_id = ?", installerID)
return err
})
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
// create a pending software install request
_, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID, fleet.HostSoftwareInstallOptions{PolicyID: &policy2.ID})
require.NoError(t, err)
pendingInstalls, err := ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 1, len(pendingInstalls))
// delete the policy
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policy2.ID})
require.NoError(t, err)
pendingInstalls, err = ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 0, len(pendingInstalls))
}
func testPoliciesListOptions(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
ctx := context.Background()
_, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "apple",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "banana",
Query: "select 1;",
Description: "query2 desc",
Resolution: "query2 resolution",
})
require.NoError(t, err)
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "cherry",
Query: "select 1;",
Description: "query3 desc",
Resolution: "query3 resolution",
})
require.NoError(t, err)
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "apple pie",
Query: "select 1;",
Description: "query4 desc",
Resolution: "query4 resolution",
})
require.NoError(t, err)
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "rotten apple",
Query: "select 1;",
Description: "query5 desc",
Resolution: "query5 resolution",
})
require.NoError(t, err)
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{MatchQuery: "apple", OrderKey: "name", OrderDirection: fleet.OrderAscending})
require.NoError(t, err)
require.Len(t, policies, 3)
assert.Equal(t, "apple", policies[0].Name)
assert.Equal(t, "apple pie", policies[1].Name)
assert.Equal(t, "rotten apple", policies[2].Name)
}
func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("1234"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("1"),
UUID: "1",
Hostname: "foo.local",
})
require.NoError(t, err)
host2, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("5679"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("2"),
UUID: "2",
Hostname: "bar.local",
})
require.NoError(t, err)
q, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
p, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
})
require.NoError(t, err)
assert.Equal(t, "query1", p.Name)
assert.Equal(t, "select 1;", p.Query)
assert.Equal(t, "query1 desc", p.Description)
require.NotNil(t, p.AuthorID)
assert.Equal(t, user1.ID, *p.AuthorID)
q2, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 42;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
p2, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
QueryID: &q2.ID,
})
require.NoError(t, err)
assert.Equal(t, "query2", p2.Name)
assert.Equal(t, "select 42;", p2.Query)
assert.Equal(t, "query2 desc", p2.Description)
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 2)
assert.Equal(t, p.ID, policies[0].ID)
assert.Equal(t, uint(2), policies[0].PassingHostCount)
assert.Equal(t, uint(0), policies[0].FailingHostCount)
assert.Equal(t, p2.ID, policies[1].ID)
assert.Equal(t, uint(0), policies[1].PassingHostCount)
assert.Equal(t, uint(0), policies[1].FailingHostCount)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 2)
assert.Equal(t, p.ID, policies[0].ID)
assert.Equal(t, uint(1), policies[0].PassingHostCount)
assert.Equal(t, uint(1), policies[0].FailingHostCount)
assert.Equal(t, p2.ID, policies[1].ID)
assert.Equal(t, uint(0), policies[1].PassingHostCount)
assert.Equal(t, uint(1), policies[1].FailingHostCount)
policy, err := ds.Policy(ctx, policies[0].ID)
require.NoError(t, err)
assert.Equal(t, policies[0], policy)
queries, err := ds.PolicyQueriesForHost(ctx, host1)
require.NoError(t, err)
require.Len(t, queries, 2)
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)])
// create a couple teams and team-specific policies
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
t1pol, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "team1pol",
Query: "SELECT 1",
})
require.NoError(t, err)
t2pol, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "team2pol",
Query: "SELECT 2",
})
require.NoError(t, err)
t2pol2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "team2pol2",
Query: "SELECT 3",
})
require.NoError(t, err)
// create hosts in each team
host3, err := ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("3"),
fleet.WithEnrollOsqueryNodeKey("3"),
fleet.WithEnrollOsqueryTeamID(&team1.ID),
)
require.NoError(t, err)
host4, err := ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("4"),
fleet.WithEnrollOsqueryNodeKey("4"),
fleet.WithEnrollOsqueryTeamID(&team2.ID),
)
require.NoError(t, err)
host5, err := ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("5"),
fleet.WithEnrollOsqueryNodeKey("5"),
fleet.WithEnrollOsqueryTeamID(&team2.ID),
)
require.NoError(t, err)
// create some policy results
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: ptr.Bool(true), p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: ptr.Bool(false), t2pol2.ID: ptr.Bool(true), p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: ptr.Bool(true), t2pol2.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
t1Pols, t1Inherited, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, t1Pols, 1)
assert.Equal(t, uint(1), t1Pols[0].PassingHostCount)
assert.Equal(t, uint(0), t1Pols[0].FailingHostCount)
require.Len(t, t1Inherited, 2)
require.Equal(t, p.ID, t1Inherited[0].ID)
assert.Equal(t, uint(1), t1Inherited[0].PassingHostCount)
assert.Equal(t, uint(0), t1Inherited[0].FailingHostCount)
require.Equal(t, p2.ID, t1Inherited[1].ID)
assert.Equal(t, uint(0), t1Inherited[1].PassingHostCount)
assert.Equal(t, uint(1), t1Inherited[1].FailingHostCount)
t2Pols, t2Inherited, err := ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, t2Pols, 2)
require.Equal(t, t2pol.ID, t2Pols[0].ID)
assert.Equal(t, uint(1), t2Pols[0].PassingHostCount)
assert.Equal(t, uint(1), t2Pols[0].FailingHostCount)
require.Equal(t, t2pol2.ID, t2Pols[1].ID)
assert.Equal(t, uint(2), t2Pols[1].PassingHostCount)
assert.Equal(t, uint(0), t2Pols[1].FailingHostCount)
require.Len(t, t2Inherited, 2)
require.Equal(t, p.ID, t2Inherited[0].ID)
assert.Equal(t, uint(0), t2Inherited[0].PassingHostCount)
assert.Equal(t, uint(1), t2Inherited[0].FailingHostCount)
require.Equal(t, p2.ID, t2Inherited[1].ID)
assert.Equal(t, uint(1), t2Inherited[1].PassingHostCount)
assert.Equal(t, uint(0), t2Inherited[1].FailingHostCount)
}
func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
q, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
q2, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 1;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
prevPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, prevPolicies, 0)
_, err = ds.NewTeamPolicy(ctx, 99999999, &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
})
require.Error(t, err)
p, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
Resolution: "some resolution",
})
require.NoError(t, err)
assert.Equal(t, "query1", p.Name)
assert.Equal(t, "select 1;", p.Query)
assert.Equal(t, "query1 desc", p.Description)
require.NotNil(t, p.AuthorID)
assert.Equal(t, user1.ID, *p.AuthorID)
require.NotNil(t, p.Resolution)
assert.Equal(t, "some resolution", *p.Resolution)
gpol, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "global_1",
Query: "SELECT 1",
})
require.NoError(t, err)
globalPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, globalPolicies, 1)
p2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &q2.ID,
})
require.NoError(t, err)
assert.Equal(t, "query2", p2.Name)
assert.Equal(t, "select 1;", p2.Query)
assert.Equal(t, "query2 desc", p2.Description)
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
teamPolicies, inherited1, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, q.Name, teamPolicies[0].Name)
assert.Equal(t, q.Query, teamPolicies[0].Query)
assert.Equal(t, q.Description, teamPolicies[0].Description)
require.NotNil(t, teamPolicies[0].AuthorID)
require.Equal(t, user1.ID, *teamPolicies[0].AuthorID)
require.Len(t, inherited1, 1)
require.Equal(t, gpol, inherited1[0])
team2Policies, inherited2, err := ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team2Policies, 1)
assert.Equal(t, q2.Name, team2Policies[0].Name)
assert.Equal(t, q2.Query, team2Policies[0].Query)
assert.Equal(t, q2.Description, team2Policies[0].Description)
require.NotNil(t, team2Policies[0].AuthorID)
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
require.Len(t, inherited2, 1)
require.Equal(t, gpol, inherited2[0])
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{teamPolicies[0].ID})
require.NoError(t, err)
teamPolicies, inherited1, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 0)
require.Len(t, inherited1, 1)
}
func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
ctx := context.Background()
requireLabels := func(t *testing.T, expected []string, actual []fleet.LabelIdent) {
actualLabels := make([]string, 0, len(actual))
for _, label := range actual {
actualLabels = append(actualLabels, label.LabelName)
}
require.Equal(t, expected, actualLabels)
}
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
label1, err := ds.NewLabel(ctx, &fleet.Label{Name: "label1"})
require.NoError(t, err)
label2, err := ds.NewLabel(ctx, &fleet.Label{Name: "label2"})
require.NoError(t, err)
gpol, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "existing-query-global-1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
LabelsIncludeAny: []string{label1.Name, label2.Name},
})
require.NoError(t, err)
requireLabels(t, []string{label1.Name, label2.Name}, gpol.LabelsIncludeAny)
// Cannot create a policy with inclusive and exclusive labels set
gpol1, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "global-query-bad-both-labels",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
LabelsExcludeAny: []string{label1.Name},
LabelsIncludeAny: []string{label2.Name},
})
require.Error(t, err)
require.Nil(t, gpol1)
// Cannot create policy with invalid label set
gpol1, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "global-query-invalid-label",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
LabelsExcludeAny: []string{"invalid"},
})
require.Error(t, err)
require.Nil(t, gpol1)
prevPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, prevPolicies, 1)
requireLabels(t, []string{label1.Name, label2.Name}, prevPolicies[0].LabelsIncludeAny)
require.Equal(t, gpol, prevPolicies[0])
// team does not exist
_, err = ds.NewTeamPolicy(ctx, 99999999, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.Error(t, err)
p, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
CalendarEventsEnabled: true,
LabelsExcludeAny: []string{label1.Name, label2.Name},
})
require.NoError(t, err)
// Can't create a team policy with same team id and name.
_, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
})
require.Error(t, err)
var isExist interface {
IsExists() bool
}
require.True(t, errors.As(err, &isExist) && isExist.IsExists(), err)
// Can't create a global policy with an existing global name.
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "existing-query-global-1",
Query: "select 1;",
})
require.Error(t, err)
require.True(t, errors.As(err, &isExist) && isExist.IsExists(), err)
assert.Equal(t, "query1", p.Name)
assert.Equal(t, "select 1;", p.Query)
assert.Equal(t, "query1 desc", p.Description)
require.NotNil(t, p.Resolution)
assert.Equal(t, "query1 resolution", *p.Resolution)
require.NotNil(t, p.AuthorID)
assert.Equal(t, user1.ID, *p.AuthorID)
assert.True(t, p.CalendarEventsEnabled)
requireLabels(t, []string{label1.Name, label2.Name}, p.LabelsExcludeAny)
globalPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, globalPolicies, len(prevPolicies))
p2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "query2",
Query: "select 2;",
Description: "query2 desc",
Resolution: "query2 resolution",
})
require.NoError(t, err)
assert.Equal(t, "query2", p2.Name)
assert.Equal(t, "select 2;", p2.Query)
assert.Equal(t, "query2 desc", p2.Description)
require.NotNil(t, p2.Resolution)
assert.Equal(t, "query2 resolution", *p2.Resolution)
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
teamPolicies, inherited1, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, "query1", teamPolicies[0].Name)
assert.Equal(t, "select 1;", teamPolicies[0].Query)
assert.Equal(t, "query1 desc", teamPolicies[0].Description)
require.NotNil(t, teamPolicies[0].Resolution)
assert.Equal(t, "query1 resolution", *teamPolicies[0].Resolution)
require.NotNil(t, teamPolicies[0].AuthorID)
require.Equal(t, user1.ID, *teamPolicies[0].AuthorID)
requireLabels(t, []string{label1.Name, label2.Name}, teamPolicies[0].LabelsExcludeAny)
require.Len(t, inherited1, 1)
require.Equal(t, gpol, inherited1[0])
team2Policies, inherited2, err := ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team2Policies, 1)
assert.Equal(t, "query2", team2Policies[0].Name)
assert.Equal(t, "select 2;", team2Policies[0].Query)
assert.Equal(t, "query2 desc", team2Policies[0].Description)
require.NotNil(t, team2Policies[0].Resolution)
assert.Equal(t, "query2 resolution", *team2Policies[0].Resolution)
require.NotNil(t, team2Policies[0].AuthorID)
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
require.Len(t, inherited2, 1)
require.Equal(t, gpol, inherited2[0])
// Can't create a policy with the same name on the same team.
p3, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 2;",
Description: "query2 other description",
Resolution: "query2 other resolution",
})
require.Error(t, err)
require.Nil(t, p3)
// Can't create a policy with both include and excldue any labels
p3, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query-bothlabel",
Query: "select 2;",
Description: "query2 other description",
Resolution: "query2 other resolution",
LabelsExcludeAny: []string{label1.Name},
LabelsIncludeAny: []string{label2.Name},
})
require.Error(t, err)
require.Nil(t, p3)
// Can't create a policy with a non-existant label
p3, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query-nolabel",
Query: "select 2;",
Description: "query2 other description",
Resolution: "query2 other resolution",
LabelsExcludeAny: []string{"invalid"},
})
require.Error(t, err)
require.Nil(t, p3)
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{teamPolicies[0].ID})
require.NoError(t, err)
teamPolicies, inherited1, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 0)
require.Len(t, inherited1, 1)
require.Equal(t, gpol, inherited1[0])
// Now the name is available and we can create the policy in the team.
_, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 2;",
Description: "query2 other description",
Resolution: "query2 other resolution",
})
require.NoError(t, err)
teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, "query1", teamPolicies[0].Name)
assert.Equal(t, "select 2;", teamPolicies[0].Query)
assert.Equal(t, "query2 other description", teamPolicies[0].Description)
require.NotNil(t, teamPolicies[0].Resolution)
assert.Equal(t, "query2 other resolution", *teamPolicies[0].Resolution)
require.NotNil(t, team2Policies[0].AuthorID)
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
}
func testTeamPolicyPendingScriptsAndInstalls(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// create a script and associate it with a team policy
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo",
TeamID: &team1.ID,
})
require.NoError(t, err)
policy1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "team1",
Query: "select 1;",
Description: "description",
Resolution: "resolution",
ScriptID: &script.ID,
})
require.NoError(t, err)
// create pending script execution
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host1.ID,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &policy1.ID,
SyncRequest: true,
ScriptID: &script.ID,
})
require.NoError(t, err)
pendingScripts, err := ds.ListPendingHostScriptExecutions(ctx, policy1.ID, false)
require.NoError(t, err)
require.Equal(t, 1, len(pendingScripts))
// delete the policy
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{policy1.ID})
require.NoError(t, err)
pendingScripts, err = ds.ListPendingHostScriptExecutions(ctx, policy1.ID, false)
require.NoError(t, err)
require.Equal(t, 0, len(pendingScripts))
// create a software install and associate it with a team policy
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
UninstallScript: "goodbye",
InstallerFile: tfr1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user.ID,
TeamID: &team2.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy2, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{
Name: "team2",
Query: "select 1;",
Description: "description2",
Resolution: "resolution2",
SoftwareInstallerID: &installerID,
})
require.NoError(t, err)
// create a pending software install request
_, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID, fleet.HostSoftwareInstallOptions{PolicyID: &policy2.ID})
require.NoError(t, err)
pendingInstalls, err := ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 1, len(pendingInstalls))
// delete the policy
_, err = ds.DeleteTeamPolicies(ctx, team2.ID, []uint{policy2.ID})
require.NoError(t, err)
pendingInstalls, err = ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 0, len(pendingInstalls))
}
func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
ctx := context.Background()
gpol, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{
Name: "query1 global",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team1policy, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "query2 team1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
team2policy, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{
Name: "query3 team2",
Query: "select 2;",
Description: "query2 desc",
Resolution: "query2 resolution",
})
require.NoError(t, err)
// Test list options affect both global and team policies
merged, err := ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderDescending,
}, "")
require.NoError(t, err)
require.Len(t, merged, 2)
assert.Equal(t, team1policy.ID, merged[0].ID)
assert.Equal(t, gpol.ID, merged[1].ID)
// Test filter
merged, err = ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{
MatchQuery: "query1",
}, "")
require.NoError(t, err)
require.Len(t, merged, 1)
assert.Equal(t, gpol.ID, merged[0].ID)
merged, err = ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{
MatchQuery: "query2",
}, "")
require.NoError(t, err)
require.Len(t, merged, 1)
assert.Equal(t, team1policy.ID, merged[0].ID)
// Test HostPolicyCounts
// Global Host
host, err := ds.NewHost(context.Background(),
&fleet.Host{OsqueryHostID: ptr.String("host1"), NodeKey: ptr.String(fmt.Sprint("host1", 1)), TeamID: nil})
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{gpol.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(context.Background())
require.NoError(t, err)
// team 1 shows no host counts
merged, err = ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{
OrderKey: "name",
}, "")
require.NoError(t, err)
require.Len(t, merged, 2)
assert.Equal(t, gpol.ID, merged[0].ID)
assert.Equal(t, uint(0), merged[0].PassingHostCount)
assert.Equal(t, uint(0), merged[0].FailingHostCount)
assert.Equal(t, team1policy.ID, merged[1].ID)
assert.Equal(t, uint(0), merged[1].PassingHostCount)
assert.Equal(t, uint(0), merged[1].FailingHostCount)
// move host to team1
err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host.ID}))
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{team1policy.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(context.Background())
require.NoError(t, err)
// team 1 shows host counts
merged, err = ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{
OrderKey: "name",
}, "")
require.NoError(t, err)
require.Len(t, merged, 2)
assert.Equal(t, gpol.ID, merged[0].ID)
assert.Equal(t, uint(1), merged[0].PassingHostCount)
assert.Equal(t, uint(0), merged[0].FailingHostCount)
assert.Equal(t, team1policy.ID, merged[1].ID)
assert.Equal(t, uint(1), merged[1].PassingHostCount)
assert.Equal(t, uint(0), merged[1].FailingHostCount)
// team2 shows no host counts
merged, err = ds.ListMergedTeamPolicies(ctx, team2.ID, fleet.ListOptions{
OrderKey: "name",
}, "")
require.NoError(t, err)
require.Len(t, merged, 2)
assert.Equal(t, gpol.ID, merged[0].ID)
assert.Equal(t, uint(0), merged[0].PassingHostCount)
assert.Equal(t, uint(0), merged[0].FailingHostCount)
assert.Equal(t, team2policy.ID, merged[1].ID)
assert.Equal(t, uint(0), merged[1].PassingHostCount)
assert.Equal(t, uint(0), merged[1].FailingHostCount)
}
func newTestHostWithPlatform(t *testing.T, ds *Datastore, hostname, platform string, teamID *uint) *fleet.Host {
nodeKey, err := server.GenerateRandomText(32)
require.NoError(t, err)
host, err := ds.NewHost(context.Background(), &fleet.Host{
OsqueryHostID: ptr.String(uuid.NewString()),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &nodeKey,
UUID: uuid.NewString(),
Hostname: hostname,
Platform: platform,
OSVersion: "15.4.1",
ComputerName: hostname,
})
require.NoError(t, err)
if teamID != nil {
err := ds.AddHostsToTeam(context.Background(), fleet.NewAddHostsToTeamParams(teamID, []uint{host.ID}))
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.NoError(t, err)
}
return host
}
func newTestPolicy(t *testing.T, ds *Datastore, user *fleet.User, name, platforms string, teamID *uint) *fleet.Policy {
query := fmt.Sprintf("select %s;", name)
if teamID == nil {
gp, err := ds.NewGlobalPolicy(context.Background(), &user.ID, fleet.PolicyPayload{
Name: name,
Query: query,
Platform: platforms,
})
require.NoError(t, err)
return gp
}
tp, err := ds.NewTeamPolicy(context.Background(), *teamID, &user.ID, fleet.PolicyPayload{
Name: name,
Query: query,
Platform: platforms,
})
require.NoError(t, err)
return tp
}
type expectedPolicyResults struct {
policyQueries map[string]string
hostPolicies []*fleet.HostPolicy
}
func expectedPolicyQueries(policies ...*fleet.Policy) expectedPolicyResults {
queries := make(map[string]string)
for _, policy := range policies {
queries[fmt.Sprint(policy.ID)] = policy.Query
}
hostPolicies := make([]*fleet.HostPolicy, len(policies))
for i := range policies {
hostPolicies[i] = &fleet.HostPolicy{
PolicyData: policies[i].PolicyData,
}
}
return expectedPolicyResults{
policyQueries: queries,
hostPolicies: hostPolicies,
}
}
func testPolicyQueriesForHostPlatforms(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"})
require.NoError(t, err)
// Global hosts:
var global *uint
host1GlobalUbuntu := newTestHostWithPlatform(t, ds, "host1_global_ubuntu", "ubuntu", global)
host2GlobalDarwin := newTestHostWithPlatform(t, ds, "host2_global_darwin", "darwin", global)
host3GlobalWindows := newTestHostWithPlatform(t, ds, "host3_global_windows", "windows", global)
host4GlobalEmpty := newTestHostWithPlatform(t, ds, "host4_global_empty_platform", "", global)
// team1 hosts:
host1t1Rhel := newTestHostWithPlatform(t, ds, "host1_team1_ubuntu", "rhel", &team1.ID)
host2t1Darwin := newTestHostWithPlatform(t, ds, "host2_team1_darwin", "darwin", &team1.ID)
host3t1Windows := newTestHostWithPlatform(t, ds, "host3_team1_windows", "windows", &team1.ID)
host4t1Empty := newTestHostWithPlatform(t, ds, "host4_team1_empty_platform", "", &team1.ID)
// team2 hosts
host1t2Debian := newTestHostWithPlatform(t, ds, "host1_team2_ubuntu", "debian", &team2.ID)
host2t2Darwin := newTestHostWithPlatform(t, ds, "host2_team2_darwin", "darwin", &team2.ID)
host3t2Windows := newTestHostWithPlatform(t, ds, "host3_team2_windows", "windows", &team2.ID)
host4t2Empty := newTestHostWithPlatform(t, ds, "host4_team2_empty_platform", "", &team2.ID)
// Global policies:
policy1GlobalLinuxDarwin := newTestPolicy(t, ds, user1, "policy1_global_linux_darwin", "linux,darwin", global)
policy2GlobalWindows := newTestPolicy(t, ds, user1, "policy2_global_windows", "windows", global)
policy3GlobalAll := newTestPolicy(t, ds, user1, "policy3_global_all", "", global)
// Team1 policies:
policy1t1Darwin := newTestPolicy(t, ds, user1, "policy1_team1_darwin", "darwin", &team1.ID)
policy2t1Windows := newTestPolicy(t, ds, user1, "policy2_team1_windows", "windows", &team1.ID)
policy3t1All := newTestPolicy(t, ds, user1, "policy3_team1_all", "", &team1.ID)
// Team2 policies:
policy1t2LinuxDarwin := newTestPolicy(t, ds, user1, "policy1_team2_linux_darwin", "linux,darwin", &team2.ID)
policy2t2Windows := newTestPolicy(t, ds, user1, "policy2_team2_windows", "windows", &team2.ID)
policy3t2All1 := newTestPolicy(t, ds, user1, "policy3_team2_all1", "linux,darwin,windows", &team2.ID)
policy4t2All2 := newTestPolicy(t, ds, user1, "policy4_team2_all2", "", &team2.ID)
for _, tc := range []struct {
host *fleet.Host
expectedPolicies expectedPolicyResults
}{
{
host: host1GlobalUbuntu,
expectedPolicies: expectedPolicyQueries(
policy1GlobalLinuxDarwin,
policy3GlobalAll,
),
},
{
host: host2GlobalDarwin,
expectedPolicies: expectedPolicyQueries(
policy1GlobalLinuxDarwin,
policy3GlobalAll,
),
},
{
host: host3GlobalWindows,
expectedPolicies: expectedPolicyQueries(
policy2GlobalWindows,
policy3GlobalAll,
),
},
{
host: host4GlobalEmpty,
expectedPolicies: expectedPolicyQueries(
policy3GlobalAll,
),
},
{
host: host1t1Rhel,
expectedPolicies: expectedPolicyQueries(
policy1GlobalLinuxDarwin,
policy3GlobalAll,
policy3t1All,
),
},
{
host: host2t1Darwin,
expectedPolicies: expectedPolicyQueries(
policy1GlobalLinuxDarwin,
policy3GlobalAll,
policy3t1All,
policy1t1Darwin,
),
},
{
host: host3t1Windows,
expectedPolicies: expectedPolicyQueries(
policy2GlobalWindows,
policy3GlobalAll,
policy3t1All,
policy2t1Windows,
),
},
{
host: host4t1Empty,
expectedPolicies: expectedPolicyQueries(
policy3GlobalAll,
policy3t1All,
),
},
{
host: host1t2Debian,
expectedPolicies: expectedPolicyQueries(
policy1GlobalLinuxDarwin,
policy3GlobalAll,
policy1t2LinuxDarwin,
policy3t2All1,
policy4t2All2,
),
},
{
host: host2t2Darwin,
expectedPolicies: expectedPolicyQueries(
policy1GlobalLinuxDarwin,
policy3GlobalAll,
policy1t2LinuxDarwin,
policy3t2All1,
policy4t2All2,
),
},
{
host: host3t2Windows,
expectedPolicies: expectedPolicyQueries(
policy2GlobalWindows,
policy3GlobalAll,
policy2t2Windows,
policy3t2All1,
policy4t2All2,
),
},
{
host: host4t2Empty,
expectedPolicies: expectedPolicyQueries(
policy3GlobalAll,
policy4t2All2,
),
},
} {
t.Run(tc.host.Hostname, func(t *testing.T) {
// PolicyQueriesForHost is the endpoint used by osquery agents when they check in.
queries, err := ds.PolicyQueriesForHost(context.Background(), tc.host)
require.NoError(t, err)
require.Equal(t, tc.expectedPolicies.policyQueries, queries)
// ListPoliciesForHost is the endpoint used by fleet UI/API clients.
hostPolicies, err := ds.ListPoliciesForHost(context.Background(), tc.host)
require.NoError(t, err)
sort.Slice(hostPolicies, func(i, j int) bool {
return hostPolicies[i].ID < hostPolicies[j].ID
})
sort.Slice(tc.expectedPolicies.hostPolicies, func(i, j int) bool {
return tc.expectedPolicies.hostPolicies[i].ID < tc.expectedPolicies.hostPolicies[j].ID
})
require.Equal(t, tc.expectedPolicies.hostPolicies, hostPolicies)
})
}
}
func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
host1, err := ds.NewHost(context.Background(), &fleet.Host{
OsqueryHostID: ptr.String("1234"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("1"),
UUID: "1",
Hostname: "foo.local",
})
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host1.ID})))
host1, err = ds.Host(context.Background(), host1.ID)
require.NoError(t, err)
host2, err := ds.NewHost(context.Background(), &fleet.Host{
OsqueryHostID: ptr.String("5679"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("2"),
UUID: "2",
Hostname: "bar.local",
})
require.NoError(t, err)
q, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
gp, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
Resolution: "some gp resolution",
})
require.NoError(t, err)
q2, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 42;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
tp, err := ds.NewTeamPolicy(context.Background(), team1.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &q2.ID,
Resolution: "some other gp resolution",
ConditionalAccessEnabled: true,
})
require.NoError(t, err)
queries, err := ds.PolicyQueriesForHost(context.Background(), host1)
require.NoError(t, err)
require.Len(t, queries, 2)
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)])
queries, err = ds.PolicyQueriesForHost(context.Background(), host2)
require.NoError(t, err)
require.Len(t, queries, 1)
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
// Team policy ran with failing result.
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{tp.ID: ptr.Bool(false), gp.ID: nil}, time.Now(), false))
policies, err := ds.ListPoliciesForHost(context.Background(), host1)
require.NoError(t, err)
require.Len(t, policies, 2)
checkGlobaPolicy := func(policy *fleet.HostPolicy) {
assert.Equal(t, "query1", policy.Name)
assert.Equal(t, "select 1;", policy.Query)
assert.Equal(t, "query1 desc", policy.Description)
require.NotNil(t, policy.AuthorID)
assert.Equal(t, user1.ID, *policy.AuthorID)
assert.Equal(t, "Alice", policy.AuthorName)
assert.Equal(t, "alice@example.com", policy.AuthorEmail)
assert.NotNil(t, policy.Resolution)
assert.Equal(t, "some gp resolution", *policy.Resolution)
assert.False(t, policy.ConditionalAccessEnabled)
}
// Failing policy is listed first.
assert.Equal(t, "fail", policies[0].Response)
assert.Equal(t, "query2", policies[0].Name)
assert.Equal(t, "select 42;", policies[0].Query)
assert.Equal(t, "query2 desc", policies[0].Description)
require.NotNil(t, policies[0].AuthorID)
assert.Equal(t, user1.ID, *policies[0].AuthorID)
assert.Equal(t, "Alice", policies[0].AuthorName)
assert.Equal(t, "alice@example.com", policies[0].AuthorEmail)
assert.NotNil(t, policies[0].Resolution)
assert.Equal(t, "some other gp resolution", *policies[0].Resolution)
assert.True(t, policies[0].ConditionalAccessEnabled)
checkGlobaPolicy(policies[1])
assert.Equal(t, "", policies[1].Response)
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
require.NoError(t, err)
require.Len(t, policies, 1)
checkGlobaPolicy(policies[0])
assert.Equal(t, "", policies[0].Response)
// Global policy ran with passing result.
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{gp.ID: ptr.Bool(true)}, time.Now(), false))
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
require.NoError(t, err)
require.Len(t, policies, 1)
checkGlobaPolicy(policies[0])
assert.Equal(t, "pass", policies[0].Response)
// Manually insert a global policy with null resolution.
res, err := ds.writer(context.Background()).ExecContext(
context.Background(),
fmt.Sprintf(`INSERT INTO policies (name, query, description, checksum) VALUES (?, ?, ?, %s)`, policiesChecksumComputedColumn()),
q.Name+"2", q.Query, q.Description+"2",
)
require.NoError(t, err)
id, err := res.LastInsertId()
require.NoError(t, err)
require.NoError(t,
ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{uint(id): nil}, //nolint:gosec // dismiss G115
time.Now(), false))
policies, err = ds.ListPoliciesForHost(context.Background(), host2)
require.NoError(t, err)
require.Len(t, policies, 2)
// Global policy with null resolution is listed first, followed by passing policy.
assert.Equal(t, "query1 desc2", policies[0].Description)
assert.NotNil(t, policies[0].Resolution)
assert.Empty(t, *policies[0].Resolution)
assert.NotNil(t, policies[1].Resolution)
assert.Equal(t, "some gp resolution", *policies[1].Resolution)
}
func testPoliciesByID(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
policy1 := newTestPolicy(t, ds, user1, "policy1", "darwin", nil)
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
policy2 := newTestPolicy(t, ds, user1, "policy2", "darwin", &team1.ID)
host1 := newTestHostWithPlatform(t, ds, "host1", "darwin", nil)
// Associate an installer to policy2
installer, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy2.SoftwareInstallerID = ptr.Uint(installerID)
err = ds.SavePolicy(context.Background(), policy2, false, false)
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{policy1.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.UpdateHostPolicyCounts(context.Background()))
policiesByID, err := ds.PoliciesByID(context.Background(), []uint{1, 2})
require.NoError(t, err)
assert.Equal(t, len(policiesByID), 2)
assert.Equal(t, policiesByID[1].ID, policy1.ID)
assert.Equal(t, policiesByID[1].Name, policy1.Name)
assert.Nil(t, policiesByID[1].SoftwareInstallerID)
assert.Equal(t, uint(1), policiesByID[1].PassingHostCount)
assert.Equal(t, policiesByID[2].ID, uint(2))
assert.Equal(t, policiesByID[2].Name, "policy2")
assert.NotNil(t, policiesByID[2].SoftwareInstallerID)
assert.Equal(t, uint(1), *policiesByID[2].SoftwareInstallerID)
_, err = ds.PoliciesByID(context.Background(), []uint{1, 2, 3})
require.Error(t, err)
var nfe fleet.NotFoundError
require.ErrorAs(t, err, &nfe)
}
func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"})
require.NoError(t, err)
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("1234"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("1"),
UUID: "1",
Hostname: "foo.local",
})
require.NoError(t, err)
host2, err := ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("2"),
fleet.WithEnrollOsqueryNodeKey("2"),
fleet.WithEnrollOsqueryTeamID(&team1.ID),
)
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host1.ID})))
host1, err = ds.Host(ctx, host1.ID)
require.NoError(t, err)
tq, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
team1Policy, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &tq.ID,
})
require.NoError(t, err)
gq, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 2;",
Saved: true,
Logging: fleet.LoggingSnapshot,
})
require.NoError(t, err)
globalPolicy, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
QueryID: &gq.ID,
})
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
checkPassingCount := func(tm1, tm1Inherited, tm2Inherited, global uint) {
t.Helper()
require.NoError(t, ds.UpdateHostPolicyCounts(ctx))
policies, inherited, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, tm1, policies[0].PassingHostCount)
require.Len(t, inherited, 1)
assert.Equal(t, tm1Inherited, inherited[0].PassingHostCount)
policies, inherited, err = ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, policies, 0) // team 2 has no policies of its own
require.Len(t, inherited, 1)
assert.Equal(t, tm2Inherited, inherited[0].PassingHostCount)
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, global, policies[0].PassingHostCount)
}
// both hosts belong to team1 and pass the team and the global policy
checkPassingCount(2, 2, 0, 2)
// team policies are removed when AddHostsToTeam is called
require.NoError(t, ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(ptr.Uint(team2.ID), []uint{host1.ID})))
// host2 passes tm1 and the global (so team1's inherited too), host1 passes the team2's inherited and the global
checkPassingCount(1, 1, 1, 2)
// all host policies are removed when a host is enrolled in the same team
_, err = ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("2"),
fleet.WithEnrollOsqueryNodeKey("2"),
fleet.WithEnrollOsqueryTeamID(&team1.ID),
)
require.NoError(t, err)
checkPassingCount(0, 0, 1, 1)
// team policies are removed if the host is enrolled in a different team
_, err = ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("2"),
fleet.WithEnrollOsqueryNodeKey("2"),
fleet.WithEnrollOsqueryTeamID(&team2.ID),
)
require.NoError(t, err)
// both hosts are now in team2
checkPassingCount(0, 0, 1, 1)
// team policies are removed if the host is re-enrolled without a team
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
checkPassingCount(1, 0, 2, 2)
// all host policies are removed when a host is re-enrolled
_, err = ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("2"),
fleet.WithEnrollOsqueryNodeKey("2"),
)
require.NoError(t, err)
checkPassingCount(0, 0, 1, 1)
}
func testApplyPolicySpec(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
ctx := context.Background()
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
unicode, _ := strconv.Unquote(`"\uAC00"`) // 가
unicodeEq, _ := strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ
// Add a user-defined label
fooLabel, err := ds.NewLabel(
context.Background(),
&fleet.Label{
Name: "Foo",
Query: "select 1",
LabelType: fleet.LabelTypeRegular,
LabelMembershipType: fleet.LabelMembershipTypeManual,
},
)
require.NoError(t, err)
barLabel, err := ds.NewLabel(
context.Background(),
&fleet.Label{
Name: "Bar",
Query: "select 1",
LabelType: fleet.LabelTypeRegular,
LabelMembershipType: fleet.LabelMembershipTypeManual,
},
)
require.NoError(t, err)
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "query1" + unicodeEq,
Query: "select 1;",
Description: "query1 desc",
Resolution: "some resolution",
Team: "",
Platform: "",
LabelsIncludeAny: []string{fooLabel.Name},
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query2",
Query: "select 2;",
Description: "query2 desc",
Resolution: "some other resolution",
Team: "team1",
Platform: "darwin",
CalendarEventsEnabled: true,
LabelsExcludeAny: []string{barLabel.Name},
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query3",
Query: "select 3;",
Description: "query3 desc",
Resolution: "some other good resolution",
Team: "team1",
Platform: "windows,linux",
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query4",
Query: "select 4;",
Description: "query4 desc",
Resolution: "some other good resolution 2",
Team: "No team",
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
}))
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, "query1"+unicode, policies[0].Name)
assert.Equal(t, "select 1;", policies[0].Query)
assert.Equal(t, "query1 desc", policies[0].Description)
require.NotNil(t, policies[0].AuthorID)
assert.Equal(t, user1.ID, *policies[0].AuthorID)
require.NotNil(t, policies[0].Resolution)
assert.Equal(t, "some resolution", *policies[0].Resolution)
assert.Equal(t, "", policies[0].Platform)
assert.Equal(t, []fleet.LabelIdent{{
LabelName: fooLabel.Name,
LabelID: fooLabel.ID,
}}, policies[0].LabelsIncludeAny)
assert.Equal(t, policies[0].Type, fleet.PolicyTypeDynamic)
teamPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
assert.Equal(t, "query2", teamPolicies[0].Name)
assert.Equal(t, "select 2;", teamPolicies[0].Query)
assert.Equal(t, "query2 desc", teamPolicies[0].Description)
require.NotNil(t, teamPolicies[0].AuthorID)
assert.Equal(t, user1.ID, *teamPolicies[0].AuthorID)
require.NotNil(t, teamPolicies[0].Resolution)
assert.Equal(t, "some other resolution", *teamPolicies[0].Resolution)
assert.Equal(t, "darwin", teamPolicies[0].Platform)
assert.True(t, teamPolicies[0].CalendarEventsEnabled)
assert.Equal(t, []fleet.LabelIdent{{
LabelName: barLabel.Name,
LabelID: barLabel.ID,
}}, teamPolicies[0].LabelsExcludeAny)
assert.Equal(t, "query3", teamPolicies[1].Name)
assert.Equal(t, "select 3;", teamPolicies[1].Query)
assert.Equal(t, "query3 desc", teamPolicies[1].Description)
require.NotNil(t, teamPolicies[1].AuthorID)
assert.Equal(t, user1.ID, *teamPolicies[1].AuthorID)
require.NotNil(t, teamPolicies[1].Resolution)
assert.Equal(t, "some other good resolution", *teamPolicies[1].Resolution)
assert.Equal(t, "windows,linux", teamPolicies[1].Platform)
assert.False(t, teamPolicies[1].CalendarEventsEnabled)
noTeamPolicies, _, err := ds.ListTeamPolicies(ctx, fleet.PolicyNoTeamID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, noTeamPolicies, 1)
assert.Equal(t, "query4", noTeamPolicies[0].Name)
assert.Equal(t, "select 4;", noTeamPolicies[0].Query)
assert.Equal(t, "query4 desc", noTeamPolicies[0].Description)
require.NotNil(t, noTeamPolicies[0].AuthorID)
assert.Equal(t, user1.ID, *noTeamPolicies[0].AuthorID)
require.NotNil(t, noTeamPolicies[0].Resolution)
assert.Equal(t, "some other good resolution 2", *noTeamPolicies[0].Resolution)
assert.Equal(t, "", noTeamPolicies[0].Platform)
assert.False(t, noTeamPolicies[0].CalendarEventsEnabled)
assert.NotNil(t, noTeamPolicies[0].TeamID)
assert.Zero(t, *noTeamPolicies[0].TeamID)
// Make sure apply is idempotent
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "query1" + unicode,
Query: "select 1;",
Description: "query1 desc",
Resolution: "some resolution",
Team: "",
Platform: "",
LabelsIncludeAny: []string{fooLabel.Name},
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query2",
Query: "select 2;",
Description: "query2 desc",
Resolution: "some other resolution",
Team: "team1",
Platform: "darwin",
CalendarEventsEnabled: true,
LabelsExcludeAny: []string{barLabel.Name},
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query3",
Query: "select 3;",
Description: "query3 desc",
Resolution: "some other good resolution",
Team: "team1",
Platform: "windows,linux",
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query4",
Query: "select 4;",
Description: "query4 desc",
Resolution: "some other good resolution 2",
Team: "No team",
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
}))
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
noTeamPolicies, _, err = ds.ListTeamPolicies(ctx, fleet.PolicyNoTeamID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, noTeamPolicies, 1)
// Test policy updating.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "query1" + unicodeEq,
Query: "select 1 from updated;",
Description: "query1 desc updated",
Resolution: "some resolution updated",
Team: "", // No error, team did not change
Platform: "",
LabelsExcludeAny: []string{fooLabel.Name, barLabel.Name},
Type: fleet.PolicyTypeDynamic,
},
{
Name: "query2",
Query: "select 2 from updated;",
Description: "query2 desc updated",
Resolution: "some other resolution updated",
Team: "team1", // No error, team did not change
Platform: "windows",
CalendarEventsEnabled: false,
Type: fleet.PolicyTypeDynamic,
},
}))
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, "query1"+unicode, policies[0].Name)
assert.Equal(t, "select 1 from updated;", policies[0].Query)
assert.Equal(t, "query1 desc updated", policies[0].Description)
require.NotNil(t, policies[0].AuthorID)
assert.Equal(t, user1.ID, *policies[0].AuthorID)
require.NotNil(t, policies[0].Resolution)
assert.Equal(t, "some resolution updated", *policies[0].Resolution)
assert.Equal(t, "", policies[0].Platform)
assert.False(t, policies[0].CalendarEventsEnabled)
assert.Contains(t, policies[0].LabelsExcludeAny, fleet.LabelIdent{
LabelName: fooLabel.Name,
LabelID: fooLabel.ID,
})
assert.Contains(t, policies[0].LabelsExcludeAny, fleet.LabelIdent{
LabelName: barLabel.Name,
LabelID: barLabel.ID,
})
teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
assert.Equal(t, "query2", teamPolicies[0].Name)
assert.Equal(t, "select 2 from updated;", teamPolicies[0].Query)
assert.Equal(t, "query2 desc updated", teamPolicies[0].Description)
require.NotNil(t, teamPolicies[0].AuthorID)
assert.Equal(t, user1.ID, *teamPolicies[0].AuthorID)
assert.Equal(t, team1.ID, *teamPolicies[0].TeamID)
require.NotNil(t, teamPolicies[0].Resolution)
assert.Equal(t, "some other resolution updated", *teamPolicies[0].Resolution)
assert.Equal(t, "windows", teamPolicies[0].Platform)
assert.Nil(t, teamPolicies[0].LabelsIncludeAny)
assert.Nil(t, teamPolicies[0].LabelsExcludeAny)
// Creating the same policy for a different team is allowed.
require.NoError(
t, ds.ApplyPolicySpecs(
ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "query1" + unicode,
Query: "select 1 from updated again;",
Description: "query1 desc updated again",
Resolution: "some resolution updated again",
Team: "team1",
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
}))
}
func updatePolicyFailureCountsForHosts(ctx context.Context, ds *Datastore, hosts []*fleet.Host) ([]*fleet.Host, error) {
if len(hosts) == 0 {
return hosts, nil
}
// Get policy failure counts for each host
hostIDs := make([]uint, 0, len(hosts))
for _, host := range hosts {
hostIDs = append(hostIDs, host.ID)
}
query, args, err := sqlx.In(
`
SELECT
pm.host_id,
COUNT(*) AS failing_policy_count
FROM
policy_membership pm
WHERE
pm.passes = 0 AND
pm.host_id IN (?)
GROUP BY
pm.host_id
`, hostIDs,
)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "build policy failure count query")
}
var policyFailureCounts []struct {
HostID uint `db:"host_id"`
FailingPolicyCount uint64 `db:"failing_policy_count"`
}
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &policyFailureCounts, query, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "get policy failure counts for hosts")
}
// Map policy failure counts to hosts
hostIDToPolicyFailureCounts := make(map[uint]uint64)
for _, policyFailureCount := range policyFailureCounts {
hostIDToPolicyFailureCounts[policyFailureCount.HostID] = policyFailureCount.FailingPolicyCount
}
for _, host := range hosts {
host.TotalIssuesCount = hostIDToPolicyFailureCounts[host.ID]
host.FailingPoliciesCount = hostIDToPolicyFailureCounts[host.ID]
}
return hosts, nil
}
func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
ctx := context.Background()
unicode, _ := strconv.Unquote(`"\uAC00"`) // 가
unicodeEq, _ := strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1" + unicode})
require.NoError(t, err)
globalNames := []string{"global query1" + unicode, "global query2" + unicode, "global query3" + unicode}
teamNames := []string{"team query1", "team query2", "team query3"}
require.NoError(
t, ds.ApplyPolicySpecs(
ctx, user1.ID, []*fleet.PolicySpec{
{
Name: globalNames[0],
Query: "select 1;",
Team: "",
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
{
Name: globalNames[1],
Query: "select 2;",
Team: "",
Platform: "darwin",
Type: fleet.PolicyTypeDynamic,
},
{
Name: globalNames[2],
Query: "select 3;",
Team: "",
Platform: "darwin,linux",
Type: fleet.PolicyTypeDynamic,
},
{
Name: teamNames[0],
Query: "select 1;",
Team: "team1" + unicode,
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
{
Name: teamNames[1],
Query: "select 2;",
Team: "team1" + unicode,
Platform: "darwin",
Type: fleet.PolicyTypeDynamic,
},
{
Name: teamNames[2],
Query: "select 3;",
Team: "team1" + unicodeEq,
Platform: "darwin,linux",
Type: fleet.PolicyTypeDynamic,
},
},
),
)
// create hosts with different platforms, for that team
const hostWin, hostMac, hostDeb, hostLin = 0, 1, 2, 3
platforms := []string{"windows", "darwin", "debian", "linux"}
teamHosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(
ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
TeamID: ptr.Uint(team1.ID),
},
)
require.NoError(t, err)
teamHosts[i] = h
}
// create hosts with different platforms, without team
globalHosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("g%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(
ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
},
)
require.NoError(t, err)
globalHosts[i] = h
}
// load the global policies
gPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 3)
// load the team policies
tPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, tPolicies, 3)
// index the policies by name for easier access in the rest of the test
polsByName := make(map[string]*fleet.Policy, len(gPolicies)+len(tPolicies))
globalPolsByName := make(map[string]*fleet.Policy, len(gPolicies))
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range gPolicies {
globalPolsByName[pol.Name] = pol
polsByName[pol.Name] = pol
}
// record some results for each policy
// Note: we are adding results to hosts that shouldn't have results, based on their platform.
for _, h := range teamHosts {
res := make(map[uint]*bool, len(polsByName))
for _, pol := range polsByName {
res[pol.ID] = ptr.Bool(false)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
for _, h := range globalHosts {
res := make(map[uint]*bool, len(globalPolsByName))
for _, pol := range globalPolsByName {
res[pol.ID] = ptr.Bool(false)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
// Update host failure counts and ensure they are correct
teamHosts, err = updatePolicyFailureCountsForHosts(ctx, ds, teamHosts)
require.NoError(t, err)
assert.Equal(t, uint64(6), teamHosts[hostWin].FailingPoliciesCount)
assert.Equal(t, uint64(6), teamHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, uint64(6), teamHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, uint64(6), teamHosts[hostLin].FailingPoliciesCount)
globalHosts, err = updatePolicyFailureCountsForHosts(ctx, ds, globalHosts)
require.NoError(t, err)
assert.Equal(t, uint64(3), globalHosts[hostWin].FailingPoliciesCount)
assert.Equal(t, uint64(3), globalHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, uint64(3), globalHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, uint64(3), globalHosts[hostLin].FailingPoliciesCount)
// Ensure policy passing and failing counts are correct
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 3)
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, tPolicies, 3)
for _, pol := range gPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount)
assert.Equal(t, uint(8), polsByName[globalNames[1]].FailingHostCount)
assert.Equal(t, uint(8), polsByName[globalNames[2]].FailingHostCount)
assert.Equal(t, uint(4), polsByName[teamNames[0]].FailingHostCount)
assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount)
assert.Equal(t, uint(4), polsByName[teamNames[2]].FailingHostCount)
// Update policies
require.NoError(
t, ds.ApplyPolicySpecs(
ctx, user1.ID, []*fleet.PolicySpec{
{
Name: globalNames[0],
Query: "select 1;",
Team: "",
Platform: "",
Description: "updated", // update description
Type: fleet.PolicyTypeDynamic,
},
{
Name: globalNames[1],
Query: "select 2 updated;", // update query
Team: "",
Platform: "darwin",
Type: fleet.PolicyTypeDynamic,
},
{
Name: globalNames[2],
Query: "select 3;",
Team: "",
Platform: "darwin", // update platform
Type: fleet.PolicyTypeDynamic,
},
{
Name: "new global query",
Query: "select 4;",
Team: "",
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
{
Name: teamNames[0],
Query: "select 1;",
Team: "team1" + unicode,
Platform: "linux", // update platform
Type: fleet.PolicyTypeDynamic,
},
{
Name: teamNames[1],
Query: "select 2;",
Team: "team1" + unicode,
Platform: "darwin",
CalendarEventsEnabled: true, // update calendar events
Type: fleet.PolicyTypeDynamic,
},
{
Name: teamNames[2],
Query: "select 3 updated;", // update query
Team: "team1" + unicodeEq,
Platform: "darwin,linux",
Type: fleet.PolicyTypeDynamic,
},
{
Name: "new team query",
Query: "select 4;",
Team: "team1" + unicode,
Platform: "",
Type: fleet.PolicyTypeDynamic,
},
},
),
)
// Update host failure counts and ensure they are correct
teamHosts, err = updatePolicyFailureCountsForHosts(ctx, ds, teamHosts)
require.NoError(t, err)
assert.Equal(t, uint64(1), teamHosts[hostWin].FailingPoliciesCount) // kept result from globalNames[0]
assert.Equal(t, uint64(3), teamHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, uint64(2), teamHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, uint64(2), teamHosts[hostLin].FailingPoliciesCount)
globalHosts, err = updatePolicyFailureCountsForHosts(ctx, ds, globalHosts)
require.NoError(t, err)
assert.Equal(t, uint64(1), globalHosts[hostWin].FailingPoliciesCount)
assert.Equal(t, uint64(2), globalHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, uint64(1), globalHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, uint64(1), globalHosts[hostLin].FailingPoliciesCount)
// Ensure policy passing and failing counts are correct
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 4)
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, tPolicies, 4)
for _, pol := range gPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount)
assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query
assert.Equal(t, uint(0), polsByName[globalNames[2]].FailingHostCount) // updated platform
assert.Equal(t, uint(0), polsByName[teamNames[0]].FailingHostCount) // updated platform
assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount)
assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 4)
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, tPolicies, 4)
for _, pol := range gPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount) // platform is "" -- no change
assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query
assert.Equal(t, uint(2), polsByName[globalNames[2]].FailingHostCount) // updated platform
assert.Equal(t, uint(2), polsByName[teamNames[0]].FailingHostCount) // updated platform
assert.Equal(t, uint(1), polsByName[teamNames[1]].FailingHostCount) // platform is "darwin" -- no change
assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query
}
func testPoliciesSave(t *testing.T, ds *Datastore) {
requireLabels := func(t *testing.T, expected []string, actual []fleet.LabelIdent) {
actualLabels := make([]string, 0, len(actual))
for _, label := range actual {
actualLabels = append(actualLabels, label.LabelName)
}
require.Equal(t, expected, actualLabels)
}
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
ctx := context.Background()
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
label1, err := ds.NewLabel(ctx, &fleet.Label{Name: "label1"})
require.NoError(t, err)
label2, err := ds.NewLabel(ctx, &fleet.Label{Name: "label2"})
require.NoError(t, err)
err = ds.SavePolicy(ctx, &fleet.Policy{
PolicyData: fleet.PolicyData{
ID: 99999999,
Name: "non-existent query",
Query: "select 1;",
},
}, false, false,
)
require.Error(t, err)
var nfe *common_mysql.NotFoundError
require.True(t, errors.As(err, &nfe))
payload := fleet.PolicyPayload{
Name: "global query",
Query: "select 1;",
Description: "global query desc",
Resolution: "global query resolution",
LabelsIncludeAny: []string{label1.Name, label2.Name},
// TODO also find out where policies get selected and add the logic for checking labels
}
gp, err := ds.NewGlobalPolicy(ctx, &user1.ID, payload)
require.NoError(t, err)
require.Equal(t, gp.Name, payload.Name)
require.Equal(t, gp.Query, payload.Query)
require.Equal(t, gp.Description, payload.Description)
require.Equal(t, *gp.Resolution, payload.Resolution)
require.Equal(t, gp.Critical, payload.Critical)
requireLabels(t, []string{label1.Name, label2.Name}, gp.LabelsIncludeAny)
computeChecksum := func(policy fleet.Policy) string {
h := md5.New() //nolint:gosec // (only used for tests)
// Compute the same way as DB does.
teamStr := ""
if policy.TeamID != nil {
teamStr = fmt.Sprint(*policy.TeamID)
}
cols := []string{teamStr, policy.Name}
_, _ = fmt.Fprint(h, strings.Join(cols, "\x00"))
checksum := h.Sum(nil)
return hex.EncodeToString(checksum)
}
var globalChecksum []uint8
err = ds.writer(context.Background()).Get(&globalChecksum, `SELECT checksum FROM policies WHERE id = ?`, gp.ID)
require.NoError(t, err)
assert.Equal(t, computeChecksum(*gp), hex.EncodeToString(globalChecksum))
payload = fleet.PolicyPayload{
Name: "team1 query",
Query: "select 2;",
Description: "team1 query desc",
Resolution: "team1 query resolution",
Critical: true,
CalendarEventsEnabled: true,
LabelsExcludeAny: []string{label1.Name, label2.Name},
}
tp1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, payload)
require.NoError(t, err)
require.Equal(t, tp1.Name, payload.Name)
require.Equal(t, tp1.Query, payload.Query)
require.Equal(t, tp1.Description, payload.Description)
require.Equal(t, *tp1.Resolution, payload.Resolution)
require.Equal(t, tp1.Critical, payload.Critical)
assert.Equal(t, tp1.CalendarEventsEnabled, payload.CalendarEventsEnabled)
requireLabels(t, []string{label1.Name, label2.Name}, tp1.LabelsExcludeAny)
var teamChecksum []uint8
err = ds.writer(context.Background()).Get(&teamChecksum, `SELECT checksum FROM policies WHERE id = ?`, tp1.ID)
require.NoError(t, err)
assert.Equal(t, computeChecksum(*tp1), hex.EncodeToString(teamChecksum))
// Change name only of a global query.
gp2 := *gp
gp2.Name = "global query updated"
gp2.Critical = true
// Swap labels include to labels exclude
gp2.LabelsExcludeAny = gp2.LabelsIncludeAny
gp2.LabelsIncludeAny = nil
err = ds.SavePolicy(ctx, &gp2, false, false)
require.NoError(t, err)
gp, err = ds.Policy(ctx, gp.ID)
require.NoError(t, err)
require.Empty(t, gp.LabelsIncludeAny)
requireLabels(t, []string{label1.Name, label2.Name}, gp.LabelsExcludeAny)
gp2.UpdateCreateTimestamps = gp.UpdateCreateTimestamps
require.Equal(t, &gp2, gp)
var globalChecksum2 []uint8
err = ds.writer(context.Background()).Get(&globalChecksum2, `SELECT checksum FROM policies WHERE id = ?`, gp.ID)
require.NoError(t, err)
assert.NotEqual(t, globalChecksum, globalChecksum2, "Checksum should be different since policy name changed")
assert.Equal(t, computeChecksum(*gp), hex.EncodeToString(globalChecksum2))
// Cannot save a policy with both include and exclude labels
gp2.LabelsExcludeAny = []fleet.LabelIdent{{LabelName: label1.Name}}
gp2.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: label2.Name}}
err = ds.SavePolicy(ctx, &gp2, false, false)
require.Error(t, err)
// Change name, query, description and resolution of a team policy.
tp2 := *tp1
tp2.Name = "team1 query updated"
tp2.Query = "select 12;"
tp2.Description = "team1 query desc updated"
tp2.Resolution = ptr.String("team1 query resolution updated")
tp2.Critical = false
tp2.CalendarEventsEnabled = false
// Swap labels include and exclude
tp2.LabelsIncludeAny = tp2.LabelsExcludeAny
tp2.LabelsExcludeAny = nil
err = ds.SavePolicy(ctx, &tp2, true, true)
require.NoError(t, err)
tp1, err = ds.Policy(ctx, tp1.ID)
require.Empty(t, tp1.LabelsExcludeAny)
requireLabels(t, []string{label1.Name, label2.Name}, tp1.LabelsIncludeAny)
tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps
require.NoError(t, err)
require.Equal(t, tp1, &tp2)
var teamChecksum2 []uint8
err = ds.writer(context.Background()).Get(&teamChecksum2, `SELECT checksum FROM policies WHERE id = ?`, tp1.ID)
require.NoError(t, err)
assert.NotEqual(t, teamChecksum, teamChecksum2, "Checksum should be different since policy name changed")
assert.Equal(t, computeChecksum(*tp1), hex.EncodeToString(teamChecksum2))
loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id = ?`, tp2.ID)
require.NoError(t, err)
type polHostIDs struct {
PolicyID uint `db:"policy_id"`
HostID uint `db:"host_id"`
}
var rows []polHostIDs
err = ds.writer(context.Background()).SelectContext(context.Background(), &rows, loadMembershipStmt, args...)
require.NoError(t, err)
require.Len(t, rows, 0)
}
func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
teamHost, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-1"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("test-1"),
UUID: "test-1",
Hostname: "foo.local",
Platform: "windows",
})
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{teamHost.ID})))
globalHost, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-2"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("test-2"),
UUID: "test-2",
Hostname: "foo.local",
Platform: "windows",
})
require.NoError(t, err)
globalPolicy, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "global query",
Query: "select 1;",
Description: "global query desc",
Resolution: "global query resolution",
})
require.NoError(t, err)
teamPolicy, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "team query",
Query: "select 1;",
Description: "team query desc",
Resolution: "team query resolution",
})
require.NoError(t, err)
// teamHost and globalHost fail all policies
require.NoError(
t, ds.RecordPolicyQueryExecutions(
ctx, teamHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false,
),
)
require.NoError(
t, ds.RecordPolicyQueryExecutions(
ctx, teamHost, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), teamPolicy.ID: ptr.Bool(false)}, time.Now(), false,
),
)
require.NoError(
t, ds.RecordPolicyQueryExecutions(
ctx, globalHost, map[uint]*bool{globalPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(false)}, time.Now(), false,
),
)
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
globalPolicy, err = ds.Policy(ctx, globalPolicy.ID)
require.NoError(t, err)
assert.Equal(t, uint(2), globalPolicy.FailingHostCount)
teamPolicies, inheritedPolicies, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
require.Len(t, inheritedPolicies, 1)
assert.Equal(t, uint(1), teamPolicies[0].FailingHostCount)
assert.Equal(t, uint(1), inheritedPolicies[0].FailingHostCount)
var count uint64
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
assert.Equal(t, uint64(2), count)
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count = 2"))
assert.Equal(t, uint64(1), count)
// Update the global policy sql to trigger a cache invalidation
err = ds.SavePolicy(ctx, globalPolicy, true, true)
require.NoError(t, err)
globalPolicy, err = ds.Policy(ctx, globalPolicy.ID)
require.NoError(t, err)
assert.Equal(t, uint(0), globalPolicy.FailingHostCount)
teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
require.Len(t, inheritedPolicies, 1)
assert.Equal(t, uint(1), teamPolicies[0].FailingHostCount)
assert.Equal(t, uint(0), inheritedPolicies[0].FailingHostCount)
// Only the team host now has issues
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
assert.Equal(t, uint64(1), count)
require.NoError(
t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count = 1 AND host_id = ?", teamHost.ID),
)
assert.Equal(t, uint64(1), count)
// Update the team policy platform to trigger a cache invalidation
err = ds.SavePolicy(ctx, teamPolicy, false, true)
require.NoError(t, err)
teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
require.Len(t, inheritedPolicies, 1)
assert.Equal(t, uint(0), teamPolicies[0].FailingHostCount)
assert.Equal(t, uint(0), inheritedPolicies[0].FailingHostCount)
}
func testPoliciesDelUser(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
user2 := test.NewUser(t, ds, "User2", "user2@example.com", true)
ctx := context.Background()
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
gp, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "global query",
Query: "select 1;",
Description: "global query desc",
Resolution: "global query resolution",
})
require.NoError(t, err)
tp, err := ds.NewTeamPolicy(ctx, team1.ID, &user2.ID, fleet.PolicyPayload{
Name: "team1 query",
Query: "select 2;",
Description: "team1 query desc",
Resolution: "team1 query resolution",
})
require.NoError(t, err)
err = ds.DeleteUser(ctx, user1.ID)
require.NoError(t, err)
err = ds.DeleteUser(ctx, user2.ID)
require.NoError(t, err)
tp, err = ds.Policy(ctx, tp.ID)
require.NoError(t, err)
assert.Nil(t, tp.AuthorID)
assert.Equal(t, "<deleted>", tp.AuthorName)
assert.Empty(t, tp.AuthorEmail)
gp, err = ds.Policy(ctx, gp.ID)
require.NoError(t, err)
assert.Nil(t, gp.AuthorID)
assert.Equal(t, "<deleted>", gp.AuthorName)
assert.Empty(t, gp.AuthorEmail)
}
func testFlippingPoliciesForHost(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
ctx := context.Background()
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-1"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("test-1"),
UUID: "test-1",
Hostname: "foo.local",
Platform: "windows",
})
require.NoError(t, err)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
p1, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "policy1",
Query: "select 41;",
})
require.NoError(t, err)
p2, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy2",
Query: "select 42;",
})
require.NoError(t, err)
// Create some unused policy.
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "policy3",
Query: "select 43;",
})
require.NoError(t, err)
pfailed, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy_failed",
Query: "select * from unexistent_table;",
})
require.NoError(t, err)
p4, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy_failed_to_run_then_pass",
Query: "select 42;",
})
require.NoError(t, err)
p5, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy_failed_to_run_then_fail",
Query: "select 42;",
})
require.NoError(t, err)
// Unknown policies will be considered their first execution.
newFailing, newPassing, err := ds.FlippingPoliciesForHost(ctx, host1.ID, map[uint]*bool{
99997: nil, // considered as didn't run.
99998: ptr.Bool(false),
99999: ptr.Bool(true),
})
require.NoError(t, err)
sort.Slice(newFailing, func(i, j int) bool {
return newFailing[i] < newFailing[j]
})
require.Equal(t, []uint{99998}, newFailing)
require.Empty(t, newPassing) // because this would be the first run.
// Unknown host.
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, 99999, map[uint]*bool{
p1.ID: ptr.Bool(false),
})
require.NoError(t, err)
require.Equal(t, []uint{p1.ID}, newFailing)
require.Empty(t, newPassing)
// Empty incoming results.
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, map[uint]*bool{})
require.NoError(t, err)
require.Empty(t, newFailing)
require.Empty(t, newPassing)
// incoming policy 1 with first new failing result: => no
// incoming policy 2 with first new passing result: => yes
incoming := map[uint]*bool{
p1.ID: ptr.Bool(false),
p2.ID: ptr.Bool(true),
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Equal(t, []uint{p1.ID}, newFailing)
require.Empty(t, newPassing) // because this would be the first run.
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
require.NoError(t, err)
// incoming policy 1 with passing result: no => yes
// incoming policy 2 with failing result: yes => no
incoming = map[uint]*bool{
p1.ID: ptr.Bool(true),
p2.ID: ptr.Bool(false),
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Equal(t, []uint{p2.ID}, newFailing)
require.Equal(t, []uint{p1.ID}, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
require.NoError(t, err)
// incoming policy 1 with passing result: yes => yes
// incoming policy 2 with failing result: no => no
incoming = map[uint]*bool{
p1.ID: ptr.Bool(true),
p2.ID: ptr.Bool(false),
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Empty(t, newFailing)
require.Empty(t, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
require.NoError(t, err)
// incoming policy 1 failed to execute: yes => no
// incoming policy 2 failed to execute: no => no
incoming = map[uint]*bool{
p1.ID: ptr.Bool(false),
p2.ID: ptr.Bool(false),
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Equal(t, []uint{p1.ID}, newFailing)
require.Empty(t, newPassing)
// incoming pfailed failed to execute: ---
incoming = map[uint]*bool{
pfailed.ID: nil,
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Empty(t, newFailing)
require.Empty(t, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
require.NoError(t, err)
// incoming pfailed again failed to execute: --- -> ---
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Empty(t, newFailing)
require.Empty(t, newPassing)
// incoming policy 4 failed to run: => ---
// incoming policy 5 failed to run: => ---
incoming = map[uint]*bool{
p4.ID: nil,
p5.ID: nil,
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Empty(t, newFailing)
require.Empty(t, newPassing)
// Record the above executions.
err = ds.RecordPolicyQueryExecutions(ctx, host1, incoming, time.Now(), false)
require.NoError(t, err)
// incoming policy 4 with first new failing result: --- => no
// incoming policy 5 with first new passing result: --- => yes
incoming = map[uint]*bool{
p4.ID: ptr.Bool(false),
p5.ID: ptr.Bool(true),
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Equal(t, []uint{p4.ID}, newFailing)
require.Empty(t, newPassing) // because this would be the first run.
// incoming policy 4 now fails to execute: no => ---
// incoming policy 5 now fails to execute: yes => ---
incoming = map[uint]*bool{
p4.ID: nil,
p5.ID: nil,
}
newFailing, newPassing, err = ds.FlippingPoliciesForHost(ctx, host1.ID, incoming)
require.NoError(t, err)
require.Empty(t, newFailing)
require.Empty(t, newPassing)
}
func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
tm, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name()})
require.NoError(t, err)
const hostWin, hostMac, hostDeb, hostLin = 0, 1, 2, 3
platforms := []string{"windows", "darwin", "debian", "linux"}
// create hosts with different platforms, for that team
teamHosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
TeamID: ptr.Uint(tm.ID),
})
require.NoError(t, err)
teamHosts[i] = h
}
// create hosts with different platforms, without team
globalHosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("g%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
})
require.NoError(t, err)
globalHosts[i] = h
}
// new global policy for any platform
_, err = ds.NewGlobalPolicy(ctx, ptr.Uint(user.ID), fleet.PolicyPayload{Name: "g1", Query: "select 1", Platform: ""})
require.NoError(t, err)
// new team policy for any platform
_, err = ds.NewTeamPolicy(ctx, tm.ID, ptr.Uint(user.ID), fleet.PolicyPayload{Name: "t1", Query: "select 1", Platform: ""})
require.NoError(t, err)
// new global and team policies for Linux, via apply spec
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
{Name: "g2", Query: "select 2", Platform: "linux", Type: fleet.PolicyTypeDynamic},
{Name: "t2", Query: "select 2", Team: tm.Name, Platform: "linux", Type: fleet.PolicyTypeDynamic},
})
require.NoError(t, err)
// load the global policies
gpols, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gpols, 2)
// load the team policies
tpols, _, err := ds.ListTeamPolicies(ctx, tm.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, tpols, 2)
// index the policies by name for easier access in the rest of the test
polsByName := make(map[string]*fleet.Policy, len(gpols)+len(tpols))
for _, tpol := range tpols {
polsByName[tpol.Name] = tpol
}
for _, gpol := range gpols {
polsByName[gpol.Name] = gpol
}
// updating without change works fine
err = ds.SavePolicy(ctx, polsByName["g1"], false, false)
require.NoError(t, err)
err = ds.SavePolicy(ctx, polsByName["t2"], false, false)
require.NoError(t, err)
// apply specs that result in an update (without change) works fine
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
{Name: polsByName["g2"].Name, Query: polsByName["g2"].Query, Platform: polsByName["g2"].Platform, Type: fleet.PolicyTypeDynamic},
{Name: polsByName["t1"].Name, Query: polsByName["t1"].Query, Team: tm.Name, Platform: polsByName["t1"].Platform, Type: fleet.PolicyTypeDynamic},
})
require.NoError(t, err)
pol, err := ds.Policy(ctx, polsByName["g2"].ID)
require.NoError(t, err)
require.Equal(t, polsByName["g2"], pol)
pol, err = ds.Policy(ctx, polsByName["t1"].ID)
require.NoError(t, err)
require.Equal(t, polsByName["t1"], pol)
// record some results for each policy
for i, h := range teamHosts {
res := map[uint]*bool{
polsByName["t1"].ID: ptr.Bool(true),
}
if i == hostDeb || i == hostLin {
// also record a result for linux policy
res[polsByName["t2"].ID] = ptr.Bool(true)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
for i, h := range globalHosts {
res := map[uint]*bool{
polsByName["g1"].ID: ptr.Bool(true),
}
if i == hostDeb || i == hostLin {
// also record a result for linux policy
res[polsByName["g2"].ID] = ptr.Bool(true)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
wantHostsByPol := map[string][]uint{
"g1": {globalHosts[hostWin].ID, globalHosts[hostMac].ID, globalHosts[hostDeb].ID, globalHosts[hostLin].ID},
"g2": {globalHosts[hostDeb].ID, globalHosts[hostLin].ID},
"t1": {teamHosts[hostWin].ID, teamHosts[hostMac].ID, teamHosts[hostDeb].ID, teamHosts[hostLin].ID},
"t2": {teamHosts[hostDeb].ID, teamHosts[hostLin].ID},
}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update global policy g1 from any => linux
g1 := polsByName["g1"]
g1.Platform = "linux"
polsByName["g1"] = g1
err = ds.SavePolicy(ctx, g1, false, false)
require.NoError(t, err)
wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update team policy t1 from any => windows, darwin
t1 := polsByName["t1"]
t1.Platform = "windows,darwin"
polsByName["t1"] = t1
err = ds.SavePolicy(ctx, t1, false, false)
require.NoError(t, err)
wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update g2 from linux => any, t2 from linux => debian, via ApplySpecs
t2, g2 := polsByName["t2"], polsByName["g2"]
g2.Platform = ""
t2.Platform = "debian"
polsByName["t2"], polsByName["g2"] = t2, g2
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
{Name: g2.Name, Query: g2.Query, Platform: g2.Platform, Type: fleet.PolicyTypeDynamic},
{Name: t2.Name, Query: t2.Query, Team: tm.Name, Platform: t2.Platform, Type: fleet.PolicyTypeDynamic},
})
require.NoError(t, err)
// nothing should've changed for g2 (platform changed to any, so nothing to cleanup),
// while t2 should now only accept debian
wantHostsByPol["t2"] = []uint{teamHosts[hostDeb].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
}
func assertPolicyMembership(t *testing.T, ds *Datastore, polsByName map[string]*fleet.Policy, wantPolNameToHostIDs map[string][]uint) {
policyIDs := make([]uint, 0, len(polsByName))
for _, pol := range polsByName {
policyIDs = append(policyIDs, pol.ID)
}
loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id IN (?)`, policyIDs)
require.NoError(t, err)
type polHostIDs struct {
PolicyID uint `db:"policy_id"`
HostID uint `db:"host_id"`
}
var rows []polHostIDs
err = ds.writer(context.Background()).SelectContext(context.Background(), &rows, loadMembershipStmt, args...)
require.NoError(t, err)
// index the host IDs by policy ID
hostIDsByPolID := make(map[uint][]uint, len(policyIDs))
for _, row := range rows {
hostIDsByPolID[row.PolicyID] = append(hostIDsByPolID[row.PolicyID], row.HostID)
}
// assert that they match the expected list of hosts by policy
for polNm, hostIDs := range wantPolNameToHostIDs {
pol, ok := polsByName[polNm]
if !ok {
require.Len(t, hostIDs, 0)
continue
}
got := hostIDsByPolID[pol.ID]
require.ElementsMatch(t, hostIDs, got)
}
}
func testPolicyViolationDays(t *testing.T, ds *Datastore) {
ctx := context.Background()
then := time.Now().Add(-48 * time.Hour)
setStatsTimestampDB := func(updatedAt time.Time) error {
_, err := ds.writer(ctx).ExecContext(ctx, `
UPDATE aggregated_stats SET created_at = ?, updated_at = ? WHERE id = ? AND global_stats = ? AND type = ?
`, then, updatedAt, 0, true, aggregatedStatsTypePolicyViolationsDays)
return err
}
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
hosts := make([]*fleet.Host, 3)
for i, name := range []string{"h1", "h2", "h3"} {
id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: then,
LabelUpdatedAt: then,
PolicyUpdatedAt: then,
SeenTime: then,
NodeKey: &id,
UUID: id,
Hostname: name,
})
require.NoError(t, err)
hosts[i] = h
}
createPolStmt := fmt.Sprintf(
`INSERT INTO policies (name, query, description, author_id, platforms, created_at, updated_at, checksum) VALUES (?, ?, '', ?, ?, ?, ?, %s)`,
policiesChecksumComputedColumn(),
)
res, err := ds.writer(ctx).ExecContext(ctx, createPolStmt, "test_pol", "select 1", user.ID, "", then, then)
require.NoError(t, err)
id, _ := res.LastInsertId()
pol, err := ds.Policy(ctx, uint(id)) //nolint:gosec // dismiss G115
require.NoError(t, err)
require.NoError(t, ds.InitializePolicyViolationDays(ctx)) // sets starting violation count to zero
// initialize policy statuses: 1 failling, 2 passing
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[0], map[uint]*bool{pol.ID: ptr.Bool(false)}, then, false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), hosts[2], map[uint]*bool{pol.ID: ptr.Bool(true)}, then, false))
// setup db for test: starting counts zero, more than 24h since last updated, one outstanding violation
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err := amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
require.NoError(t, err)
// actual should increment from 0 -> 1 (+1 outstanding violation)
require.Equal(t, 1, actual)
// possible should increment from 0 -> 3 (3 total hosts * 1 policy)
require.Equal(t, 3, possible)
// reset violation counts to zero for next test
require.NoError(t, ds.InitializePolicyViolationDays(ctx))
// setup for test: starting counts zero, less than 24h since last updated, one outstanding violation
require.NoError(t, setStatsTimestampDB(time.Now().Add(-1*time.Hour)))
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
require.NoError(t, err)
// count should not increment from zero
require.Equal(t, 0, actual)
// possible should not increment from zero
require.Equal(t, 0, possible)
// leave counts at zero for next test
// setup for test: starting count zero, more than 24h since last updated, add second outstanding violation
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
require.NoError(t, err)
// actual should increment from 0 -> 2 (+2 outstanding violations)
require.Equal(t, 2, actual) // leave count at two for next test
// possible should increment from 0 -> 3 (3 total hosts * 1 policy)
require.Equal(t, 3, possible)
// leave counts at 2 actual and 3 possible for next test
// setup for test: starting counts at 2 actual and 3 possible, more than 24h since last updated, resolve one outstaning violation
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hosts[1], map[uint]*bool{pol.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, setStatsTimestampDB(time.Now().Add(-25*time.Hour)))
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
require.NoError(t, err)
// actual should increment from 2 -> 3 (+1 outstanding violation)
require.Equal(t, 3, actual)
// possible should increment from 3 -> 6 (3 total hosts * 1 policy)
require.Equal(t, 6, possible)
// leave counts at 3 actual and 6 possible
// attempt again immediately after last update, counts should not increment
require.NoError(t, ds.IncrementPolicyViolationDays(ctx))
actual, possible, err = amountPolicyViolationDaysDB(ctx, ds.reader(ctx))
require.NoError(t, err)
require.Equal(t, 3, actual)
require.Equal(t, 6, possible)
}
func testPolicyCleanupPolicyMembership(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
// create hosts with different platforms
hostWin, hostMac, hostDeb, hostLin := 0, 1, 2, 3
platforms := []string{"windows", "darwin", "debian", "linux"}
hosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
})
require.NoError(t, err)
hosts[i] = h
}
// create some policies, using direct insert statements to control the timestamps
createPolStmt := fmt.Sprintf(
`INSERT INTO policies (name, query, description, author_id, platforms, created_at, updated_at, checksum)
VALUES (?, ?, '', ?, ?, ?, ?, %s)`, policiesChecksumComputedColumn(),
)
jan2020 := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
feb2020 := time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC)
mar2020 := time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC)
apr2020 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC)
may2020 := time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC)
pols := make([]*fleet.Policy, 3)
for i, dt := range []time.Time{jan2020, feb2020, mar2020} {
res, err := ds.writer(ctx).ExecContext(ctx, createPolStmt, "p"+strconv.Itoa(i+1), "select 1", user.ID, "", dt, dt)
require.NoError(t, err)
id, _ := res.LastInsertId()
pol, err := ds.Policy(ctx, uint(id)) //nolint:gosec // dismiss G115
require.NoError(t, err)
pols[i] = pol
}
// index the policies by name for easier access in the rest of the test
polsByName := make(map[string]*fleet.Policy, len(pols))
for _, pol := range pols {
polsByName[pol.Name] = pol
}
wantHostsByPol := map[string][]uint{
"p1": {},
"p2": {},
"p3": {},
}
// no recently updated policies
err := ds.CleanupPolicyMembership(ctx, time.Now())
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
var count uint64
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues"))
assert.Zero(t, count)
// record results for each policy, all hosts, even if invalid for the policy
for _, h := range hosts {
res := map[uint]*bool{
polsByName["p1"].ID: ptr.Bool(false), // This failing policy will increment the host_issues count.
polsByName["p2"].ID: ptr.Bool(true),
polsByName["p3"].ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
assert.Equal(t, uint64(len(hosts)), count)
// no recently updated policies, so no host gets cleaned up
wantHostsByPol = map[string][]uint{
"p1": {hosts[hostWin].ID, hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID},
"p2": {hosts[hostWin].ID, hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID},
"p3": {hosts[hostWin].ID, hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID},
}
err = ds.CleanupPolicyMembership(ctx, time.Now())
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update policy p1, but do not change the platform (still any)
pols[0].Description = "updated"
updatePolicyWithTimestamp(t, ds, pols[0], feb2020)
err = ds.CleanupPolicyMembership(ctx, time.Now())
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update policy p1 to "windows", but cleanup with a timestamp of apr2020, so
// not "recently updated", no changes
pols[0].Platform = "windows"
updatePolicyWithTimestamp(t, ds, pols[0], mar2020)
err = ds.CleanupPolicyMembership(ctx, apr2020)
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// now cleanup with a timestamp of mar2020+1h, so "recently updated", only windows
// hosts are kept
err = ds.CleanupPolicyMembership(ctx, mar2020.Add(time.Hour))
require.NoError(t, err)
wantHostsByPol["p1"] = []uint{hosts[hostWin].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
assert.Equal(t, uint64(1), count, "only the Windows host should have issues")
// update policy p2 to "linux,darwin", but cleanup with a timestamp of just over 24h, so
// not "recently updated", no changes
pols[1].Platform = "linux,darwin"
updatePolicyWithTimestamp(t, ds, pols[1], mar2020)
err = ds.CleanupPolicyMembership(ctx, mar2020.Add(25*time.Hour))
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// now cleanup with a timestamp of just under 24h, so it is "recently updated"
err = ds.CleanupPolicyMembership(ctx, mar2020.Add(23*time.Hour))
require.NoError(t, err)
wantHostsByPol["p2"] = []uint{hosts[hostMac].ID, hosts[hostDeb].ID, hosts[hostLin].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update policy p2 to just "linux", p3 to "debian", both get cleaned up (using apr2020
// because p3 was created with mar2020, so it will not be detected as updated if we use
// that same timestamp for the update).
pols[1].Platform = "linux"
updatePolicyWithTimestamp(t, ds, pols[1], apr2020)
pols[2].Platform = "debian"
updatePolicyWithTimestamp(t, ds, pols[2], apr2020)
err = ds.CleanupPolicyMembership(ctx, apr2020.Add(time.Hour))
require.NoError(t, err)
wantHostsByPol["p2"] = []uint{hosts[hostDeb].ID, hosts[hostLin].ID}
wantHostsByPol["p3"] = []uint{hosts[hostDeb].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// cleaning up again 1h later doesn't change anything
err = ds.CleanupPolicyMembership(ctx, apr2020.Add(2*time.Hour))
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// update policy p1 to allow any, doesn't clean up anything
pols[0].Platform = ""
updatePolicyWithTimestamp(t, ds, pols[0], may2020)
err = ds.CleanupPolicyMembership(ctx, may2020.Add(time.Hour))
require.NoError(t, err)
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
}
func updatePolicyWithTimestamp(t *testing.T, ds *Datastore, p *fleet.Policy, ts time.Time) {
sqlStmt := `
UPDATE policies
SET name = ?, query = ?, description = ?, resolution = ?, platforms = ?, updated_at = ?
WHERE id = ?`
_, err := ds.writer(context.Background()).ExecContext(
context.Background(), sqlStmt, p.Name, p.Query, p.Description, p.Resolution, p.Platform, ts, p.ID,
)
require.NoError(t, err)
}
func testDeleteAllPolicyMemberships(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alan", "alan@example.com", true)
ctx := context.Background()
globalPolicy, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
host, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("567898"),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("4"),
UUID: "4",
Hostname: "bar.local",
})
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(
ctx,
host,
map[uint]*bool{globalPolicy.ID: ptr.Bool(false)},
time.Now(),
false,
)
require.NoError(t, err)
hostPolicies, err := ds.ListPoliciesForHost(ctx, host)
require.NoError(t, err)
require.Len(t, hostPolicies, 1)
var count int
err = ds.writer(ctx).Get(&count, "select COUNT(*) from policy_membership")
require.NoError(t, err)
require.Equal(t, 1, count)
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
assert.Equal(t, 1, count)
err = deleteAllPolicyMemberships(ctx, ds.writer(ctx), host.ID)
require.NoError(t, err)
err = ds.writer(ctx).Get(&count, "select COUNT(*) from policy_membership")
require.NoError(t, err)
require.Equal(t, 0, count)
require.NoError(t, ds.writer(ctx).Get(&count, "select COUNT(*) from host_issues WHERE total_issues_count > 0"))
assert.Zero(t, count)
}
func testIncreasePolicyAutomationIteration(t *testing.T, ds *Datastore) {
ctx := context.Background()
pol1, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy1"})
require.NoError(t, err)
pol2, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy2"})
require.NoError(t, err)
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol1.ID))
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol2.ID))
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol2.ID))
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol2.ID))
type at []struct {
PolicyID uint `db:"policy_id"`
Iteration int `db:"iteration"`
}
var automations at
err = ds.writer(ctx).Select(&automations, `SELECT policy_id, iteration FROM policy_automation_iterations;`)
require.NoError(t, err)
require.ElementsMatch(t, automations, at{
{pol1.ID, 1},
{pol2.ID, 3},
})
}
func testOutdatedAutomationBatch(t *testing.T, ds *Datastore) {
ctx := context.Background()
h1, err := ds.NewHost(ctx, &fleet.Host{OsqueryHostID: ptr.String("host1"), NodeKey: ptr.String("host1")})
require.NoError(t, err)
h2, err := ds.NewHost(ctx, &fleet.Host{OsqueryHostID: ptr.String("host2"), NodeKey: ptr.String("host2")})
require.NoError(t, err)
pol1, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy1"})
require.NoError(t, err)
pol2, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: "policy2"})
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, h1, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, h2, map[uint]*bool{pol1.ID: ptr.Bool(false), pol2.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
batch, err := ds.OutdatedAutomationBatch(ctx)
require.NoError(t, err)
require.ElementsMatch(t, batch, []fleet.PolicyFailure{})
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol1.ID))
batch, err = ds.OutdatedAutomationBatch(ctx)
require.NoError(t, err)
require.ElementsMatch(t, batch, []fleet.PolicyFailure{
{
PolicyID: pol1.ID,
Host: fleet.PolicySetHost{
ID: h1.ID,
},
},
{
PolicyID: pol1.ID,
Host: fleet.PolicySetHost{
ID: h2.ID,
},
},
})
batch, err = ds.OutdatedAutomationBatch(ctx)
require.NoError(t, err)
require.ElementsMatch(t, batch, []fleet.PolicyFailure{})
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol1.ID))
require.NoError(t, ds.IncreasePolicyAutomationIteration(ctx, pol2.ID))
batch, err = ds.OutdatedAutomationBatch(ctx)
require.NoError(t, err)
require.ElementsMatch(t, batch, []fleet.PolicyFailure{
{
PolicyID: pol1.ID,
Host: fleet.PolicySetHost{
ID: h1.ID,
},
}, {
PolicyID: pol1.ID,
Host: fleet.PolicySetHost{
ID: h2.ID,
},
}, {
PolicyID: pol2.ID,
Host: fleet.PolicySetHost{
ID: h2.ID,
},
},
})
batch, err = ds.OutdatedAutomationBatch(ctx)
require.NoError(t, err)
require.ElementsMatch(t, batch, []fleet.PolicyFailure{})
}
func testListGlobalPoliciesCanPaginate(t *testing.T, ds *Datastore) {
// create 30 policies
for i := 0; i < 30; i++ {
_, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: fmt.Sprintf("global policy %d", i)})
require.NoError(t, err)
}
// create 30 team policies
tm, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
for i := 0; i < 30; i++ {
_, err := ds.NewTeamPolicy(context.Background(), tm.ID, nil, fleet.PolicyPayload{Name: fmt.Sprintf("team policy %d", i)})
require.NoError(t, err)
}
// Page 0 contains 20 policies
policies, err := ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{
Page: 0,
PerPage: 20,
})
assert.Equal(t, "global policy 0", policies[0].Name)
assert.Len(t, policies, 20)
require.NoError(t, err)
// Page 1 contains 10 policies
policies, err = ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{
Page: 1,
PerPage: 20,
})
assert.Equal(t, "global policy 20", policies[0].Name)
assert.Len(t, policies, 10)
require.NoError(t, err)
// No list options returns all policies
policies, err = ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{})
assert.Len(t, policies, 30)
require.NoError(t, err)
}
func testListTeamPoliciesCanPaginate(t *testing.T, ds *Datastore) {
tm, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
// create 30 team policies
for i := 0; i < 30; i++ {
_, err := ds.NewTeamPolicy(context.Background(), tm.ID, nil, fleet.PolicyPayload{Name: fmt.Sprintf("team policy %d", i)})
require.NoError(t, err)
}
// create 30 global policies
for i := 0; i < 30; i++ {
_, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: fmt.Sprintf("global policy %d", i)})
require.NoError(t, err)
}
// Page 0 contains 20 policies
policies, _, err := ds.ListTeamPolicies(context.Background(), tm.ID, fleet.ListOptions{
Page: 0,
PerPage: 20,
}, fleet.ListOptions{}, "")
assert.Equal(t, "team policy 0", policies[0].Name)
assert.Len(t, policies, 20)
require.NoError(t, err)
// Page 1 contains 10 policies
policies, _, err = ds.ListTeamPolicies(context.Background(), tm.ID, fleet.ListOptions{
Page: 1,
PerPage: 20,
}, fleet.ListOptions{}, "")
assert.Equal(t, "team policy 20", policies[0].Name)
assert.Len(t, policies, 10)
require.NoError(t, err)
// No list options returns all policies
policies, _, err = ds.ListTeamPolicies(context.Background(), 1, fleet.ListOptions{}, fleet.ListOptions{}, "")
assert.Len(t, policies, 30)
require.NoError(t, err)
}
func testCountPolicies(t *testing.T, ds *Datastore) {
ctx := context.Background()
tm, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
// no policies
globalCount, err := ds.CountPolicies(ctx, nil, "", "")
require.NoError(t, err)
assert.Equal(t, 0, globalCount)
teamCount, err := ds.CountPolicies(ctx, &tm.ID, "", "")
require.NoError(t, err)
assert.Equal(t, 0, teamCount)
mergedCount, err := ds.CountMergedTeamPolicies(ctx, tm.ID, "", "")
require.NoError(t, err)
assert.Equal(t, 0, mergedCount)
// 10 global policies
for i := 0; i < 10; i++ {
_, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: fmt.Sprintf("global policy %d", i)})
require.NoError(t, err)
}
globalCount, err = ds.CountPolicies(ctx, nil, "", "")
require.NoError(t, err)
assert.Equal(t, 10, globalCount)
teamCount, err = ds.CountPolicies(ctx, &tm.ID, "", "")
require.NoError(t, err)
assert.Equal(t, 0, teamCount)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "", "")
require.NoError(t, err)
assert.Equal(t, 10, mergedCount)
// add 5 team policies
for i := 0; i < 5; i++ {
_, err := ds.NewTeamPolicy(ctx, tm.ID, nil, fleet.PolicyPayload{Name: fmt.Sprintf("team policy %d", i)})
require.NoError(t, err)
}
teamCount, err = ds.CountPolicies(ctx, &tm.ID, "", "")
require.NoError(t, err)
assert.Equal(t, 5, teamCount)
globalCount, err = ds.CountPolicies(ctx, nil, "", "")
require.NoError(t, err)
assert.Equal(t, 10, globalCount)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "", "")
require.NoError(t, err)
assert.Equal(t, 15, mergedCount)
// test filter
globalCount, err = ds.CountPolicies(ctx, nil, "global policy 1", "")
require.NoError(t, err)
assert.Equal(t, 1, globalCount)
teamCount, err = ds.CountPolicies(ctx, &tm.ID, "team policy 1", "")
require.NoError(t, err)
assert.Equal(t, 1, teamCount)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "policy 1", "")
require.NoError(t, err)
assert.Equal(t, 2, mergedCount)
// test automation filter doesn't affect global policy count
globalCount, err = ds.CountPolicies(ctx, nil, "", "scripts")
require.NoError(t, err)
assert.Equal(t, 10, globalCount)
}
func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) {
// new policy
policy, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "global policy 1"})
require.NoError(t, err)
// create 4 global hosts
var globalHosts []*fleet.Host
for i := 100; i < 104; i++ {
h, err := ds.NewHost(
context.Background(),
&fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: nil},
)
require.NoError(t, err)
globalHosts = append(globalHosts, h)
}
// add policy responses to global hosts
for _, h := range globalHosts {
res := map[uint]*bool{
policy.ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
require.NoError(t, err)
}
// check policy host counts before update
policy, err = ds.Policy(context.Background(), policy.ID)
require.NoError(t, err)
require.Equal(t, uint(0), policy.FailingHostCount)
require.Equal(t, uint(0), policy.PassingHostCount)
assert.Nil(t, policy.HostCountUpdatedAt)
// update policy host counts
now := time.Now().Truncate(time.Second)
later := now.Add(10 * time.Second)
err = ds.UpdateHostPolicyCounts(context.Background())
require.NoError(t, err)
// check policy host counts
policy, err = ds.Policy(context.Background(), policy.ID)
require.NoError(t, err)
require.Equal(t, uint(0), policy.FailingHostCount)
require.Equal(t, uint(4), policy.PassingHostCount)
require.NotNil(t, policy.HostCountUpdatedAt)
assert.True(
t, policy.HostCountUpdatedAt.Compare(now) >= 0, fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy.HostCountUpdatedAt),
)
assert.True(
t, policy.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy.HostCountUpdatedAt),
)
team, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
// create 4 team hosts
var teamHosts []*fleet.Host
for i := 0; i < 4; i++ {
h, err := ds.NewHost(context.Background(), &fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: &team.ID})
require.NoError(t, err)
teamHosts = append(teamHosts, h)
}
// add policy responses to team hosts
for _, h := range teamHosts {
var result *bool
switch h.ID % 5 {
case 0, 1: // 2 fails
result = ptr.Bool(false)
case 2: // 1 pass
result = ptr.Bool(true)
default:
// remain null
}
res := map[uint]*bool{
policy.ID: result,
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
require.NoError(t, err)
}
// update policy host counts
now = time.Now().Truncate(time.Second)
later = now.Add(10 * time.Second)
err = ds.UpdateHostPolicyCounts(context.Background())
require.NoError(t, err)
// check policy host counts
policy, err = ds.Policy(context.Background(), policy.ID)
require.NoError(t, err)
require.Equal(t, uint(2), policy.FailingHostCount)
require.Equal(t, uint(5), policy.PassingHostCount)
require.NotNil(t, policy.HostCountUpdatedAt)
assert.True(
t, policy.HostCountUpdatedAt.Compare(now) >= 0, fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy.HostCountUpdatedAt),
)
assert.True(
t, policy.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy.HostCountUpdatedAt),
)
// new global policy
policy2, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "global policy 2"})
require.NoError(t, err)
// new team
team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"})
require.NoError(t, err)
// create 4 team2 hosts
for i := 4; i < 8; i++ {
h, err := ds.NewHost(
context.Background(), &fleet.Host{
OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i)), TeamID: &team2.ID,
},
)
require.NoError(t, err)
teamHosts = append(teamHosts, h)
}
// Update policy results for all hosts.
// All fail policy 1, all pass policy 2
for _, h := range globalHosts {
res := map[uint]*bool{
policy.ID: ptr.Bool(false),
policy2.ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
require.NoError(t, err)
}
for _, h := range teamHosts {
res := map[uint]*bool{
policy.ID: ptr.Bool(false),
policy2.ID: ptr.Bool(true),
}
err = ds.RecordPolicyQueryExecutions(context.Background(), h, res, time.Now(), false)
require.NoError(t, err)
}
// update policy host counts
now = time.Now().Truncate(time.Second)
later = now.Add(10 * time.Second)
err = ds.UpdateHostPolicyCounts(context.Background())
require.NoError(t, err)
// check policy 1 host counts
policy, err = ds.Policy(context.Background(), policy.ID)
require.NoError(t, err)
require.Equal(t, uint(12), policy.FailingHostCount)
require.Equal(t, uint(0), policy.PassingHostCount)
require.NotNil(t, policy.HostCountUpdatedAt)
assert.True(
t, policy.HostCountUpdatedAt.Compare(now) >= 0, fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy.HostCountUpdatedAt),
)
assert.True(
t, policy.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy.HostCountUpdatedAt),
)
// check policy 2 host counts
policy2, err = ds.Policy(context.Background(), policy2.ID)
require.NoError(t, err)
require.Equal(t, uint(0), policy2.FailingHostCount)
require.Equal(t, uint(12), policy2.PassingHostCount)
require.NotNil(t, policy2.HostCountUpdatedAt)
assert.True(
t, policy2.HostCountUpdatedAt.Compare(now) >= 0,
fmt.Sprintf("reference:%v HostCountUpdatedAt:%v", now, *policy2.HostCountUpdatedAt),
)
assert.True(
t, policy2.HostCountUpdatedAt.Compare(later) < 0, fmt.Sprintf("later:%v HostCountUpdatedAt:%v", later, *policy2.HostCountUpdatedAt),
)
}
func testPoliciesNameUnicode(t *testing.T, ds *Datastore) {
var equivalentNames []string
item, _ := strconv.Unquote(`"\uAC00"`) // 가
equivalentNames = append(equivalentNames, item)
item, _ = strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ
equivalentNames = append(equivalentNames, item)
// Save policy
policy, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: equivalentNames[0]})
require.NoError(t, err)
assert.Equal(t, equivalentNames[0], policy.Name)
// Try to create policy with equivalent name
_, err = ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: equivalentNames[1]})
var existsErr *existsError
assert.ErrorAs(t, err, &existsErr)
// Try to update a different policy with equivalent name -- not allowed
policyEmoji, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "💻"})
require.NoError(t, err)
err = ds.SavePolicy(
context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false, false,
)
assert.True(t, IsDuplicate(err), err)
// Try to find policy with equivalent name
policies, err := ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{MatchQuery: equivalentNames[1]})
assert.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, equivalentNames[0], policies[0].Name)
// Test team methods
team, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
// Create team policy
teamPolicy, err := ds.NewTeamPolicy(context.Background(), team.ID, nil, fleet.PolicyPayload{Name: equivalentNames[0]})
require.NoError(t, err)
assert.Equal(t, equivalentNames[0], teamPolicy.Name)
// Try to create another team policy with equivalent name -- not allowed
_, err = ds.NewTeamPolicy(context.Background(), team.ID, nil, fleet.PolicyPayload{Name: equivalentNames[1]})
assert.ErrorAs(t, err, &existsErr)
// ListTeamPolicies, including inherited policy
teamPolicies, inheritedPolicies, err := ds.ListTeamPolicies(
context.Background(), team.ID, fleet.ListOptions{MatchQuery: equivalentNames[1]}, fleet.ListOptions{MatchQuery: equivalentNames[1]}, "",
)
assert.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, equivalentNames[0], teamPolicies[0].Name)
require.Len(t, inheritedPolicies, 1)
assert.Equal(t, equivalentNames[0], inheritedPolicies[0].Name)
// CountPolicies
count, err := ds.CountPolicies(context.Background(), &team.ID, equivalentNames[1], "")
assert.NoError(t, err)
assert.Equal(t, 1, count)
count, err = ds.CountPolicies(context.Background(), nil, equivalentNames[1], "")
assert.NoError(t, err)
assert.Equal(t, 1, count)
}
func testPoliciesNameEmoji(t *testing.T, ds *Datastore) {
// Try to save policies with emojis
emoji0 := "🔥"
_, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: emoji0})
require.NoError(t, err)
emoji1 := "💻"
policyEmoji, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: emoji1})
require.NoError(t, err)
assert.Equal(t, emoji1, policyEmoji.Name)
// Try to find policy with emoji0
policies, err := ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{MatchQuery: emoji0})
assert.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, emoji0, policies[0].Name)
// Try to find policy with emoji1
policies, err = ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{MatchQuery: emoji1})
assert.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, emoji1, policies[0].Name)
}
// Ensure case-insensitive sort order for policy names
func testPoliciesNameSort(t *testing.T, ds *Datastore) {
var policies [3]*fleet.Policy
var err error
// Save policy
policies[1], err = ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "В"})
require.NoError(t, err)
policies[2], err = ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "о"})
require.NoError(t, err)
policies[0], err = ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "а"})
require.NoError(t, err)
policiesResult, err := ds.ListGlobalPolicies(context.Background(), fleet.ListOptions{OrderKey: "name"})
assert.NoError(t, err)
require.Len(t, policies, 3)
for i, policy := range policies {
assert.Equal(t, policy.Name, policiesResult[i].Name)
}
}
func testGetCalendarPolicies(t *testing.T, ds *Datastore) {
ctx := context.Background()
// Test with non-existent team.
_, err := ds.GetCalendarPolicies(ctx, 999)
require.NoError(t, err)
team, err := ds.NewTeam(ctx, &fleet.Team{
Name: "Foobar",
})
require.NoError(t, err)
// Test when the team has no policies.
_, err = ds.GetCalendarPolicies(ctx, team.ID)
require.NoError(t, err)
// Create a global query to test that only team policies are returned.
_, err = ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{
Name: "Global Policy",
Query: "SELECT * FROM time;",
})
require.NoError(t, err)
_, err = ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
Name: "Team Policy 1",
Query: "SELECT * FROM system_info;",
CalendarEventsEnabled: false,
})
require.NoError(t, err)
// Test when the team has policies, but none is configured for calendar.
_, err = ds.GetCalendarPolicies(ctx, team.ID)
require.NoError(t, err)
teamPolicy2, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
Name: "Team Policy 2",
Query: "SELECT * FROM osquery_info;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
teamPolicy3, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
Name: "Team Policy 3",
Query: "SELECT * FROM os_version;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
calendarPolicies, err := ds.GetCalendarPolicies(ctx, team.ID)
require.NoError(t, err)
require.Len(t, calendarPolicies, 2)
require.Equal(t, calendarPolicies[0].ID, teamPolicy2.ID)
require.Equal(t, calendarPolicies[1].ID, teamPolicy3.ID)
}
func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
ctx := context.Background()
//
// Test setup:
//
// team1:
// team1Policy1 (calendar), team1Policy2
// host1, host5, host6
//
// team2:
// team2Policy1 (calendar), team2Policy2 (calendar)
// host2, host3
//
// global:
// Global Policy 1
// host4
//
//
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
team1Policy1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "Team 1 Policy 1",
Query: "SELECT * FROM osquery_info;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
team1Policy2, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "Team 1 Policy 2",
Query: "SELECT * FROM system_info;",
CalendarEventsEnabled: false,
})
require.NoError(t, err)
team2Policy1, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{
Name: "Team 2 Policy 1",
Query: "SELECT * FROM os_version;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
team2Policy2, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{
Name: "Team 2 Policy 2",
Query: "SELECT * FROM processes;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
_, err = ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{
Name: "Global Policy 1",
Query: "SELECT * FROM foobar;",
})
require.NoError(t, err)
// Empty teams.
hostsTeam1, err := ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policy1.ID, team1Policy2.ID}, nil)
require.NoError(t, err)
require.Empty(t, hostsTeam1)
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("host1"),
NodeKey: ptr.String("host1"),
HardwareSerial: "serial1",
ComputerName: "display_name1",
TeamID: &team1.ID,
})
require.NoError(t, err)
host2, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("host2"),
NodeKey: ptr.String("host2"),
HardwareSerial: "serial2",
ComputerName: "display_name2",
TeamID: &team2.ID,
})
require.NoError(t, err)
host3, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("host3"),
NodeKey: ptr.String("host3"),
HardwareSerial: "serial3",
ComputerName: "display_name3",
TeamID: &team2.ID,
})
require.NoError(t, err)
host4, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("host4"),
NodeKey: ptr.String("host4"),
HardwareSerial: "serial4",
ComputerName: "display_name4",
})
require.NoError(t, err)
host5, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("host5"),
NodeKey: ptr.String("host5"),
HardwareSerial: "serial5",
ComputerName: "display_name5",
TeamID: &team1.ID,
})
require.NoError(t, err)
host6, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("host6"),
NodeKey: ptr.String("host6"),
HardwareSerial: "serial6",
ComputerName: "display_name6",
TeamID: &team1.ID,
})
require.NoError(t, err)
// Some domain that doesn't exist on any of the hosts
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "not-exists.com", team1.ID, []uint{team1Policy1.ID, team1Policy2.ID}, nil)
require.NoError(t, err)
require.Empty(t, hostsTeam1)
// No policy results yet (and no calendar events).
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policy1.ID, team1Policy2.ID}, nil)
require.NoError(t, err)
require.Empty(t, hostsTeam1)
//
// Email setup
//
// host1 has foo@example.com, zoo@example.com
// host2 has foo@example.com, foo@other.com
// host3 has zoo@example.com
// host4 has foo@example.com
// host5 has foo@other.com
// host6 has bar@example.com
//
err = ds.ReplaceHostDeviceMapping(ctx, host1.ID, []*fleet.HostDeviceMapping{
{HostID: host1.ID, Email: "foo@example.com", Source: "google_chrome_profiles"},
}, "google_chrome_profiles")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host1.ID, []*fleet.HostDeviceMapping{
{HostID: host1.ID, Email: "zoo@example.com", Source: "custom"},
}, "custom")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host2.ID, []*fleet.HostDeviceMapping{
{HostID: host2.ID, Email: "foo@example.com", Source: "custom"},
}, "custom")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host2.ID, []*fleet.HostDeviceMapping{
{HostID: host2.ID, Email: "foo@other.com", Source: "google_chrome_profiles"},
}, "google_chrome_profiles")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host3.ID, []*fleet.HostDeviceMapping{
{HostID: host3.ID, Email: "zoo@example.com", Source: "google_chrome_profiles"},
}, "google_chrome_profiles")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host4.ID, []*fleet.HostDeviceMapping{
{HostID: host4.ID, Email: "foo@example.com", Source: "google_chrome_profiles"},
}, "google_chrome_profiles")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host5.ID, []*fleet.HostDeviceMapping{
{HostID: host5.ID, Email: "foo@other.com", Source: "google_chrome_profiles"},
}, "google_chrome_profiles")
require.NoError(t, err)
err = ds.ReplaceHostDeviceMapping(ctx, host6.ID, []*fleet.HostDeviceMapping{
{HostID: host6.ID, Email: "bar@example.com", Source: "google_chrome_profiles"},
}, "google_chrome_profiles")
require.NoError(t, err)
//
// Results setup
//
// host1 (team1) is passing team1Policy1 (calendar) and failing team1Policy2.
// host2 (team2) is failing team2Policy1 (calendar) and passing team2Policy2 (calendar).
// host3 (team2) is passing all policies.
// host5 (team1) is failing all policies.
// host6 (team1) has not returned results.
//
err = ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{
team1Policy1.ID: ptr.Bool(true),
team1Policy2.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{
team2Policy1.ID: ptr.Bool(false),
team2Policy2.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{
team2Policy1.ID: ptr.Bool(true),
team2Policy2.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{
team1Policy1.ID: ptr.Bool(false),
team1Policy2.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
team1Policies, err := ds.GetCalendarPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, team1Policies, 1)
team2Policies, err := ds.GetCalendarPolicies(ctx, team2.ID)
require.NoError(t, err)
require.Len(t, team2Policies, 2)
// Only returns the failing host, because the passing hosts do not have a calendar event.
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}, nil)
require.NoError(t, err)
sort.Slice(hostsTeam1, func(i, j int) bool {
return hostsTeam1[i].HostID < hostsTeam1[j].HostID
})
require.Len(t, hostsTeam1, 1)
require.Equal(t, host5.ID, hostsTeam1[0].HostID)
require.Empty(t, hostsTeam1[0].Email)
require.False(t, hostsTeam1[0].Passing)
require.Equal(t, "serial5", hostsTeam1[0].HostHardwareSerial)
require.Equal(t, "display_name5", hostsTeam1[0].HostDisplayName)
//
// Create a calendar event on host1 and host6.
//
tZ := "America/Argentina/Buenos_Aires"
now := time.Now()
eventUUID1 := uuid.New().String()
_, err = ds.CreateOrUpdateCalendarEvent(ctx, eventUUID1, "foo@example.com", now, now.Add(30*time.Minute), []byte(`{"foo": "bar"}`), &tZ,
host1.ID, fleet.CalendarWebhookStatusPending)
require.NoError(t, err)
eventUUID2 := uuid.New().String()
_, err = ds.CreateOrUpdateCalendarEvent(ctx, eventUUID2, "bar@example.com", now, now.Add(30*time.Minute), []byte(`{"foo": "bar"}`), &tZ,
host6.ID, fleet.CalendarWebhookStatusPending)
require.NoError(t, err)
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}, nil)
require.NoError(t, err)
sort.Slice(hostsTeam1, func(i, j int) bool {
return hostsTeam1[i].HostID < hostsTeam1[j].HostID
})
require.Len(t, hostsTeam1, 3)
require.Equal(t, host1.ID, hostsTeam1[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam1[0].Email)
require.True(t, hostsTeam1[0].Passing)
require.Equal(t, "serial1", hostsTeam1[0].HostHardwareSerial)
require.Equal(t, "display_name1", hostsTeam1[0].HostDisplayName)
require.Equal(t, host5.ID, hostsTeam1[1].HostID)
require.Empty(t, hostsTeam1[1].Email)
require.False(t, hostsTeam1[1].Passing)
require.Equal(t, "serial5", hostsTeam1[1].HostHardwareSerial)
require.Equal(t, "display_name5", hostsTeam1[1].HostDisplayName)
require.Equal(t, host6.ID, hostsTeam1[2].HostID)
require.Equal(t, "bar@example.com", hostsTeam1[2].Email)
require.True(t, hostsTeam1[2].Passing)
require.Equal(t, "serial6", hostsTeam1[2].HostHardwareSerial)
require.Equal(t, "display_name6", hostsTeam1[2].HostDisplayName)
//
// Move host 4 to team1 and have it fail all team1 policies.
//
err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{host4.ID}))
require.NoError(t, err)
err = ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{
team1Policy1.ID: ptr.Bool(false),
team1Policy2.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam1, 4)
sort.Slice(hostsTeam1, func(i, j int) bool {
return hostsTeam1[i].HostID < hostsTeam1[j].HostID
})
require.Equal(t, host1.ID, hostsTeam1[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam1[0].Email)
require.True(t, hostsTeam1[0].Passing)
require.Equal(t, "serial1", hostsTeam1[0].HostHardwareSerial)
require.Equal(t, "display_name1", hostsTeam1[0].HostDisplayName)
require.Equal(t, host4.ID, hostsTeam1[1].HostID)
require.Equal(t, "foo@example.com", hostsTeam1[1].Email)
require.False(t, hostsTeam1[1].Passing)
require.Equal(t, "serial4", hostsTeam1[1].HostHardwareSerial)
require.Equal(t, "display_name4", hostsTeam1[1].HostDisplayName)
require.Equal(t, host5.ID, hostsTeam1[2].HostID)
require.Empty(t, hostsTeam1[2].Email)
require.False(t, hostsTeam1[2].Passing)
require.Equal(t, "serial5", hostsTeam1[2].HostHardwareSerial)
require.Equal(t, "display_name5", hostsTeam1[2].HostDisplayName)
require.Equal(t, host6.ID, hostsTeam1[3].HostID)
require.Equal(t, "bar@example.com", hostsTeam1[3].Email)
require.True(t, hostsTeam1[3].Passing)
require.Equal(t, "serial6", hostsTeam1[3].HostHardwareSerial)
require.Equal(t, "display_name6", hostsTeam1[3].HostDisplayName)
//
// host3 doesn't have a calendar event so it's not returned by GetTeamHostsPolicyMemberships.
//
hostsTeam2, err := ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam2, 1)
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.False(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
//
// Create a calendar event on host2 and host3.
//
now = time.Now()
_, err = ds.CreateOrUpdateCalendarEvent(ctx, eventUUID1, "foo@example.com", now, now.Add(30*time.Minute), []byte(`{"foo": "bar"}`), &tZ,
host2.ID, fleet.CalendarWebhookStatusPending)
require.NoError(t, err)
eventUUID3 := uuid.New().String()
calendarEventHost3, err := ds.CreateOrUpdateCalendarEvent(ctx, eventUUID3, "zoo@example.com", now, now.Add(30*time.Minute),
[]byte(`{"foo": "bar"}`), &tZ, host3.ID, fleet.CalendarWebhookStatusPending)
require.NoError(t, err)
//
// Now it should return host3 because it's passing and has a calendar event.
//
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam2, 2)
sort.Slice(hostsTeam2, func(i, j int) bool {
return hostsTeam2[i].HostID < hostsTeam1[j].HostID
})
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.False(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
require.Equal(t, host3.ID, hostsTeam2[1].HostID)
require.Equal(t, "zoo@example.com", hostsTeam2[1].Email)
require.True(t, hostsTeam2[1].Passing)
require.Equal(t, "serial3", hostsTeam2[1].HostHardwareSerial)
require.Equal(t, "display_name3", hostsTeam2[1].HostDisplayName)
//
// Make host2 policy results invalid (NULL).
//
err = ds.RecordPolicyQueryExecutions(
ctx, host2, map[uint]*bool{
team2Policy1.ID: nil,
team2Policy2.ID: nil,
}, time.Now(), false,
)
require.NoError(t, err)
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam2, 2)
sort.Slice(
hostsTeam2, func(i, j int) bool {
return hostsTeam2[i].HostID < hostsTeam1[j].HostID
},
)
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.True(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
require.Equal(t, host3.ID, hostsTeam2[1].HostID)
require.Equal(t, "zoo@example.com", hostsTeam2[1].Email)
require.True(t, hostsTeam2[1].Passing)
require.Equal(t, "serial3", hostsTeam2[1].HostHardwareSerial)
require.Equal(t, "display_name3", hostsTeam2[1].HostDisplayName)
//
// Make host2 pass all policies.
//
err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{
team2Policy1.ID: ptr.Bool(true),
team2Policy2.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam2, 2)
sort.Slice(hostsTeam2, func(i, j int) bool {
return hostsTeam2[i].HostID < hostsTeam1[j].HostID
})
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.True(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
require.Equal(t, host3.ID, hostsTeam2[1].HostID)
require.Equal(t, "zoo@example.com", hostsTeam2[1].Email)
require.True(t, hostsTeam2[1].Passing)
require.Equal(t, "serial3", hostsTeam2[1].HostHardwareSerial)
require.Equal(t, "display_name3", hostsTeam2[1].HostDisplayName)
// Retrieve the data only for host2.
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID},
&host2.ID)
require.NoError(t, err)
require.Len(t, hostsTeam2, 1)
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.True(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
//
// Delete host3 calendar event
//
err = ds.DeleteCalendarEvent(ctx, calendarEventHost3.ID)
require.NoError(t, err)
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam2, 1)
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.True(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
//
// Edit team2Policy1 platform (which removes all its policy_membership entries).
//
team2Policy1.Platform = "darwin"
err = ds.SavePolicy(ctx, team1Policy1, false, true)
require.NoError(t, err)
team1Policy1.Platform = "darwin"
err = ds.SavePolicy(ctx, team2Policy1, false, true)
require.NoError(t, err)
//
// We should still get host2 as passing because it has an associated calendar event.
//
hostsTeam2, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}, nil)
require.NoError(t, err)
require.Len(t, hostsTeam2, 1)
require.Equal(t, host2.ID, hostsTeam2[0].HostID)
require.Equal(t, "foo@example.com", hostsTeam2[0].Email)
require.True(t, hostsTeam2[0].Passing)
require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial)
require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName)
}
func testGetTeamHostsPolicyMembershipsEmailPriority(t *testing.T, ds *Datastore) {
ctx := context.Background()
team, err := ds.NewTeam(ctx, &fleet.Team{Name: "test-team"})
require.NoError(t, err)
calendarPolicy, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
Name: "Calendar Policy",
Query: "SELECT 1;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
newHost := func(name string) *fleet.Host {
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String(name),
NodeKey: ptr.String(name),
HardwareSerial: name + "-serial",
ComputerName: name,
TeamID: &team.ID,
})
require.NoError(t, err)
// Make the host fail the calendar policy so it always appears in results.
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{calendarPolicy.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
return h
}
setEmails := func(hostID uint, mappings ...*fleet.HostDeviceMapping) {
// Group by source so each source gets its own ReplaceHostDeviceMapping call,
// matching production usage (each source replaces its own slice).
bySource := make(map[string][]*fleet.HostDeviceMapping)
for _, m := range mappings {
bySource[m.Source] = append(bySource[m.Source], m)
}
for src, ms := range bySource {
err := ds.ReplaceHostDeviceMapping(ctx, hostID, ms, src)
require.NoError(t, err)
}
}
getResults := func(domain string) []fleet.HostPolicyMembershipData {
results, err := ds.GetTeamHostsPolicyMemberships(ctx, domain, team.ID, []uint{calendarPolicy.ID}, nil)
require.NoError(t, err)
sort.Slice(results, func(i, j int) bool { return results[i].HostID < results[j].HostID })
return results
}
emailFor := func(results []fleet.HostPolicyMembershipData, h *fleet.Host) string {
for _, r := range results {
if r.HostID == h.ID {
return r.Email
}
}
t.Fatalf("host %d not found in results", h.ID)
return ""
}
testCases := []struct {
name string
mappings []*fleet.HostDeviceMapping
expected string
}{
{
name: "idp-vs-chrome",
mappings: []*fleet.HostDeviceMapping{
{Email: "chrome@example.com", Source: fleet.DeviceMappingGoogleChromeProfiles},
{Email: "idp@example.com", Source: fleet.DeviceMappingMDMIdpAccounts},
},
expected: "idp@example.com",
},
{
name: "idpsrc-vs-chrome",
mappings: []*fleet.HostDeviceMapping{
{Email: "chrome@example.com", Source: fleet.DeviceMappingGoogleChromeProfiles},
{Email: "idpsrc@example.com", Source: fleet.DeviceMappingIDP},
},
expected: "idpsrc@example.com",
},
{
name: "custom-vs-chrome",
mappings: []*fleet.HostDeviceMapping{
{Email: "chrome@example.com", Source: fleet.DeviceMappingGoogleChromeProfiles},
{Email: "custom@example.com", Source: "custom"},
},
expected: "chrome@example.com",
},
{
name: "multi-idp",
mappings: []*fleet.HostDeviceMapping{
{Email: "zebra@example.com", Source: fleet.DeviceMappingMDMIdpAccounts},
{Email: "apple@example.com", Source: fleet.DeviceMappingMDMIdpAccounts},
},
expected: "apple@example.com",
},
{
name: "multi-chrome",
mappings: []*fleet.HostDeviceMapping{
{Email: "zebra@example.com", Source: fleet.DeviceMappingGoogleChromeProfiles},
{Email: "apple@example.com", Source: fleet.DeviceMappingGoogleChromeProfiles},
},
expected: "apple@example.com",
},
{
name: "idp-wrong-domain",
mappings: []*fleet.HostDeviceMapping{
{Email: "idp@other.com", Source: fleet.DeviceMappingMDMIdpAccounts},
{Email: "chrome@example.com", Source: fleet.DeviceMappingGoogleChromeProfiles},
},
expected: "chrome@example.com",
},
}
for _, tC := range testCases {
host := newHost(tC.name)
for _, m := range tC.mappings {
m.HostID = host.ID
}
setEmails(host.ID, tC.mappings...)
results := getResults("example.com")
require.Equal(t, tC.expected, emailFor(results, host), tC.name)
}
}
func testNewGlobalPolicyWithInstaller(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
_, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
Query: "SELECT 1;",
SoftwareInstallerID: ptr.Uint(1),
})
require.Error(t, err)
require.ErrorIs(t, err, errSoftwareTitleIDOnGlobalPolicy)
}
func testNewGlobalPolicyWithScript(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
_, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
Query: "SELECT 1;",
ScriptID: ptr.Uint(1),
})
require.Error(t, err)
require.ErrorIs(t, err, errScriptIDOnGlobalPolicy)
}
func testTeamPoliciesWithInstaller(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) // team2 has no policies
require.NoError(t, err)
// Policy p1 has no associated installer.
p1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
SoftwareInstallerID: nil,
})
require.NoError(t, err)
// Create and associate an installer to p2.
installer, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
require.Nil(t, p1.SoftwareInstallerID)
p2, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 1;",
SoftwareInstallerID: ptr.Uint(installerID),
})
require.NoError(t, err)
require.NotNil(t, p2.SoftwareInstallerID)
require.Equal(t, installerID, *p2.SoftwareInstallerID)
// Create p3 as global policy.
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "p3",
Query: "SELECT 1;",
})
require.NoError(t, err)
p2, err = ds.Policy(ctx, p2.ID)
require.NoError(t, err)
require.NotNil(t, p2.SoftwareInstallerID)
require.Equal(t, installerID, *p2.SoftwareInstallerID)
// Policy p4 in "No team" with associated installer.
installer1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
noTeamInstallerID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: ptr.Uint(fleet.PolicyNoTeamID),
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
p4, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "p4",
Query: "SELECT 4;",
SoftwareInstallerID: ptr.Uint(noTeamInstallerID),
})
require.NoError(t, err)
_, err = ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "p4",
Query: "SELECT 4;",
SoftwareInstallerID: ptr.Uint(installerID),
})
require.Error(t, err, "software installer is associated with a different team")
policiesWithInstallers, err := ds.GetPoliciesWithAssociatedInstaller(ctx, fleet.PolicyNoTeamID, []uint{p4.ID})
require.NoError(t, err)
require.Len(t, policiesWithInstallers, 1)
require.Equal(t, p4.ID, policiesWithInstallers[0].ID)
policiesWithInstallers, err = ds.GetPoliciesWithAssociatedInstaller(ctx, team1.ID, []uint{})
require.NoError(t, err)
require.Empty(t, policiesWithInstallers)
// p1 has no associated installers.
policiesWithInstallers, err = ds.GetPoliciesWithAssociatedInstaller(ctx, team1.ID, []uint{p1.ID})
require.NoError(t, err)
require.Empty(t, policiesWithInstallers)
policiesWithInstallers, err = ds.GetPoliciesWithAssociatedInstaller(ctx, team1.ID, []uint{p2.ID})
require.NoError(t, err)
require.Len(t, policiesWithInstallers, 1)
require.Equal(t, p2.ID, policiesWithInstallers[0].ID)
require.Equal(t, installerID, policiesWithInstallers[0].InstallerID)
// p2 has associated installer but belongs to team1.
policiesWithInstallers, err = ds.GetPoliciesWithAssociatedInstaller(ctx, team2.ID, []uint{p2.ID})
require.NoError(t, err)
require.Empty(t, policiesWithInstallers)
p1.SoftwareInstallerID = ptr.Uint(installerID)
err = ds.SavePolicy(ctx, p1, false, false)
require.NoError(t, err)
p1.SoftwareInstallerID = ptr.Uint(noTeamInstallerID)
err = ds.SavePolicy(ctx, p1, false, false)
require.Error(t, err, "software installer is associated with a different team")
p2, err = ds.Policy(ctx, p2.ID)
require.NoError(t, err)
require.NotNil(t, p2.SoftwareInstallerID)
require.Equal(t, installerID, *p2.SoftwareInstallerID)
policiesWithInstallers, err = ds.GetPoliciesWithAssociatedInstaller(ctx, team1.ID, []uint{p1.ID, p2.ID})
require.NoError(t, err)
require.Len(t, policiesWithInstallers, 2)
require.Equal(t, p1.ID, policiesWithInstallers[0].ID)
require.Equal(t, installerID, policiesWithInstallers[0].InstallerID)
require.Equal(t, p2.ID, policiesWithInstallers[1].ID)
require.Equal(t, installerID, policiesWithInstallers[1].InstallerID)
policiesWithInstallers, err = ds.GetPoliciesWithAssociatedInstaller(ctx, team2.ID, []uint{p1.ID, p2.ID})
require.NoError(t, err)
require.Empty(t, policiesWithInstallers)
}
func testTeamPoliciesWithVPP(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) // team2 has no policies
require.NoError(t, err)
test.CreateInsertGlobalVPPToken(t, ds)
// create team1 app
team1App, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp1", BundleIdentifier: "com.app.appy",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_app", Platform: fleet.MacOSPlatform}},
}, &team1.ID)
require.NoError(t, err)
team1Meta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, &team1.ID, team1App.TitleID)
require.NoError(t, err)
// Policy p1 has no associated app.
p1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
})
require.NoError(t, err)
// Create and associate an app to p2.
p2, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 1;",
VPPAppsTeamsID: &team1Meta.VPPAppsTeamsID,
})
require.NoError(t, err)
require.Equal(t, team1Meta.VPPAppsTeamsID, *p2.VPPAppsTeamsID)
p2, err = ds.Policy(ctx, p2.ID)
require.NoError(t, err)
require.NotNil(t, p2.VPPAppsTeamsID)
require.Equal(t, team1Meta.VPPAppsTeamsID, *p2.VPPAppsTeamsID)
// create no-team app
noTeamApp, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp1", BundleIdentifier: "com.app.appy",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_app", Platform: fleet.MacOSPlatform}},
}, nil)
require.NoError(t, err)
noTeamMeta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, ptr.Uint(0), noTeamApp.TitleID)
require.NoError(t, err)
require.NotEqual(t, noTeamMeta.VPPAppsTeamsID, team1Meta.VPPAppsTeamsID)
require.Equal(t, noTeamMeta.AdamID, team1Meta.AdamID)
// Policy p4 in "No team" with associated app.
p4, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "p4",
Query: "SELECT 4;",
VPPAppsTeamsID: ptr.Uint(noTeamMeta.VPPAppsTeamsID),
})
require.NoError(t, err)
// create another team1 app
team1App2, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp2", BundleIdentifier: "com.app.vpp2",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp2", Platform: fleet.MacOSPlatform}},
}, &team1.ID)
require.NoError(t, err)
team1Meta2, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, &team1.ID, team1App2.TitleID)
require.NoError(t, err)
require.NoError(t, err)
_, err = ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "p5",
Query: "SELECT 5;",
VPPAppsTeamsID: ptr.Uint(team1Meta2.VPPAppsTeamsID),
})
require.Error(t, err, "app is associated with a different team")
policiesWithVPPs, err := ds.GetPoliciesWithAssociatedVPP(ctx, fleet.PolicyNoTeamID, []uint{p4.ID})
require.NoError(t, err)
require.Len(t, policiesWithVPPs, 1)
require.Equal(t, p4.ID, policiesWithVPPs[0].ID)
require.Equal(t, noTeamMeta.AdamID, policiesWithVPPs[0].AdamID)
require.Equal(t, noTeamMeta.Platform, policiesWithVPPs[0].Platform)
policiesWithVPPs, err = ds.GetPoliciesWithAssociatedVPP(ctx, team1.ID, []uint{})
require.NoError(t, err)
require.Empty(t, policiesWithVPPs)
// p1 has no associated apps.
policiesWithVPPs, err = ds.GetPoliciesWithAssociatedVPP(ctx, team1.ID, []uint{p1.ID})
require.NoError(t, err)
require.Empty(t, policiesWithVPPs)
policiesWithVPPs, err = ds.GetPoliciesWithAssociatedVPP(ctx, team1.ID, []uint{p2.ID})
require.NoError(t, err)
require.Len(t, policiesWithVPPs, 1)
require.Equal(t, p2.ID, policiesWithVPPs[0].ID)
require.Equal(t, team1Meta.AdamID, policiesWithVPPs[0].AdamID)
// p2 has associated app but belongs to team1.
policiesWithVPPs, err = ds.GetPoliciesWithAssociatedVPP(ctx, team2.ID, []uint{p2.ID})
require.NoError(t, err)
require.Empty(t, policiesWithVPPs)
p1.VPPAppsTeamsID = ptr.Uint(team1Meta.VPPAppsTeamsID)
err = ds.SavePolicy(ctx, p1, false, false)
require.NoError(t, err)
p1.VPPAppsTeamsID = ptr.Uint(noTeamMeta.VPPAppsTeamsID)
err = ds.SavePolicy(ctx, p1, false, false)
require.Error(t, err, "VPP app is associated with a different team")
p2, err = ds.Policy(ctx, p2.ID)
require.NoError(t, err)
require.NotNil(t, p2.VPPAppsTeamsID)
require.Equal(t, team1Meta.VPPAppsTeamsID, *p2.VPPAppsTeamsID)
policiesWithVPPs, err = ds.GetPoliciesWithAssociatedVPP(ctx, team1.ID, []uint{p1.ID, p2.ID})
require.NoError(t, err)
require.Len(t, policiesWithVPPs, 2)
require.Equal(t, p1.ID, policiesWithVPPs[0].ID)
require.Equal(t, team1Meta.AdamID, policiesWithVPPs[0].AdamID)
require.Equal(t, p2.ID, policiesWithVPPs[1].ID)
require.Equal(t, team1Meta.AdamID, policiesWithVPPs[1].AdamID)
policiesWithVPPs, err = ds.GetPoliciesWithAssociatedVPP(ctx, team2.ID, []uint{p1.ID, p2.ID})
require.NoError(t, err)
require.Empty(t, policiesWithVPPs)
// create another team1 app, this time with an automatic policy
team1App3, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp3", BundleIdentifier: "com.app.vpp3",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp3", Platform: fleet.MacOSPlatform}, AddAutoInstallPolicy: true},
}, &team1.ID)
require.NoError(t, err)
automaticPolicies, err := ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{team1App3.TitleID}, team1.ID)
require.NoError(t, err)
require.Len(t, automaticPolicies, 1)
policyWithVPP, err := ds.Policy(ctx, automaticPolicies[0].ID)
require.NoError(t, err)
require.Equal(t, *policyWithVPP.VPPAppsTeamsID, team1App3.VPPAppTeam.AppTeamID)
require.Equal(t, `SELECT 1 FROM apps WHERE bundle_identifier = 'com.app.vpp3';`, policyWithVPP.Query)
}
func testTeamPoliciesWithScript(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Sierra", "sierra@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) // team2 has no policies
require.NoError(t, err)
// Policy p1 has no associated script.
p1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
SoftwareInstallerID: nil,
})
require.NoError(t, err)
// Create and associate a script to p2.
script, err := ds.NewScript(context.Background(), &fleet.Script{
TeamID: &team1.ID,
Name: "hello-world.sh",
ScriptContents: "echo 'Hello World'",
})
require.NoError(t, err)
require.Nil(t, p1.ScriptID)
p2, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 1;",
ScriptID: &script.ID,
})
require.NoError(t, err)
require.NotNil(t, p2.ScriptID)
require.Equal(t, script.ID, *p2.ScriptID)
// Create p3 as global policy.
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "p3",
Query: "SELECT 1;",
})
require.NoError(t, err)
p2, err = ds.Policy(ctx, p2.ID)
require.NoError(t, err)
require.NotNil(t, p2.ScriptID)
require.Equal(t, script.ID, *p2.ScriptID)
// Policy p4 in "No team" with associated script.
noTeamScript, err := ds.NewScript(context.Background(), &fleet.Script{
Name: "hello-world.sh",
ScriptContents: "echo 'Hello NoTeam'",
})
require.NoError(t, err)
p4, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "p4",
Query: "SELECT 4;",
ScriptID: &noTeamScript.ID,
})
require.NoError(t, err)
_, err = ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "p4",
Query: "SELECT 4;",
ScriptID: &script.ID,
})
require.Error(t, err, "script is associated with a different team")
policiesWithScripts, err := ds.GetPoliciesWithAssociatedScript(ctx, fleet.PolicyNoTeamID, []uint{p4.ID})
require.NoError(t, err)
require.Len(t, policiesWithScripts, 1)
require.Equal(t, p4.ID, policiesWithScripts[0].ID)
policiesWithScripts, err = ds.GetPoliciesWithAssociatedScript(ctx, team1.ID, []uint{})
require.NoError(t, err)
require.Empty(t, policiesWithScripts)
// p1 has no associated scripts.
policiesWithScripts, err = ds.GetPoliciesWithAssociatedScript(ctx, team1.ID, []uint{p1.ID})
require.NoError(t, err)
require.Empty(t, policiesWithScripts)
policiesWithScripts, err = ds.GetPoliciesWithAssociatedScript(ctx, team1.ID, []uint{p2.ID})
require.NoError(t, err)
require.Len(t, policiesWithScripts, 1)
require.Equal(t, p2.ID, policiesWithScripts[0].ID)
require.Equal(t, script.ID, policiesWithScripts[0].ScriptID)
// p2 has associated script but belongs to team1.
policiesWithScripts, err = ds.GetPoliciesWithAssociatedScript(ctx, team2.ID, []uint{p2.ID})
require.NoError(t, err)
require.Empty(t, policiesWithScripts)
p1.ScriptID = ptr.Uint(script.ID)
err = ds.SavePolicy(ctx, p1, false, false)
require.NoError(t, err)
p1.ScriptID = ptr.Uint(noTeamScript.ID)
err = ds.SavePolicy(ctx, p1, false, false)
require.Error(t, err, "script is associated with a different team")
p2, err = ds.Policy(ctx, p2.ID)
require.NoError(t, err)
require.NotNil(t, p2.ScriptID)
require.Equal(t, script.ID, *p2.ScriptID)
policiesWithScripts, err = ds.GetPoliciesWithAssociatedScript(ctx, team1.ID, []uint{p1.ID, p2.ID})
require.NoError(t, err)
require.Len(t, policiesWithScripts, 2)
require.Equal(t, p1.ID, policiesWithScripts[0].ID)
require.Equal(t, script.ID, policiesWithScripts[0].ScriptID)
require.Equal(t, p2.ID, policiesWithScripts[1].ID)
require.Equal(t, script.ID, policiesWithScripts[1].ScriptID)
policiesWithScripts, err = ds.GetPoliciesWithAssociatedScript(ctx, team2.ID, []uint{p1.ID, p2.ID})
require.NoError(t, err)
require.Empty(t, policiesWithScripts)
}
func testApplyPolicySpecWithInstallers(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
newHost := func(name string, teamID *uint, platform string) *fleet.Host {
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String(uuid.New().String()),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(uuid.New().String()),
UUID: uuid.New().String(),
Hostname: name,
TeamID: teamID,
Platform: platform,
})
require.NoError(t, err)
return h
}
host1Team1 := newHost("host1Team1", &team1.ID, "darwin")
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello1"), t.TempDir)
require.NoError(t, err)
installer1ID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1;",
PostInstallScript: "world1",
InstallerFile: tfr1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer1, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer1ID)
require.NoError(t, err)
require.NotNil(t, installer1.TitleID)
tfr2, err := fleet.NewTempFileReader(strings.NewReader("hello2"), t.TempDir)
require.NoError(t, err)
installer2ID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello2",
PreInstallQuery: "SELECT 2;",
PostInstallScript: "world2",
InstallerFile: tfr2,
StorageID: "storage2",
Filename: "file2",
Title: "file2",
Version: "1.0",
Source: "deb_packages",
UserID: user1.ID,
TeamID: &team2.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer2, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer2ID)
require.NoError(t, err)
require.NotNil(t, installer2.TitleID)
tfr3, err := fleet.NewTempFileReader(strings.NewReader("hello3"), t.TempDir)
require.NoError(t, err)
installer3ID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello3",
PreInstallQuery: "SELECT 3;",
PostInstallScript: "world3",
InstallerFile: tfr3,
StorageID: "storage3",
Filename: "file3",
Title: "file3",
Version: "1.0",
Source: "rpm_packages",
UserID: user1.ID,
TeamID: nil,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer3, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer3ID)
require.NoError(t, err)
require.NotNil(t, installer3.TitleID)
// Another installer on team1 to test changing installers.
tfr5, err := fleet.NewTempFileReader(strings.NewReader("hello5"), t.TempDir)
require.NoError(t, err)
installer5ID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello5",
PreInstallQuery: "SELECT 5;",
PostInstallScript: "world5",
InstallerFile: tfr5,
StorageID: "storage5",
Filename: "file5",
Title: "file5",
Version: "1.0",
Source: "programs",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer5, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer5ID)
require.NoError(t, err)
require.NotNil(t, installer5.TitleID)
test.CreateInsertGlobalVPPToken(t, ds)
// create VPP apps
va1, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp1", BundleIdentifier: "com.app.vpp1",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp_app_1", Platform: fleet.MacOSPlatform}},
}, &team1.ID)
require.NoError(t, err)
va2, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp2", BundleIdentifier: "com.app.vpp2",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp_app_2", Platform: fleet.MacOSPlatform}},
}, &team2.ID)
require.NoError(t, err)
va1NoTeam, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp1", BundleIdentifier: "com.app.vpp1",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp_app_1", Platform: fleet.MacOSPlatform}},
}, nil)
require.NoError(t, err)
// Installers cannot be assigned to global policies.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Global policy",
Query: "SELECT 1;",
Description: "Description",
Resolution: "Resolution",
Team: "",
Platform: "darwin",
SoftwareTitleID: installer1.TitleID,
},
})
require.Error(t, err)
require.ErrorIs(t, err, errSoftwareTitleIDOnGlobalPolicy)
// Apply two team policies associated to two installers and a "No team" policy associated to an installer.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: installer1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "Team policy 2",
Query: "SELECT 2;",
Description: "Description 2",
Resolution: "Resolution 2",
Team: "team2",
Platform: "linux",
SoftwareTitleID: installer2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "No team policy 3",
Query: "SELECT 3;",
Description: "Description 3",
Resolution: "Resolution 3",
Team: "No team",
Platform: "linux",
SoftwareTitleID: installer3.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: &va1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 2",
Query: "SELECT 2;",
Description: "Description 2",
Resolution: "Resolution 2",
Team: "team2",
Platform: "linux",
SoftwareTitleID: &va2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP No team policy 3",
Query: "SELECT 3;",
Description: "Description 3",
Resolution: "Resolution 3",
Team: "No team",
Platform: "linux",
SoftwareTitleID: &va1NoTeam.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
team1Policies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team1Policies, 2)
require.NotNil(t, team1Policies[0].SoftwareInstallerID)
require.NotNil(t, team1Policies[1].VPPAppsTeamsID)
policy1Team1 := team1Policies[0]
require.Equal(t, installer1.InstallerID, *team1Policies[0].SoftwareInstallerID)
vppPolicy1Team1 := team1Policies[1]
va1Meta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, &team1.ID, va1.TitleID)
require.NoError(t, err)
require.Equal(t, va1Meta.VPPAppsTeamsID, *vppPolicy1Team1.VPPAppsTeamsID)
team2Policies, _, err := ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team2Policies, 2)
require.NotNil(t, team2Policies[0].SoftwareInstallerID)
require.Equal(t, installer2.InstallerID, *team2Policies[0].SoftwareInstallerID)
vppPolicy2Team2 := team2Policies[1]
va2Meta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, &team2.ID, va2.TitleID)
require.NoError(t, err)
require.Equal(t, va2Meta.VPPAppsTeamsID, *vppPolicy2Team2.VPPAppsTeamsID)
noTeamPolicies, _, err := ds.ListTeamPolicies(ctx, fleet.PolicyNoTeamID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, noTeamPolicies, 2)
require.NotNil(t, noTeamPolicies[0].SoftwareInstallerID)
require.Equal(t, installer3.InstallerID, *noTeamPolicies[0].SoftwareInstallerID)
vppNoTeamPolicy := noTeamPolicies[1]
vNoTeamMeta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, ptr.Uint(0), va1NoTeam.TitleID)
require.NoError(t, err)
require.Equal(t, vNoTeamMeta.VPPAppsTeamsID, *vppNoTeamPolicy.VPPAppsTeamsID)
// Record policy execution on policy1Team1.
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
policy1Team1.ID: ptr.Bool(false),
vppPolicy1Team1.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
// Unset software installer from "Team policy 1" and the VPP policy.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: nil,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: nil,
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
team1Policies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team1Policies, 2)
require.Nil(t, team1Policies[0].SoftwareInstallerID)
require.Nil(t, team1Policies[1].VPPAppsTeamsID)
// Should not clear results because we've cleared not changed/set-new installer.
require.Equal(t, uint(1), team1Policies[0].FailingHostCount)
require.Equal(t, uint(1), team1Policies[1].FailingHostCount)
// Set "Team policy 1" to a software installer on team2.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: installer2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.Error(t, err)
var notFoundErr *common_mysql.NotFoundError
require.ErrorAs(t, err, &notFoundErr)
// Set "Team policy 1" to a VPP app on team2.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: &va2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.Error(t, err)
require.ErrorAs(t, err, &notFoundErr)
// Set "No team policy 3" to a software installer on team2.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "No team policy 3",
Query: "SELECT 3;",
Description: "Description 3",
Resolution: "Resolution 3",
Team: "No team",
Platform: "darwin",
SoftwareTitleID: installer2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.Error(t, err)
require.ErrorAs(t, err, &notFoundErr)
// Set "No Team policy 3" to a VPP app on team2.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "No team policy 3",
Query: "SELECT 3;",
Description: "Description 3",
Resolution: "Resolution 3",
Team: "No team",
Platform: "darwin",
SoftwareTitleID: &va2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.Error(t, err)
require.ErrorAs(t, err, &notFoundErr)
// Set "Team policy 1" to a software title that doesn't exist.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: ptr.Uint(999_999),
Type: fleet.PolicyTypeDynamic,
},
})
require.Error(t, err)
require.ErrorAs(t, err, &notFoundErr)
// Set "No team policy 3" to a software title that doesn't exist.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "No team policy 3",
Query: "SELECT 3;",
Description: "Description 3",
Resolution: "Resolution 3",
Team: "No team",
Platform: "darwin",
SoftwareTitleID: ptr.Uint(999_999),
Type: fleet.PolicyTypeDynamic,
},
})
require.Error(t, err)
require.ErrorAs(t, err, &notFoundErr)
// Unset software installer from "Team policy 2" using 0.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 2",
Query: "SELECT 2;",
Description: "Description 2",
Resolution: "Resolution 2",
Team: "team2",
Platform: "linux",
SoftwareTitleID: ptr.Uint(0),
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
team2Policies, _, err = ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team2Policies, 2)
require.Nil(t, team2Policies[0].SoftwareInstallerID)
require.Equal(t, va2Meta.VPPAppsTeamsID, *team2Policies[1].VPPAppsTeamsID) // stays set since Apply doesn't delete
// Apply team policies associated to two installers (again, with two installers with the same title), and same with VPP apps
tfr4, err := fleet.NewTempFileReader(strings.NewReader("hello3"), t.TempDir)
require.NoError(t, err)
installer4ID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello3",
PreInstallQuery: "SELECT 3;",
PostInstallScript: "world3",
InstallerFile: tfr4,
StorageID: "storage3",
Filename: "file1",
Title: "file1", // same title as installer1.
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team2.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer4, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer4ID)
require.NoError(t, err)
require.NotNil(t, installer2.TitleID)
va4Team2, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp4", BundleIdentifier: "com.app.vpp4",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp_app_4", Platform: fleet.MacOSPlatform}},
}, &team2.ID)
require.NoError(t, err)
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: installer1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "Team policy 2",
Query: "SELECT 2;",
Description: "Description 2",
Resolution: "Resolution 2",
Team: "team2",
Platform: "linux",
SoftwareTitleID: installer4.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: &va1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 2",
Query: "SELECT 2;",
Description: "Description 2",
Resolution: "Resolution 2",
Team: "team2",
Platform: "linux",
SoftwareTitleID: &va4Team2.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
team1Policies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team1Policies, 2)
require.NotNil(t, team1Policies[0].SoftwareInstallerID)
require.Equal(t, installer1.InstallerID, *team1Policies[0].SoftwareInstallerID)
require.NotNil(t, team1Policies[1].VPPAppsTeamsID)
require.NoError(t, err)
require.Equal(t, va1Meta.VPPAppsTeamsID, *team1Policies[1].VPPAppsTeamsID)
// Should clear results because we've are setting an installer.
require.Equal(t, uint(0), team1Policies[0].FailingHostCount)
countBiggerThanZero := true
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(ctx, q,
&countBiggerThanZero,
`SELECT COUNT(*) > 0 FROM policy_membership WHERE policy_id = ?`,
team1Policies[0].ID,
)
})
require.False(t, countBiggerThanZero)
team2Policies, _, err = ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team2Policies, 2)
require.NotNil(t, team2Policies[0].SoftwareInstallerID)
require.Equal(t, installer4.InstallerID, *team2Policies[0].SoftwareInstallerID)
require.NotNil(t, team2Policies[1].VPPAppsTeamsID)
va4Team2Meta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, &team2.ID, va4Team2.TitleID)
require.NoError(t, err)
require.Equal(t, va4Team2Meta.VPPAppsTeamsID, *team2Policies[1].VPPAppsTeamsID)
// Record policy execution on policy1Team1 + VPP equivalent to test that setting the same installer won't clear results.
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
policy1Team1.ID: ptr.Bool(false),
vppPolicy1Team1.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: installer1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: &va1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
team1Policies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team1Policies, 2)
require.Equal(t, uint(1), team1Policies[0].FailingHostCount)
countBiggerThanZero = false
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(ctx, q,
&countBiggerThanZero,
`SELECT COUNT(*) > 0 FROM policy_membership WHERE policy_id = ?`,
team1Policies[0].ID,
)
})
require.True(t, countBiggerThanZero)
require.Equal(t, uint(1), team1Policies[1].FailingHostCount)
countBiggerThanZero = false
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(ctx, q,
&countBiggerThanZero,
`SELECT COUNT(*) > 0 FROM policy_membership WHERE policy_id = ?`,
team1Policies[1].ID,
)
})
require.True(t, countBiggerThanZero)
va4Team1, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp4", BundleIdentifier: "com.app.vpp4",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_vpp_app_4", Platform: fleet.MacOSPlatform}},
}, &team1.ID)
require.NoError(t, err)
// Now change the installer, should clear results.
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: "Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: installer5.TitleID,
Type: fleet.PolicyTypeDynamic,
},
{
Name: "VPP Team policy 1",
Query: "SELECT 1;",
Description: "Description 1",
Resolution: "Resolution 1",
Team: "team1",
Platform: "darwin",
SoftwareTitleID: &va4Team1.TitleID,
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
team1Policies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, team1Policies, 2)
require.Equal(t, uint(0), team1Policies[0].FailingHostCount)
countBiggerThanZero = true
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(ctx, q,
&countBiggerThanZero,
`SELECT COUNT(*) > 0 FROM policy_membership WHERE policy_id = ?`,
team1Policies[0].ID,
)
})
require.False(t, countBiggerThanZero)
require.Equal(t, uint(0), team1Policies[1].FailingHostCount)
countBiggerThanZero = true
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(ctx, q,
&countBiggerThanZero,
`SELECT COUNT(*) > 0 FROM policy_membership WHERE policy_id = ?`,
team1Policies[1].ID,
)
})
require.False(t, countBiggerThanZero)
}
func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
newHost := func(name string, teamID *uint, platform string) *fleet.Host {
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String(uuid.New().String()),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(uuid.New().String()),
UUID: uuid.New().String(),
Hostname: name,
TeamID: teamID,
Platform: platform,
})
require.NoError(t, err)
return h
}
host0NoTeam := newHost("host0NoTeam", nil, "darwin")
host1Team1 := newHost("host1Team1", &team1.ID, "darwin")
host2Team1 := newHost("host2Team1", &team1.ID, "linux")
host3Team2 := newHost("host1Team1", &team2.ID, "windows")
host5NoTeam := newHost("host5NoTeam", nil, "windows")
policy0NoTeam, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "policy0NoTeam",
Query: "SELECT 0;",
})
require.NoError(t, err)
require.NotNil(t, policy0NoTeam.TeamID)
require.Equal(t, fleet.PolicyNoTeamID, *policy0NoTeam.TeamID)
tp, err := ds.TeamPolicy(ctx, fleet.PolicyNoTeamID, policy0NoTeam.ID)
require.NoError(t, err)
require.Equal(t, tp, policy0NoTeam)
policy1Team1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy1Team1",
Query: "SELECT 1;",
})
require.NoError(t, err)
policy2Team2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy2Team2",
Query: "SELECT 2;",
})
require.NoError(t, err)
policy3NoTeam, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &user1.ID, fleet.PolicyPayload{
Name: "policy3NoTeam",
Query: "SELECT 3;",
})
require.NoError(t, err)
policy4Team2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "policy4Team2",
Query: "SELECT 4;",
})
require.NoError(t, err)
globalPolicy1, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "globalPolicy1",
Query: "SELECT gp1;",
})
require.NoError(t, err)
globalPolicy2, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "globalPolicy2",
Query: "SELECT gp2;",
})
require.NoError(t, err)
// Results for host0NoTeam
err = ds.RecordPolicyQueryExecutions(ctx, host0NoTeam, map[uint]*bool{
globalPolicy1.ID: ptr.Bool(false),
globalPolicy2.ID: ptr.Bool(false),
policy0NoTeam.ID: ptr.Bool(true),
policy3NoTeam.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
// Results for host1Team1
err = ds.RecordPolicyQueryExecutions(ctx, host1Team1, map[uint]*bool{
globalPolicy1.ID: ptr.Bool(true),
globalPolicy2.ID: nil, // failed to execute, e.g. typo on SQL.
policy1Team1.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
// Results for host2Team1
err = ds.RecordPolicyQueryExecutions(ctx, host2Team1, map[uint]*bool{
globalPolicy1.ID: ptr.Bool(false),
globalPolicy2.ID: ptr.Bool(true),
policy1Team1.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
// Results for host3Team2
err = ds.RecordPolicyQueryExecutions(ctx, host3Team2, map[uint]*bool{
globalPolicy1.ID: ptr.Bool(true),
policy2Team2.ID: ptr.Bool(true),
policy4Team2.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
// Results for host5NoTeam
err = ds.RecordPolicyQueryExecutions(ctx, host5NoTeam, map[uint]*bool{
globalPolicy1.ID: ptr.Bool(true),
globalPolicy2.ID: ptr.Bool(false),
policy0NoTeam.ID: ptr.Bool(false),
policy3NoTeam.ID: ptr.Bool(false),
}, time.Now(), false)
require.NoError(t, err)
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
// Tests on global domain.
globalPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, globalPolicies, 2)
require.Equal(t, globalPolicy1.ID, globalPolicies[0].ID)
require.Equal(t, uint(2), globalPolicies[0].FailingHostCount)
require.Equal(t, uint(3), globalPolicies[0].PassingHostCount)
require.Equal(t, globalPolicy2.ID, globalPolicies[1].ID)
require.Equal(t, uint(2), globalPolicies[1].FailingHostCount)
require.Equal(t, uint(1), globalPolicies[1].PassingHostCount)
ids := make([]uint, 0, len(globalPolicies))
for _, globalPolicy := range globalPolicies {
p, err := ds.Policy(ctx, globalPolicy.ID)
require.NoError(t, err)
require.Equal(t, p, globalPolicy)
ids = append(ids, globalPolicy.ID)
}
c, err := ds.CountPolicies(ctx, nil, "", "")
require.NoError(t, err)
require.Equal(t, 2, c)
globalPoliciesByID, err := ds.PoliciesByID(ctx, ids)
require.NoError(t, err)
require.Len(t, globalPoliciesByID, 2)
require.Equal(t, globalPoliciesByID[globalPolicies[0].ID], globalPolicies[0])
require.Equal(t, globalPoliciesByID[globalPolicies[1].ID], globalPolicies[1])
// Tests on team1 domain.
teamPolicies, inheritedPolicies, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
require.Equal(t, policy1Team1.ID, teamPolicies[0].ID)
require.Equal(t, uint(1), teamPolicies[0].FailingHostCount)
require.Equal(t, uint(1), teamPolicies[0].PassingHostCount)
require.Len(t, inheritedPolicies, 2)
require.Equal(t, globalPolicy1.ID, inheritedPolicies[0].ID)
require.Equal(t, uint(1), inheritedPolicies[0].FailingHostCount)
require.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount)
require.Equal(t, globalPolicy2.ID, inheritedPolicies[1].ID)
require.Equal(t, uint(0), inheritedPolicies[1].FailingHostCount)
require.Equal(t, uint(1), inheritedPolicies[1].PassingHostCount)
ids = make([]uint, 0, len(teamPolicies))
for _, teamPolicy := range teamPolicies {
p, err := ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Equal(t, p, teamPolicy)
ids = append(ids, teamPolicy.ID)
}
teamPoliciesByID, err := ds.PoliciesByID(ctx, ids)
require.NoError(t, err)
require.Len(t, teamPoliciesByID, 1)
require.Equal(t, teamPoliciesByID[teamPolicies[0].ID], teamPolicies[0])
c, err = ds.CountMergedTeamPolicies(ctx, team1.ID, "", "")
require.NoError(t, err)
require.Equal(t, 3, c)
c, err = ds.CountPolicies(ctx, &team1.ID, "", "")
require.NoError(t, err)
require.Equal(t, 1, c)
mergedTeamPolicies, err := ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, mergedTeamPolicies, 3)
require.Equal(t, policy1Team1.ID, mergedTeamPolicies[0].ID)
require.Equal(t, uint(1), mergedTeamPolicies[0].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[0].PassingHostCount)
require.Equal(t, globalPolicy1.ID, mergedTeamPolicies[1].ID)
require.Equal(t, uint(1), mergedTeamPolicies[1].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[1].PassingHostCount)
require.Equal(t, globalPolicy2.ID, mergedTeamPolicies[2].ID)
require.Equal(t, uint(0), mergedTeamPolicies[2].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[2].PassingHostCount)
// Tests on team2 domain.
teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
require.Equal(t, policy2Team2.ID, teamPolicies[0].ID)
require.Equal(t, uint(0), teamPolicies[0].FailingHostCount)
require.Equal(t, uint(1), teamPolicies[0].PassingHostCount)
require.Equal(t, policy4Team2.ID, teamPolicies[1].ID)
require.Equal(t, uint(1), teamPolicies[1].FailingHostCount)
require.Equal(t, uint(0), teamPolicies[1].PassingHostCount)
require.Len(t, inheritedPolicies, 2)
require.Equal(t, globalPolicy1.ID, inheritedPolicies[0].ID)
require.Equal(t, uint(0), inheritedPolicies[0].FailingHostCount)
require.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount)
require.Equal(t, globalPolicy2.ID, inheritedPolicies[1].ID)
require.Equal(t, uint(0), inheritedPolicies[1].FailingHostCount)
require.Equal(t, uint(0), inheritedPolicies[1].PassingHostCount)
ids = make([]uint, 0, len(teamPolicies))
for _, teamPolicy := range teamPolicies {
p, err := ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Equal(t, p, teamPolicy)
ids = append(ids, teamPolicy.ID)
}
teamPoliciesByID, err = ds.PoliciesByID(ctx, ids)
require.NoError(t, err)
require.Len(t, teamPoliciesByID, 2)
require.Equal(t, teamPoliciesByID[teamPolicies[0].ID], teamPolicies[0])
require.Equal(t, teamPoliciesByID[teamPolicies[1].ID], teamPolicies[1])
c, err = ds.CountMergedTeamPolicies(ctx, team2.ID, "", "")
require.NoError(t, err)
require.Equal(t, 4, c)
c, err = ds.CountPolicies(ctx, &team2.ID, "", "")
require.NoError(t, err)
require.Equal(t, 2, c)
mergedTeamPolicies, err = ds.ListMergedTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, mergedTeamPolicies, 4)
require.Equal(t, policy2Team2.ID, mergedTeamPolicies[0].ID)
require.Equal(t, uint(0), mergedTeamPolicies[0].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[0].PassingHostCount)
require.Equal(t, policy4Team2.ID, mergedTeamPolicies[1].ID)
require.Equal(t, uint(1), mergedTeamPolicies[1].FailingHostCount)
require.Equal(t, uint(0), mergedTeamPolicies[1].PassingHostCount)
require.Equal(t, globalPolicy1.ID, mergedTeamPolicies[2].ID)
require.Equal(t, uint(0), mergedTeamPolicies[2].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[2].PassingHostCount)
require.Equal(t, globalPolicy2.ID, mergedTeamPolicies[3].ID)
require.Equal(t, uint(0), mergedTeamPolicies[3].FailingHostCount)
require.Equal(t, uint(0), mergedTeamPolicies[3].PassingHostCount)
// Tests on "No team" domain.
teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, fleet.PolicyNoTeamID, fleet.ListOptions{}, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
require.Equal(t, policy0NoTeam.ID, teamPolicies[0].ID)
require.Equal(t, uint(1), teamPolicies[0].FailingHostCount)
require.Equal(t, uint(1), teamPolicies[0].PassingHostCount)
require.Equal(t, policy3NoTeam.ID, teamPolicies[1].ID)
require.Equal(t, uint(2), teamPolicies[1].FailingHostCount)
require.Equal(t, uint(0), teamPolicies[1].PassingHostCount)
require.Len(t, inheritedPolicies, 2)
require.Equal(t, globalPolicy1.ID, inheritedPolicies[0].ID)
require.Equal(t, uint(1), inheritedPolicies[0].FailingHostCount)
require.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount)
require.Equal(t, globalPolicy2.ID, inheritedPolicies[1].ID)
require.Equal(t, uint(2), inheritedPolicies[1].FailingHostCount)
require.Equal(t, uint(0), inheritedPolicies[1].PassingHostCount)
ids = make([]uint, 0, len(teamPolicies))
for _, teamPolicy := range teamPolicies {
p, err := ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Equal(t, p, teamPolicy)
ids = append(ids, teamPolicy.ID)
}
teamPoliciesByID, err = ds.PoliciesByID(ctx, ids)
require.NoError(t, err)
require.Len(t, teamPoliciesByID, 2)
require.Equal(t, teamPoliciesByID[teamPolicies[0].ID], teamPolicies[0])
require.Equal(t, teamPoliciesByID[teamPolicies[1].ID], teamPolicies[1])
c, err = ds.CountMergedTeamPolicies(ctx, fleet.PolicyNoTeamID, "", "")
require.NoError(t, err)
require.Equal(t, 4, c)
c, err = ds.CountPolicies(ctx, ptr.Uint(fleet.PolicyNoTeamID), "", "")
require.NoError(t, err)
require.Equal(t, 2, c)
mergedTeamPolicies, err = ds.ListMergedTeamPolicies(ctx, fleet.PolicyNoTeamID, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, mergedTeamPolicies, 4)
require.Equal(t, policy0NoTeam.ID, mergedTeamPolicies[0].ID)
require.Equal(t, uint(1), mergedTeamPolicies[0].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[0].PassingHostCount)
require.Equal(t, policy3NoTeam.ID, mergedTeamPolicies[1].ID)
require.Equal(t, uint(2), mergedTeamPolicies[1].FailingHostCount)
require.Equal(t, uint(0), mergedTeamPolicies[1].PassingHostCount)
require.Equal(t, globalPolicy1.ID, mergedTeamPolicies[2].ID)
require.Equal(t, uint(1), mergedTeamPolicies[2].FailingHostCount)
require.Equal(t, uint(1), mergedTeamPolicies[2].PassingHostCount)
require.Equal(t, globalPolicy2.ID, mergedTeamPolicies[3].ID)
require.Equal(t, uint(2), mergedTeamPolicies[3].FailingHostCount)
require.Equal(t, uint(0), mergedTeamPolicies[3].PassingHostCount)
// Test ListPoliciesForHost and PolicyQueriesForHost for host0NoTeam.
host0Policies, err := ds.ListPoliciesForHost(ctx, host0NoTeam)
require.NoError(t, err)
require.Len(t, host0Policies, 4)
require.Equal(t, globalPolicy1.ID, host0Policies[0].ID)
require.Equal(t, "fail", host0Policies[0].Response)
require.Equal(t, globalPolicy2.ID, host0Policies[1].ID)
require.Equal(t, "fail", host0Policies[1].Response)
require.Equal(t, policy3NoTeam.ID, host0Policies[2].ID)
require.Equal(t, "fail", host0Policies[2].Response)
require.Equal(t, policy0NoTeam.ID, host0Policies[3].ID)
require.Equal(t, "pass", host0Policies[3].Response)
host0PolicyQueries, err := ds.PolicyQueriesForHost(ctx, host0NoTeam)
require.NoError(t, err)
require.Len(t, host0PolicyQueries, 4)
require.Equal(t, "SELECT gp1;", host0PolicyQueries[strconv.FormatUint(uint64(globalPolicy1.ID), 10)])
require.Equal(t, "SELECT gp2;", host0PolicyQueries[strconv.FormatUint(uint64(globalPolicy2.ID), 10)])
require.Equal(t, "SELECT 0;", host0PolicyQueries[strconv.FormatUint(uint64(policy0NoTeam.ID), 10)])
require.Equal(t, "SELECT 3;", host0PolicyQueries[strconv.FormatUint(uint64(policy3NoTeam.ID), 10)])
// Test ListPoliciesForHost and PolicyQueriesForHost for host1Team1.
host1Policies, err := ds.ListPoliciesForHost(ctx, host1Team1)
require.NoError(t, err)
require.Len(t, host1Policies, 3)
require.Equal(t, globalPolicy2.ID, host1Policies[0].ID)
require.Equal(t, "", host1Policies[0].Response)
require.Equal(t, globalPolicy1.ID, host1Policies[1].ID)
require.Equal(t, "pass", host1Policies[1].Response)
require.Equal(t, policy1Team1.ID, host1Policies[2].ID)
require.Equal(t, "pass", host1Policies[2].Response)
host1PolicyQueries, err := ds.PolicyQueriesForHost(ctx, host1Team1)
require.NoError(t, err)
require.Len(t, host1PolicyQueries, 3)
require.Equal(t, "SELECT gp1;", host1PolicyQueries[strconv.FormatUint(uint64(globalPolicy1.ID), 10)])
require.Equal(t, "SELECT gp2;", host1PolicyQueries[strconv.FormatUint(uint64(globalPolicy2.ID), 10)])
require.Equal(t, "SELECT 1;", host1PolicyQueries[strconv.FormatUint(uint64(policy1Team1.ID), 10)])
// Test ListPoliciesForHost and PolicyQueriesForHost for host2Team1.
host2Policies, err := ds.ListPoliciesForHost(ctx, host2Team1)
require.NoError(t, err)
require.Len(t, host2Policies, 3)
require.Equal(t, globalPolicy1.ID, host2Policies[0].ID)
require.Equal(t, "fail", host2Policies[0].Response)
require.Equal(t, policy1Team1.ID, host2Policies[1].ID)
require.Equal(t, "fail", host2Policies[1].Response)
require.Equal(t, globalPolicy2.ID, host2Policies[2].ID)
require.Equal(t, "pass", host2Policies[2].Response)
host2PolicyQueries, err := ds.PolicyQueriesForHost(ctx, host2Team1)
require.NoError(t, err)
require.Len(t, host2PolicyQueries, 3)
require.Equal(t, "SELECT gp1;", host2PolicyQueries[strconv.FormatUint(uint64(globalPolicy1.ID), 10)])
require.Equal(t, "SELECT gp2;", host2PolicyQueries[strconv.FormatUint(uint64(globalPolicy2.ID), 10)])
require.Equal(t, "SELECT 1;", host2PolicyQueries[strconv.FormatUint(uint64(policy1Team1.ID), 10)])
// Test ListPoliciesForHost and PolicyQueriesForHost for host3Team2.
host3Policies, err := ds.ListPoliciesForHost(ctx, host3Team2)
require.NoError(t, err)
require.Len(t, host3Policies, 4)
require.Equal(t, policy4Team2.ID, host3Policies[0].ID)
require.Equal(t, "fail", host3Policies[0].Response)
require.Equal(t, globalPolicy2.ID, host3Policies[1].ID)
require.Equal(t, "", host3Policies[1].Response)
require.Equal(t, globalPolicy1.ID, host3Policies[2].ID)
require.Equal(t, "pass", host3Policies[2].Response)
require.Equal(t, policy2Team2.ID, host3Policies[3].ID)
require.Equal(t, "pass", host3Policies[3].Response)
host3PolicyQueries, err := ds.PolicyQueriesForHost(ctx, host3Team2)
require.NoError(t, err)
require.Len(t, host3PolicyQueries, 4)
require.Equal(t, "SELECT gp1;", host3PolicyQueries[strconv.FormatUint(uint64(globalPolicy1.ID), 10)])
require.Equal(t, "SELECT gp2;", host3PolicyQueries[strconv.FormatUint(uint64(globalPolicy2.ID), 10)])
require.Equal(t, "SELECT 2;", host3PolicyQueries[strconv.FormatUint(uint64(policy2Team2.ID), 10)])
require.Equal(t, "SELECT 4;", host3PolicyQueries[strconv.FormatUint(uint64(policy4Team2.ID), 10)])
// Test ListPoliciesForHost and PolicyQueriesForHost for host5NoTeam.
host5Policies, err := ds.ListPoliciesForHost(ctx, host5NoTeam)
require.NoError(t, err)
require.Len(t, host5Policies, 4)
require.Equal(t, globalPolicy2.ID, host5Policies[0].ID)
require.Equal(t, "fail", host5Policies[0].Response)
require.Equal(t, policy0NoTeam.ID, host5Policies[1].ID)
require.Equal(t, "fail", host5Policies[1].Response)
require.Equal(t, policy3NoTeam.ID, host5Policies[2].ID)
require.Equal(t, "fail", host5Policies[2].Response)
require.Equal(t, globalPolicy1.ID, host5Policies[3].ID)
require.Equal(t, "pass", host5Policies[3].Response)
host5PolicyQueries, err := ds.PolicyQueriesForHost(ctx, host5NoTeam)
require.NoError(t, err)
require.Len(t, host5PolicyQueries, 4)
require.Equal(t, "SELECT gp1;", host5PolicyQueries[strconv.FormatUint(uint64(globalPolicy1.ID), 10)])
require.Equal(t, "SELECT gp2;", host5PolicyQueries[strconv.FormatUint(uint64(globalPolicy2.ID), 10)])
require.Equal(t, "SELECT 0;", host5PolicyQueries[strconv.FormatUint(uint64(policy0NoTeam.ID), 10)])
require.Equal(t, "SELECT 3;", host5PolicyQueries[strconv.FormatUint(uint64(policy3NoTeam.ID), 10)])
}
func testPoliciesBySoftwareTitleID(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
policy1 := newTestPolicy(t, ds, user1, "policy 1", "darwin", &team1.ID)
policy2 := newTestPolicy(t, ds, user1, "policy 2", "darwin", &team2.ID)
// Get policies for an invalid title ID
policies, err := ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{999}, team1.ID)
require.NoError(t, err)
require.Empty(t, policies)
installer, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
// Associate an installer to policy 1 on team 1.
installer1ID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy1.SoftwareInstallerID = ptr.Uint(installer1ID)
err = ds.SavePolicy(context.Background(), policy1, false, false)
require.NoError(t, err)
// Associate an installer to policy 2 on team 2.
installer2ID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage2",
Filename: "file2",
Title: "file2",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team2.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy2.SoftwareInstallerID = ptr.Uint(installer2ID)
err = ds.SavePolicy(context.Background(), policy2, false, false)
require.NoError(t, err)
// get the software installer metadata as we will need the associated software title ids.
installer1, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer1ID)
require.NoError(t, err)
require.NotNil(t, installer1.TitleID)
installer2, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer2ID)
require.NoError(t, err)
require.NotNil(t, installer2.TitleID)
// software title 1 should have policy 1 when filtering by team 1
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer1.TitleID}, team1.ID)
require.NoError(t, err)
require.Len(t, policies, 1)
require.Equal(t, policy1.ID, policies[0].ID)
require.Equal(t, policy1.Name, policies[0].Name)
// software title 1 should not have any policies when filtering by team 2
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer1.TitleID}, team2.ID)
require.NoError(t, err)
require.Len(t, policies, 0)
// software title 2 should have policy 2 when filtering by team 2
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer2.TitleID}, team2.ID)
require.NoError(t, err)
require.Len(t, policies, 1)
require.Equal(t, policy2.ID, policies[0].ID)
require.Equal(t, policy2.Name, policies[0].Name)
// software title 2 should not have any policies when filtering by team 1
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer2.TitleID}, team1.ID)
require.NoError(t, err)
require.Len(t, policies, 0)
// software title 2 should not have any policies when filtering by no team
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer2.TitleID}, 0)
require.NoError(t, err)
require.Len(t, policies, 0)
// Associate a couple of installers to policy 3 on no team.
installer3ID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello noteam",
PreInstallQuery: "SELECT 1 from noteam",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage3noteam",
Filename: "file3noteam",
Title: "file3noteam",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: nil,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer4ID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello noteam",
PreInstallQuery: "SELECT 1 from noteam",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage4noteam",
Filename: "file4noteam",
Title: "file4noteam",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: nil,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy3 := newTestPolicy(t, ds, user1, "policy 3", "darwin", ptr.Uint(0))
policy3.SoftwareInstallerID = ptr.Uint(installer3ID)
err = ds.SavePolicy(context.Background(), policy3, false, false)
require.NoError(t, err)
policy4 := newTestPolicy(t, ds, user1, "policy 4", "darwin", ptr.Uint(0))
policy4.SoftwareInstallerID = ptr.Uint(installer4ID)
err = ds.SavePolicy(context.Background(), policy4, false, false)
require.NoError(t, err)
installer3, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer3ID)
require.NoError(t, err)
require.NotNil(t, installer3.TitleID)
installer4, err := ds.GetSoftwareInstallerMetadataByID(ctx, installer4ID)
require.NoError(t, err)
require.NotNil(t, installer3.TitleID)
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer3.TitleID, *installer4.TitleID}, 0)
require.NoError(t, err)
require.Len(t, policies, 2)
expected := map[uint]fleet.AutomaticInstallPolicy{
policy3.ID: {ID: policy3.ID, Name: policy3.Name, TitleID: *installer3.TitleID, Type: fleet.PolicyTypeDynamic},
policy4.ID: {ID: policy4.ID, Name: policy4.Name, TitleID: *installer4.TitleID, Type: fleet.PolicyTypeDynamic},
}
for _, got := range policies {
require.Equal(t, expected[got.ID], got)
}
// performance test for 50_000 title ids, ensure batching works
megaTitleIDs := make([]uint, 0, 50_000)
megaTitleIDs = append(megaTitleIDs, *installer3.TitleID)
for i := uint(0); i < (50_000 - 2); i++ {
megaTitleIDs = append(megaTitleIDs, *installer4.TitleID+i+1)
}
megaTitleIDs = append(megaTitleIDs, *installer4.TitleID)
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, megaTitleIDs, 0)
require.NoError(t, err)
require.Len(t, policies, 2)
// "No team" titles should not have any policies when filtering by team 1
policies, err = ds.getPoliciesBySoftwareTitleIDs(ctx, []uint{*installer3.TitleID, *installer4.TitleID}, 1)
require.NoError(t, err)
require.Len(t, policies, 0)
}
func testClearAutoInstallPolicyStatusForHost(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1" + t.Name()})
require.NoError(t, err)
test.CreateInsertGlobalVPPToken(t, ds)
// create a regular policy
policy1 := newTestPolicy(t, ds, user1, "policy 1"+t.Name(), "darwin", &team1.ID)
// create an automatic install policy
policy2 := newTestPolicy(t, ds, user1, "policy 2"+t.Name(), "darwin", &team1.ID)
policy3 := newTestPolicy(t, ds, user1, "policy 3"+t.Name(), "darwin", &team1.ID)
installer, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installer1ID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
InstallerFile: installer,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policy2.SoftwareInstallerID = ptr.Uint(installer1ID)
err = ds.SavePolicy(context.Background(), policy2, false, false)
require.NoError(t, err)
team1App, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp1", BundleIdentifier: "com.app.appy",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_app", Platform: fleet.MacOSPlatform}},
}, &team1.ID)
require.NoError(t, err)
team1Meta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, &team1.ID, team1App.TitleID)
require.NoError(t, err)
policy3.VPPAppsTeamsID = ptr.Uint(team1Meta.VPPAppsTeamsID)
err = ds.SavePolicy(context.Background(), policy3, false, false)
require.NoError(t, err)
// create a host
host, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String(uuid.New().String()),
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(uuid.New().String()),
UUID: uuid.New().String(),
Hostname: "host" + t.Name(),
TeamID: &team1.ID,
Platform: "darwin",
})
require.NoError(t, err)
// record a policy run for both policies
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
policy1.ID: ptr.Bool(true),
policy2.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
policy3.ID: ptr.Bool(false), // software isn't installed on host, so Fleet should install it
}, time.Now(), false)
require.NoError(t, err)
hostPolicies, err := ds.ListPoliciesForHost(ctx, host)
require.NoError(t, err)
require.Len(t, hostPolicies, 3)
sort.Slice(hostPolicies, func(i, j int) bool {
return hostPolicies[i].ID < hostPolicies[j].ID
})
require.Equal(t, hostPolicies[0].Response, "pass")
require.Equal(t, hostPolicies[1].Response, "fail")
require.Equal(t, hostPolicies[2].Response, "fail")
// clear status for the installer automatic install policy
err = ds.ClearSoftwareInstallerAutoInstallPolicyStatusForHosts(ctx, installer1ID, []uint{host.ID})
require.NoError(t, err)
// the status should be NULL for the automatic install policy but not the "regular" one
hostPolicies, err = ds.ListPoliciesForHost(ctx, host)
require.NoError(t, err)
require.Len(t, hostPolicies, 3)
sort.Slice(hostPolicies, func(i, j int) bool {
return hostPolicies[i].ID < hostPolicies[j].ID
})
require.Equal(t, hostPolicies[0].Response, "pass")
require.Empty(t, hostPolicies[1].Response)
// policy for VPP app should still be "fail"
require.Equal(t, hostPolicies[2].Response, "fail")
// clear status for the vpp app automatic install policy
err = ds.ClearVPPAppAutoInstallPolicyStatusForHosts(ctx, team1Meta.VPPAppsTeamsID, []uint{host.ID})
require.NoError(t, err)
// the status should be NULL for the automatic install policy but not the "regular" one
hostPolicies, err = ds.ListPoliciesForHost(ctx, host)
require.NoError(t, err)
require.Len(t, hostPolicies, 3)
sort.Slice(hostPolicies, func(i, j int) bool {
return hostPolicies[i].ID < hostPolicies[j].ID
})
require.Equal(t, hostPolicies[0].Response, "pass")
require.Empty(t, hostPolicies[1].Response)
require.Empty(t, hostPolicies[2].Response)
}
func testPolicyLabels(t *testing.T, ds *Datastore) {
ctx := context.Background()
assertPolicies := func(t *testing.T, havePolicies []*fleet.HostPolicy, wantPolicies []*fleet.Policy, hostName string) {
haveMap := map[uint]string{}
wantMap := map[uint]string{}
missingPolicies := []string{}
extraPolicies := []string{}
for _, policy := range havePolicies {
haveMap[policy.ID] = policy.Name
}
for _, policy := range wantPolicies {
wantMap[policy.ID] = policy.Name
}
for _, wantPolicy := range wantPolicies {
if _, ok := haveMap[wantPolicy.ID]; !ok {
missingPolicies = append(missingPolicies, wantPolicy.Name)
}
}
for _, havePolicy := range havePolicies {
if _, ok := wantMap[havePolicy.ID]; !ok {
extraPolicies = append(extraPolicies, havePolicy.Name)
}
}
if len(missingPolicies) > 0 || len(extraPolicies) > 0 {
t.Errorf("%s missing policies: %q, extra policies: %q", hostName, missingPolicies, extraPolicies)
}
}
assertQueries := func(t *testing.T, havePolicies map[string]string, wantPolicies []*fleet.Policy, hostName string) {
haveMap := map[uint]string{}
wantMap := map[uint]string{}
missingPolicies := []string{}
extraPolicies := []string{}
for policyID := range havePolicies {
i, err := strconv.Atoi(policyID)
require.NoError(t, err)
haveMap[uint(i)] = policyID //nolint:gosec // dismiss G115
}
for _, policy := range wantPolicies {
wantMap[policy.ID] = policy.Name
}
for _, wantPolicy := range wantPolicies {
if _, ok := haveMap[wantPolicy.ID]; !ok {
missingPolicies = append(missingPolicies, wantPolicy.Name)
}
}
for havePolicy, haveStr := range haveMap {
if _, ok := wantMap[havePolicy]; !ok {
extraPolicies = append(extraPolicies, haveStr)
}
}
if len(missingPolicies) > 0 || len(extraPolicies) > 0 {
t.Errorf("%s missing policies: %q, extra policies: %q", hostName, missingPolicies, extraPolicies)
}
}
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
label1, err := ds.NewLabel(ctx, &fleet.Label{Name: "label1"})
require.NoError(t, err)
label2, err := ds.NewLabel(ctx, &fleet.Label{Name: "label2"})
require.NoError(t, err)
hostNoLabels := test.NewHost(t, ds, "host-no-labels", "10.0.0.1", "key1", "uuid1", time.Now())
hostLabel1 := test.NewHost(t, ds, "host-label1", "10.0.0.2", "key2", "uuid2", time.Now())
hostLabel2 := test.NewHost(t, ds, "host-label2", "10.0.0.3", "key3", "uuid3", time.Now())
hostLabelBoth := test.NewHost(t, ds, "host-label-both", "10.0.0.4", "key4", "uuid4", time.Now())
// Apply da labels
require.NoError(t, ds.AddLabelsToHost(ctx, hostLabel1.ID, []uint{label1.ID}))
require.NoError(t, ds.AddLabelsToHost(ctx, hostLabel2.ID, []uint{label2.ID}))
require.NoError(t, ds.AddLabelsToHost(ctx, hostLabelBoth.ID, []uint{label1.ID, label2.ID}))
// create our policies
policyNoLabel := newTestPolicy(t, ds, user1, "policy no label", "", nil)
policyIncludeLabel1 := newTestPolicy(t, ds, user1, "policy include label1", "", nil)
policyIncludeLabel1.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: label1.Name}}
require.NoError(t, ds.SavePolicy(ctx, policyIncludeLabel1, false, false))
policyExcludeLabel2 := newTestPolicy(t, ds, user1, "policy exclude label2", "", nil)
policyExcludeLabel2.LabelsExcludeAny = []fleet.LabelIdent{{LabelName: label2.Name}}
require.NoError(t, ds.SavePolicy(ctx, policyExcludeLabel2, false, false))
policyIncludeBoth := newTestPolicy(t, ds, user1, "policy include both", "", nil)
policyIncludeBoth.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: label1.Name}, {LabelName: label2.Name}}
require.NoError(t, ds.SavePolicy(ctx, policyIncludeBoth, false, false))
policyExcludeBoth := newTestPolicy(t, ds, user1, "policy exclude both", "", nil)
policyExcludeBoth.LabelsExcludeAny = []fleet.LabelIdent{{LabelName: label1.Name}, {LabelName: label2.Name}}
require.NoError(t, ds.SavePolicy(ctx, policyExcludeBoth, false, false))
// The testing grid of truth
//
// | hosts \ policies | No labels | include 1 | exclude 2 | include both | exclude both |
// |------------------+-----------+-----------+-----------+--------------+--------------|
// | no label | X | | X | | X |
// | label 1 | X | X | X | X | |
// | label 2 | X | | | X | |
// | label both | X | X | | X | |
tcs := []struct {
Host *fleet.Host
Policies []*fleet.Policy
}{
{
Host: hostNoLabels,
Policies: []*fleet.Policy{
policyNoLabel,
policyExcludeLabel2,
policyExcludeBoth,
},
},
{
Host: hostLabel1,
Policies: []*fleet.Policy{
policyNoLabel,
policyIncludeLabel1,
policyExcludeLabel2,
policyIncludeBoth,
},
},
{
Host: hostLabel2,
Policies: []*fleet.Policy{
policyNoLabel,
policyIncludeBoth,
},
},
{
Host: hostLabelBoth,
Policies: []*fleet.Policy{
policyNoLabel,
policyIncludeLabel1,
policyIncludeBoth,
},
},
}
for _, tc := range tcs {
policies, err := ds.ListPoliciesForHost(ctx, tc.Host)
require.NoError(t, err)
assertPolicies(t, policies, tc.Policies, tc.Host.Hostname)
queries, err := ds.PolicyQueriesForHost(ctx, tc.Host)
require.NoError(t, err)
assertQueries(t, queries, tc.Policies, tc.Host.Hostname)
}
}
func testPolicyLabelMembershipCleanup(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Create labels
label1, err := ds.NewLabel(ctx, &fleet.Label{Name: "cleanup-label1"})
require.NoError(t, err)
label2, err := ds.NewLabel(ctx, &fleet.Label{Name: "cleanup-label2"})
require.NoError(t, err)
// Create hosts with different label combinations
hostNoLabels := test.NewHost(t, ds, "cleanup-host-no-labels", "10.0.0.1", "key1", "uuid1", time.Now())
hostLabel1 := test.NewHost(t, ds, "cleanup-host-label1", "10.0.0.2", "key2", "uuid2", time.Now())
hostLabel2 := test.NewHost(t, ds, "cleanup-host-label2", "10.0.0.3", "key3", "uuid3", time.Now())
hostLabelBoth := test.NewHost(t, ds, "cleanup-host-label-both", "10.0.0.4", "key4", "uuid4", time.Now())
// Apply labels to hosts
require.NoError(t, ds.AddLabelsToHost(ctx, hostLabel1.ID, []uint{label1.ID}))
require.NoError(t, ds.AddLabelsToHost(ctx, hostLabel2.ID, []uint{label2.ID}))
require.NoError(t, ds.AddLabelsToHost(ctx, hostLabelBoth.ID, []uint{label1.ID, label2.ID}))
// Create a policy with no label targets (applies to all hosts)
policy := newTestPolicy(t, ds, user1, "cleanup test policy", "", nil)
// Record policy results for all hosts
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
}
// Verify all hosts have membership
polsByName := map[string]*fleet.Policy{policy.Name: policy}
wantHostsByPol := map[string][]uint{
policy.Name: {hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID},
}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Update policy to include only label1
policy.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: label1.Name}}
require.NoError(t, ds.SavePolicy(ctx, policy, false, false))
// Verify only hosts with label1 still have membership
wantHostsByPol[policy.Name] = []uint{hostLabel1.ID, hostLabelBoth.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Update policy to include both labels (include any means host must have at least one)
policy.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: label1.Name}, {LabelName: label2.Name}}
require.NoError(t, ds.SavePolicy(ctx, policy, false, false))
// Since no new memberships were added, only hosts that had membership AND match the criteria remain
// hostLabel2 was removed in the previous step, so it won't come back
wantHostsByPol[policy.Name] = []uint{hostLabel1.ID, hostLabelBoth.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Re-record membership for all hosts to test exclude labels
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
}
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Update policy to exclude label2
policy.LabelsIncludeAny = nil
policy.LabelsExcludeAny = []fleet.LabelIdent{{LabelName: label2.Name}}
require.NoError(t, ds.SavePolicy(ctx, policy, false, false))
// Verify hosts with label2 are removed (hostLabel2 and hostLabelBoth)
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Test ApplyPolicySpecs with label changes
// First, re-record membership for all hosts
for _, h := range []*fleet.Host{hostNoLabels, hostLabel1, hostLabel2, hostLabelBoth} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
}
wantHostsByPol[policy.Name] = []uint{hostNoLabels.ID, hostLabel1.ID, hostLabel2.ID, hostLabelBoth.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Apply spec with include label1 only
err = ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{
Name: policy.Name,
Query: policy.Query,
LabelsIncludeAny: []string{label1.Name},
Type: fleet.PolicyTypeDynamic,
},
})
require.NoError(t, err)
// Verify only hosts with label1 remain
wantHostsByPol[policy.Name] = []uint{hostLabel1.ID, hostLabelBoth.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Test combined platform and label cleanup
// Create hosts with different platforms
hostWinLabel1 := test.NewHost(t, ds, "cleanup-host-win-label1", "10.0.0.5", "key5", "uuid5", time.Now())
hostWinLabel1.Platform = "windows"
require.NoError(t, ds.UpdateHost(ctx, hostWinLabel1))
require.NoError(t, ds.AddLabelsToHost(ctx, hostWinLabel1.ID, []uint{label1.ID}))
hostMacLabel1 := test.NewHost(t, ds, "cleanup-host-mac-label1", "10.0.0.6", "key6", "uuid6", time.Now())
hostMacLabel1.Platform = "darwin"
require.NoError(t, ds.UpdateHost(ctx, hostMacLabel1))
require.NoError(t, ds.AddLabelsToHost(ctx, hostMacLabel1.ID, []uint{label1.ID}))
// Create a new policy for platform + label test
policy2 := newTestPolicy(t, ds, user1, "cleanup test policy 2", "", nil)
// Record membership for all hosts with label1
for _, h := range []*fleet.Host{hostLabel1, hostLabelBoth, hostWinLabel1, hostMacLabel1} {
err = ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{policy2.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
}
polsByName[policy2.Name] = policy2
wantHostsByPol[policy2.Name] = []uint{hostLabel1.ID, hostLabelBoth.ID, hostWinLabel1.ID, hostMacLabel1.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
// Update policy2 to windows platform AND include label1
policy2.Platform = "windows"
policy2.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: label1.Name}}
require.NoError(t, ds.SavePolicy(ctx, policy2, false, false))
// Only windows hosts with label1 should remain
wantHostsByPol[policy2.Name] = []uint{hostWinLabel1.ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
}
func testDeletePolicyWithSoftwareActivatesNextActivity(t *testing.T, ds *Datastore) {
ctx := t.Context()
u := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
hostTm := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
hostNoTm := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
// move hostTm to team1
err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{hostTm.ID}))
require.NoError(t, err)
// Create a couple policies with an associated installer, one for team1 and
// one for no team
installer, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerIDTm, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
InstallerFile: installer,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: u.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
installer, err = fleet.NewTempFileReader(strings.NewReader("hello2"), t.TempDir)
require.NoError(t, err)
installerIDNoTm, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello2",
InstallerFile: installer,
StorageID: "storage2",
Filename: "file2",
Title: "file2",
Version: "2.0",
Source: "apps",
UserID: u.ID,
TeamID: nil,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
policyTm, err := ds.NewTeamPolicy(ctx, team1.ID, &u.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
SoftwareInstallerID: ptr.Uint(installerIDTm),
})
require.NoError(t, err)
policyNoTm, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &u.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 2;",
SoftwareInstallerID: ptr.Uint(installerIDNoTm),
})
require.NoError(t, err)
// enqueue a script execution on hostNoTm
scriptExec, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hostNoTm.ID,
ScriptContents: "echo",
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
// record a failing policy for both hosts, would enqueue the install
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false))
// simulate the work of "processSoftwareForNewlyFailingPolicies"
installUUIDNoTm, err := ds.InsertSoftwareInstallRequest(ctx, hostNoTm.ID, installerIDNoTm,
fleet.HostSoftwareInstallOptions{
SelfService: false,
PolicyID: &policyNoTm.ID,
})
require.NoError(t, err)
installUUIDTm, err := ds.InsertSoftwareInstallRequest(ctx, hostTm.ID, installerIDTm,
fleet.HostSoftwareInstallOptions{
SelfService: false,
PolicyID: &policyTm.ID,
})
require.NoError(t, err)
// check the upcoming activities before deletion
activities, _, err := ds.ListHostUpcomingActivities(ctx, hostNoTm.ID, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, activities, 2)
require.Equal(t, scriptExec.ExecutionID, activities[0].UUID)
require.Equal(t, installUUIDNoTm, activities[1].UUID)
activities, _, err = ds.ListHostUpcomingActivities(ctx, hostTm.ID, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, activities, 1)
require.Equal(t, installUUIDTm, activities[0].UUID)
checkUpcomingActivities(t, ds, hostNoTm, scriptExec.ExecutionID, installUUIDNoTm)
checkUpcomingActivities(t, ds, hostTm, installUUIDTm)
// delete both policies, will cancel and activate next
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{policyTm.ID})
require.NoError(t, err)
_, err = ds.DeleteTeamPolicies(ctx, fleet.PolicyNoTeamID, []uint{policyNoTm.ID})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hostNoTm, scriptExec.ExecutionID)
checkUpcomingActivities(t, ds, hostTm)
}
func testDeletePolicyWithScriptActivatesNextActivity(t *testing.T, ds *Datastore) {
ctx := t.Context()
u := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
hostTm := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
hostNoTm := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
// move hostTm to team1
err = ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team1.ID, []uint{hostTm.ID}))
require.NoError(t, err)
// Create a couple policies with an associated script, one for team1 and
// one for no team
scriptTm, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo",
TeamID: &team1.ID,
})
require.NoError(t, err)
scriptNoTm, err := ds.NewScript(ctx, &fleet.Script{
Name: "script2.sh",
ScriptContents: "echo",
TeamID: nil,
})
require.NoError(t, err)
policyTm, err := ds.NewTeamPolicy(ctx, team1.ID, &u.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
ScriptID: &scriptTm.ID,
})
require.NoError(t, err)
policyNoTm, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, &u.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 2;",
ScriptID: &scriptNoTm.ID,
})
require.NoError(t, err)
// enqueue a script execution on hostNoTm
scriptExec, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hostNoTm.ID,
ScriptContents: "echo",
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
// record a failing policy for both hosts, would enqueue the associated
// scripts
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostNoTm, map[uint]*bool{policyNoTm.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, hostTm, map[uint]*bool{policyTm.ID: ptr.Bool(false)}, time.Now(), false))
// simulate the work of "processScriptsForNewlyFailingPolicies"
hsrPolicyNoTm, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hostNoTm.ID,
ScriptContents: "echo",
UserID: &u.ID,
PolicyID: &policyNoTm.ID,
SyncRequest: true,
ScriptID: &scriptNoTm.ID,
})
require.NoError(t, err)
hsrPolicyTm, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hostTm.ID,
ScriptContents: "echo",
UserID: &u.ID,
PolicyID: &policyTm.ID,
SyncRequest: true,
ScriptID: &scriptTm.ID,
})
require.NoError(t, err)
// check the upcoming activities before deletion
activities, _, err := ds.ListHostUpcomingActivities(ctx, hostNoTm.ID, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, activities, 2)
require.Equal(t, scriptExec.ExecutionID, activities[0].UUID)
require.Equal(t, hsrPolicyNoTm.ExecutionID, activities[1].UUID)
activities, _, err = ds.ListHostUpcomingActivities(ctx, hostTm.ID, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, activities, 1)
require.Equal(t, hsrPolicyTm.ExecutionID, activities[0].UUID)
checkUpcomingActivities(t, ds, hostNoTm, scriptExec.ExecutionID, hsrPolicyNoTm.ExecutionID)
checkUpcomingActivities(t, ds, hostTm, hsrPolicyTm.ExecutionID)
// delete both policies, will cancel and activate next
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{policyTm.ID})
require.NoError(t, err)
_, err = ds.DeleteTeamPolicies(ctx, fleet.PolicyNoTeamID, []uint{policyNoTm.ID})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hostNoTm, scriptExec.ExecutionID)
checkUpcomingActivities(t, ds, hostTm)
}
// The UI can send simultaneous PATCH requests for policies (e.g. "Manage automations" page)
// This is testing that the backend retries upon finding deadlocks.
// Deadlocks will happen because all transactions may be trying to clear `policy_membership` for the same
// host.
func testSimultaneousSavePolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
var policies []*fleet.Policy
for i := range 10 {
policies = append(policies, newTestPolicy(t, ds, user1, fmt.Sprintf("policy%d", i), "darwin", &team1.ID))
}
host1 := newTestHostWithPlatform(t, ds, "host1", "darwin", &team1.ID)
// Record results for host1 for all policies
host1Results := make(map[uint]*bool)
for _, policy := range policies {
host1Results[policy.ID] = ptr.Bool(true)
}
err = ds.RecordPolicyQueryExecutions(ctx, host1, host1Results, time.Now(), false)
require.NoError(t, err)
// Run simultaneous
var g errgroup.Group
for i := range 10 {
g.Go(func() error {
policy := policies[i]
policy.Query += "just changing something here"
// NOTE: shouldRemoveAllPolicyMemberships is true when the user updates
// software item associated to an installer, so we set it to true here to
// simulate that.
return ds.SavePolicy(context.Background(), policy, true, true)
})
}
err = g.Wait()
require.NoError(t, err)
}
func testIsPolicyFailing(t *testing.T, ds *Datastore) {
ctx := context.Background()
// Create test data
host := test.NewHost(t, ds, "host1", "10.0.0.1", "host1Key", "host1UUID", time.Now())
user := test.NewUser(t, ds, "User", "test@example.com", true)
policy, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "policy",
Query: "SELECT 1;",
})
require.NoError(t, err)
// No policy membership record exists
// Edge case, should consider it as failing
isFailing, err := ds.IsPolicyFailing(ctx, policy.ID, host.ID)
require.NoError(t, err)
require.True(t, isFailing, "policy with no membership record is considered failing")
// Exists with passes = NULL
// failing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: nil}, time.Now(), false)
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
require.NoError(t, err)
require.True(t, isFailing, "policy with NULL passes should be considered still failing")
// exists with passes = false
// failing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
require.NoError(t, err)
require.True(t, isFailing, "policy with passes=false should be considered still failing")
// exists with passes = true
// Not failing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host.ID)
require.NoError(t, err)
require.False(t, isFailing, "policy with passes=true should NOT be considered still failing")
// Different host
host2 := test.NewHost(t, ds, "host2", "10.0.0.2", "host2Key", "host2UUID", time.Now())
isFailing, err = ds.IsPolicyFailing(ctx, policy.ID, host2.ID)
require.NoError(t, err)
require.True(t, isFailing, "policy with no membership record for different host should be considered still failing")
// Different policy for the same host
policy2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "policy 2",
Query: "SELECT 2;",
})
require.NoError(t, err)
isFailing, err = ds.IsPolicyFailing(ctx, policy2.ID, host.ID)
require.NoError(t, err)
require.True(t, isFailing, "different policy with no membership record should be considered still failing")
}
func testResetAttemptsOnFailingToPassingSync(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
// Create policies
p1, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "policy-sync-1",
Query: "SELECT 1",
})
require.NoError(t, err)
p2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "policy-sync-2",
Query: "SELECT 2",
})
require.NoError(t, err)
// Enroll a host
host, err := ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("hsync1"),
fleet.WithEnrollOsqueryNodeKey("nsync1"),
)
require.NoError(t, err)
// p1 will be failing
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false))
// Create rows with attempt_number > 0 and attempt_number IS NULL (pending)
// p1 - completed attempt
execID1 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, 1)`, host.ID, execID1, p1.ID)
require.NoError(t, err)
execID2 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, 1, 'x', '1.0.0')`, execID2, host.ID, p1.ID)
require.NoError(t, err)
execID1Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, NULL)`, host.ID, execID1Pending, p1.ID)
require.NoError(t, err)
execID2Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, NULL, 'x-pending', '1.0.0')`, execID2Pending, host.ID, p1.ID)
require.NoError(t, err)
// p2 - completed attempt
execID3 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, 1)`, host.ID, execID3, p2.ID)
require.NoError(t, err)
execID4 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, 1, 'y', '2.0.0')`, execID4, host.ID, p2.ID)
require.NoError(t, err)
execID3Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, NULL)`, host.ID, execID3Pending, p2.ID)
require.NoError(t, err)
execID4Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, NULL, 'y-pending', '2.0.0')`, execID4Pending, host.ID, p2.ID)
require.NoError(t, err)
// p1 is now passing
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false)
require.NoError(t, err)
// p1 rows should be reset to 0 (both completed and pending)
var cnt int
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND policy_id = ? AND attempt_number = 0`, host.ID, p1.ID)
require.NoError(t, err)
require.Equal(t, 2, cnt, "both completed and pending script attempts should be reset to 0")
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_software_installs WHERE host_id = ? AND policy_id = ? AND attempt_number = 0`, host.ID, p1.ID)
require.NoError(t, err)
require.Equal(t, 2, cnt, "both completed and pending install attempts should be reset to 0")
// p2 rows should remain unchanged (no transition, first execution was passing)
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND policy_id = ? AND attempt_number = 1`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt)
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_software_installs WHERE host_id = ? AND policy_id = ? AND attempt_number = 1`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt)
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND policy_id = ? AND attempt_number IS NULL`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt, "p2 pending script should remain NULL")
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_software_installs WHERE host_id = ? AND policy_id = ? AND attempt_number IS NULL`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt, "p2 pending install should remain NULL")
}
func testResetAttemptsOnFailingToPassingAsync(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Carol", "carol@example.com", true)
// Create policies
p1, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "policy-async-1",
Query: "SELECT 1",
})
require.NoError(t, err)
p2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "policy-async-2",
Query: "SELECT 2",
})
require.NoError(t, err)
// Enroll a host
host, err := ds.EnrollOsquery(ctx,
fleet.WithEnrollOsqueryHostID("hasync1"),
fleet.WithEnrollOsqueryNodeKey("nasync1"),
)
require.NoError(t, err)
// p1 is failing
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{p1.ID: ptr.Bool(false)}, time.Now(), false))
// Create rows with attempt_number > 0 and attempt_number IS NULL (pending)
// p1 - completed attempt
execID1 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, 1)`, host.ID, execID1, p1.ID)
require.NoError(t, err)
execID2 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, 1, 'z', '3.0.0')`, execID2, host.ID, p1.ID)
require.NoError(t, err)
execID1Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, NULL)`, host.ID, execID1Pending, p1.ID)
require.NoError(t, err)
execID2Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, NULL, 'z-pending', '3.0.0')`, execID2Pending, host.ID, p1.ID)
require.NoError(t, err)
// p2 - completed attempt
execID3 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, 1)`, host.ID, execID3, p2.ID)
require.NoError(t, err)
execID4 := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, 1, 'w', '4.0.0')`, execID4, host.ID, p2.ID)
require.NoError(t, err)
execID3Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_script_results (host_id, execution_id, output, runtime, policy_id, attempt_number) VALUES (?, ?, '', 0, ?, NULL)`, host.ID, execID3Pending, p2.ID)
require.NoError(t, err)
execID4Pending := uuid.NewString()
_, err = ds.writer(ctx).Exec(`INSERT INTO host_software_installs (execution_id, host_id, policy_id, attempt_number, installer_filename, version) VALUES (?, ?, ?, NULL, 'w-pending', '4.0.0')`, execID4Pending, host.ID, p2.ID)
require.NoError(t, err)
// flip p1 to passing
batch := []fleet.PolicyMembershipResult{
{HostID: host.ID, PolicyID: p1.ID, Passes: ptr.Bool(true)},
{HostID: host.ID, PolicyID: p2.ID, Passes: ptr.Bool(true)},
}
require.NoError(t, ds.AsyncBatchInsertPolicyMembership(ctx, batch))
// p1 rows should be reset to 0 (both completed and pending)
var cnt int
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND policy_id = ? AND attempt_number = 0`, host.ID, p1.ID)
require.NoError(t, err)
require.Equal(t, 2, cnt, "both completed and pending script attempts should be reset to 0")
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_software_installs WHERE host_id = ? AND policy_id = ? AND attempt_number = 0`, host.ID, p1.ID)
require.NoError(t, err)
require.Equal(t, 2, cnt, "both completed and pending install attempts should be reset to 0")
// p2 rows should remain unchanged (no transition, first execution was passing)
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND policy_id = ? AND attempt_number = 1`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt)
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_software_installs WHERE host_id = ? AND policy_id = ? AND attempt_number = 1`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt)
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND policy_id = ? AND attempt_number IS NULL`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt, "p2 pending script should remain NULL")
err = sqlx.GetContext(ctx, ds.reader(ctx), &cnt, `SELECT COUNT(*) FROM host_software_installs WHERE host_id = ? AND policy_id = ? AND attempt_number IS NULL`, host.ID, p2.ID)
require.NoError(t, err)
require.Equal(t, 1, cnt, "p2 pending install should remain NULL")
}
func testPolicyModificationResetsAttemptNumber(t *testing.T, ds *Datastore) {
ctx := context.Background()
// Create a team
team, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name()})
require.NoError(t, err)
// Create script content
var scriptContentID int64
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
res, err := q.ExecContext(ctx, `INSERT INTO script_contents (md5_checksum, contents) VALUES (?, ?)`,
"md5hash", "echo 'test'")
if err != nil {
return err
}
scriptContentID, err = res.LastInsertId()
return err
})
// Create a script
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "test.sh",
TeamID: &team.ID,
ScriptContentID: uint(scriptContentID), //nolint:gosec // dismiss G115
ScriptContents: "echo 'test'",
})
require.NoError(t, err)
// Create a software title and installer
titleID := int64(0)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
res, err := q.ExecContext(ctx, `INSERT INTO software_titles (name, source) VALUES (?, ?)`, "Test App", "apps")
if err != nil {
return err
}
titleID, err = res.LastInsertId()
return err
})
installerID := int64(0)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
res, err := q.ExecContext(ctx, `
INSERT INTO software_installers (team_id, global_or_team_id, title_id, storage_id, filename, extension, version, install_script_content_id, uninstall_script_content_id, platform, package_ids, patch_query)
VALUES (?, ?, ?, 'storage', 'test.pkg', 'pkg', '1.0', ?, ?, 'darwin', '', '')
`, team.ID, team.ID, titleID, scriptContentID, scriptContentID)
if err != nil {
return err
}
installerID, err = res.LastInsertId()
return err
})
// Create a policy
policy, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
Name: t.Name(),
Query: "SELECT 1;",
Platform: "darwin",
})
require.NoError(t, err)
// Insert software install attempts for this policy
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, `
INSERT INTO host_software_installs (execution_id, host_id, software_installer_id, user_id, self_service, policy_id, install_script_exit_code, attempt_number)
VALUES ('install-1', 1, ?, NULL, 0, ?, 1, 1)
`, installerID, policy.ID)
return err
})
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, `
INSERT INTO host_software_installs (execution_id, host_id, software_installer_id, user_id, self_service, policy_id, install_script_exit_code, attempt_number)
VALUES ('install-2', 1, ?, NULL, 0, ?, NULL, 2)
`, installerID, policy.ID)
return err
})
// Insert script execution attempts for this policy
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, `
INSERT INTO host_script_results (host_id, execution_id, script_content_id, output, exit_code, script_id, policy_id, attempt_number)
VALUES (1, 'script-1', ?, 'output', 1, ?, ?, 1)
`, scriptContentID, script.ID, policy.ID)
return err
})
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, `
INSERT INTO host_script_results (host_id, execution_id, script_content_id, output, exit_code, script_id, policy_id, attempt_number)
VALUES (1, 'script-2', ?, '', NULL, ?, ?, 2)
`, scriptContentID, script.ID, policy.ID)
return err
})
// Modify the policy - this should reset attempt_number to 0 for all automations using this policy
err = ds.SavePolicy(ctx, policy, false, false)
require.NoError(t, err)
// Verify software install attempts were reset
type installResult struct {
ExecutionID string `db:"execution_id"`
AttemptNumber *int64 `db:"attempt_number"`
}
var installResults []installResult
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.SelectContext(ctx, q, &installResults, `
SELECT execution_id, attempt_number
FROM host_software_installs
WHERE policy_id = ?
ORDER BY execution_id ASC
`, policy.ID)
})
require.Len(t, installResults, 2)
require.NotNil(t, installResults[0].AttemptNumber)
require.Equal(t, int64(0), *installResults[0].AttemptNumber)
require.NotNil(t, installResults[1].AttemptNumber)
require.Equal(t, int64(0), *installResults[1].AttemptNumber)
// Verify script execution attempts were also reset
type scriptResult struct {
ExecutionID string `db:"execution_id"`
AttemptNumber *int64 `db:"attempt_number"`
}
var scriptResults []scriptResult
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.SelectContext(ctx, q, &scriptResults, `
SELECT execution_id, attempt_number
FROM host_script_results
WHERE script_id = ? AND policy_id = ?
ORDER BY execution_id ASC
`, script.ID, policy.ID)
})
require.Len(t, scriptResults, 2)
require.Equal(t, "script-1", scriptResults[0].ExecutionID)
require.NotNil(t, scriptResults[0].AttemptNumber)
require.Equal(t, int64(0), *scriptResults[0].AttemptNumber)
require.Equal(t, "script-2", scriptResults[1].ExecutionID)
require.NotNil(t, scriptResults[1].AttemptNumber)
require.Equal(t, int64(0), *scriptResults[1].AttemptNumber)
}
// testBatchedPolicyMembershipCleanup verifies that cleanupPolicyMembershipForPolicy and
// cleanupPolicyMembershipOnPolicyUpdate correctly delete rows in small batches (to reduce lock
// contention) rather than in a single large DELETE, and that all memberships are fully removed.
func testBatchedPolicyMembershipCleanup(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Override batch size to force multiple batches with a small number of hosts.
orig := policyMembershipDeleteBatchSize
policyMembershipDeleteBatchSize = 2
t.Cleanup(func() { policyMembershipDeleteBatchSize = orig })
// Create a policy and 5 hosts (more than the batch size of 2).
pol := newTestPolicy(t, ds, user1, "batch cleanup policy", "", nil)
hosts := make([]*fleet.Host, 5)
for i := range hosts {
id := fmt.Sprintf("batch-cleanup-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
hosts[i] = h
}
// Record failing results for all hosts so they all have policy_membership rows and host_issues entries.
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
// Collect host IDs for scoped host_issues assertions (avoids flakiness if other tests
// leave rows in host_issues).
hostIDs := make([]uint, len(hosts))
for i, h := range hosts {
hostIDs[i] = h.ID
}
hostIssuesQ, hostIssuesArgs, err := sqlx.In(
`SELECT COUNT(*) FROM host_issues WHERE host_id IN (?) AND total_issues_count > 0`, hostIDs,
)
require.NoError(t, err)
// Confirm all memberships exist.
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 5, count)
// Confirm all hosts have failing policy issues.
require.NoError(t, ds.writer(ctx).Get(&count, hostIssuesQ, hostIssuesArgs...))
require.Equal(t, 5, count)
// Run the full cleanup function directly (simulates what ApplyPolicySpecs triggers when a
// query changes — shouldRemoveAllPolicyMemberships == true).
err = cleanupPolicyMembershipForPolicy(ctx, ds.reader(ctx), ds.writer(ctx), pol.ID)
require.NoError(t, err)
// All policy_membership rows must be gone.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count)
// host_issues must be updated (no more failing policies for those hosts).
require.NoError(t, ds.writer(ctx).Get(&count, hostIssuesQ, hostIssuesArgs...))
assert.Zero(t, count)
}
// testBatchedPolicyMembershipCleanupOnPolicyUpdate verifies that cleanupPolicyMembershipOnPolicyUpdate
// deletes rows in batches for both the platform and label sections.
func testBatchedPolicyMembershipCleanupOnPolicyUpdate(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Override batch size to force multiple batches.
orig := policyMembershipDeleteBatchSize
policyMembershipDeleteBatchSize = 2
t.Cleanup(func() { policyMembershipDeleteBatchSize = orig })
// ── Part 1: platform-based cleanup ──────────────────────────────────────
// Create a windows-only policy.
pol := newTestPolicy(t, ds, user1, "batch platform cleanup", "windows", nil)
// Create 5 linux hosts (wrong platform) + 1 windows host (should remain).
winID := "batch-win-0"
winHost, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &winID,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &winID,
UUID: winID,
Hostname: winID,
Platform: "windows",
})
require.NoError(t, err)
linuxHosts := make([]*fleet.Host, 5)
for i := range linuxHosts {
id := fmt.Sprintf("batch-lin-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
linuxHosts[i] = h
}
// Record results for all hosts (simulating results arriving before platform filter applied).
allHosts := append([]*fleet.Host{winHost}, linuxHosts...)
for _, h := range allHosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 6, count)
// Run the platform-aware cleanup (simulates CleanupPolicyMembership cron).
err = cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.reader(ctx), ds.writer(ctx), pol.ID, pol.Platform)
require.NoError(t, err)
// Only the windows host should remain.
var hostIDs []uint
require.NoError(t, ds.writer(ctx).Select(&hostIDs, `SELECT host_id FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.ElementsMatch(t, []uint{winHost.ID}, hostIDs)
// ── Part 2: label-based cleanup ─────────────────────────────────────────
// Create a label and a policy that targets only hosts in that label.
inclLabel, err := ds.NewLabel(ctx, &fleet.Label{Name: "batch-incl-label"})
require.NoError(t, err)
// Create 1 host that belongs to the label (should survive cleanup) and 5
// that do not (should be removed in multiple batches of 2).
lblID := "batch-lbl-0"
lblHost, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &lblID,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &lblID,
UUID: lblID,
Hostname: lblID,
Platform: "linux",
})
require.NoError(t, err)
require.NoError(t, ds.AddLabelsToHost(ctx, lblHost.ID, []uint{inclLabel.ID}))
nonLblHosts := make([]*fleet.Host, 5)
for i := range nonLblHosts {
id := fmt.Sprintf("batch-nonlbl-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
nonLblHosts[i] = h
}
// Create a label-scoped policy (no platform restriction).
lblPol := newTestPolicy(t, ds, user1, "batch label cleanup", "", nil)
lblPol.LabelsIncludeAny = []fleet.LabelIdent{{LabelName: inclLabel.Name}}
require.NoError(t, ds.SavePolicy(ctx, lblPol, false, false))
// Record policy results for all label-test hosts so policy_membership is populated.
labelHosts := append([]*fleet.Host{lblHost}, nonLblHosts...)
for _, h := range labelHosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{lblPol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, lblPol.ID))
require.Equal(t, 6, count)
// Run cleanupPolicyMembershipOnPolicyUpdate with no platform restriction so
// only the label-based branch fires.
err = cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.reader(ctx), ds.writer(ctx), lblPol.ID, "" /* no platform filter */)
require.NoError(t, err)
// Only the host that belongs to the include label should remain.
var lblHostIDs []uint
require.NoError(t, ds.writer(ctx).Select(&lblHostIDs, `SELECT host_id FROM policy_membership WHERE policy_id = ?`, lblPol.ID))
require.ElementsMatch(t, []uint{lblHost.ID}, lblHostIDs)
}
// testApplyPolicySpecsNeedsFullMembershipCleanupFlag verifies that:
// 1. ApplyPolicySpecs sets needs_full_membership_cleanup = 1 inside the transaction when
// the query changes (shouldRemoveAllPolicyMemberships == true).
// 2. The flag is cleared back to 0 after cleanup completes successfully.
// 3. All policy_membership rows are removed after the cleanup.
func testApplyPolicySpecsNeedsFullMembershipCleanupFlag(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// Create the policy for the first time.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "flag test policy", Query: "select 1;", Platform: "", Type: fleet.PolicyTypeDynamic},
}))
// Find the policy by name so the test is not sensitive to other global policies created by concurrent tests.
pols, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
var pol *fleet.Policy
for _, p := range pols {
if p.Name == "flag test policy" {
pol = p
break
}
}
require.NotNil(t, pol, "policy 'flag test policy' not found")
// Create hosts and record failing results.
hosts := make([]*fleet.Host, 3)
for i := range hosts {
id := fmt.Sprintf("flag-test-%d", i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
hosts[i] = h
}
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{pol.ID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 3, count)
// Update the query — this triggers shouldRemoveAllPolicyMemberships = true.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "flag test policy", Query: "select 2;", Platform: "", Type: fleet.PolicyTypeDynamic},
}))
// The flag must be 0 after successful completion (set inside TX, cleared after cleanup).
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "needs_full_membership_cleanup must be cleared after successful cleanup")
// All memberships must have been removed.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count)
}
// testCleanupPolicyMembershipCrashRecovery verifies two recovery paths when a previous cleanup
// was interrupted (crash or error after the transaction committed):
//
// 1. GitOps retry path: ApplyPolicySpecs detects needs_full_membership_cleanup = 1 and re-runs
// the full cleanup itself, without waiting for the cron.
// 2. Cron safety net path: CleanupPolicyMembership finds needs_full_membership_cleanup = 1 and
// finishes the job when no GitOps retry occurs.
func testCleanupPolicyMembershipCrashRecovery(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
newHosts := func(t *testing.T, n int, prefix string) []*fleet.Host {
t.Helper()
hosts := make([]*fleet.Host, n)
for i := range hosts {
id := fmt.Sprintf("%s-%d", prefix, i)
h, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: "linux",
})
require.NoError(t, err)
hosts[i] = h
}
return hosts
}
recordResults := func(t *testing.T, hosts []*fleet.Host, polID uint) {
t.Helper()
for _, h := range hosts {
err := ds.RecordPolicyQueryExecutions(ctx, h, map[uint]*bool{polID: ptr.Bool(false)}, time.Now(), false)
require.NoError(t, err)
}
}
t.Run("gitops retry re-triggers cleanup", func(t *testing.T) {
// Create policy via ApplyPolicySpecs so it exists in the DB.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "retry recovery policy", Query: "select 1;", Type: fleet.PolicyTypeDynamic},
}))
pols, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
var pol *fleet.Policy
for _, p := range pols {
if p.Name == "retry recovery policy" {
pol = p
break
}
}
require.NotNil(t, pol)
// Record membership rows.
hosts := newHosts(t, 4, "retry-recovery")
recordResults(t, hosts, pol.ID)
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 4, count)
// Simulate: TX committed with the flag set, but cleanup never ran (crash/error).
_, err = ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`, pol.ID)
require.NoError(t, err)
// Retry GitOps with the same spec. ApplyPolicySpecs must detect the flag and
// re-run the full cleanup — no cron needed.
require.NoError(t, ds.ApplyPolicySpecs(ctx, user1.ID, []*fleet.PolicySpec{
{Name: "retry recovery policy", Query: "select 1;", Type: fleet.PolicyTypeDynamic},
}))
// Flag must be cleared by the retry.
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "flag must be cleared by the GitOps retry")
// All memberships must be gone.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count, "all policy_membership rows must be removed by the GitOps retry")
})
t.Run("cron cleans up when no gitops retry", func(t *testing.T) {
pol := newTestPolicy(t, ds, user1, "cron recovery policy", "", nil)
hosts := newHosts(t, 4, "cron-recovery")
recordResults(t, hosts, pol.ID)
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Equal(t, 4, count)
// Simulate interrupted cleanup: set the flag directly, leave membership rows in place.
_, err := ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`, pol.ID)
require.NoError(t, err)
// CleanupPolicyMembership (cron) should pick up the flag and run the full cleanup.
require.NoError(t, ds.CleanupPolicyMembership(ctx, time.Now()))
// Flag must be cleared.
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "flag must be cleared by CleanupPolicyMembership")
// All memberships must be removed.
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
assert.Zero(t, count, "all policy_membership rows must be cleaned up by the cron safety net")
})
t.Run("cron clears flag when cleanup already completed", func(t *testing.T) {
// Simulates: the transaction committed (flag=1), cleanupPolicy ran and
// removed all membership rows, but the server crashed before executing
// UPDATE policies SET needs_full_membership_cleanup = 0.
// The cron must handle this gracefully (no-op cleanup) and clear the flag.
pol := newTestPolicy(t, ds, user1, "flag-only recovery policy", "", nil)
// No membership rows exist — simulating that cleanup already removed them.
var count int
require.NoError(t, ds.writer(ctx).Get(&count, `SELECT COUNT(*) FROM policy_membership WHERE policy_id = ?`, pol.ID))
require.Zero(t, count, "precondition: no membership rows")
// Set the flag to simulate the crash window between cleanup and flag clear.
_, err := ds.writer(ctx).ExecContext(ctx,
`UPDATE policies SET needs_full_membership_cleanup = 1 WHERE id = ?`, pol.ID)
require.NoError(t, err)
// CleanupPolicyMembership (cron) should handle this without errors.
require.NoError(t, ds.CleanupPolicyMembership(ctx, time.Now()))
// Flag must be cleared.
var flagVal int
require.NoError(t, ds.writer(ctx).Get(&flagVal,
`SELECT needs_full_membership_cleanup FROM policies WHERE id = ?`, pol.ID))
assert.Zero(t, flagVal, "flag must be cleared even when no membership rows remain")
})
}
func testTeamPatchPolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
payload := &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
StorageID: "storage1",
Filename: "maintained1",
Title: "Maintained1",
Version: "1.0",
Source: "apps",
Platform: "darwin",
BundleIdentifier: "fleet.maintained1",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
}
installerID, titleID, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), payload)
require.NoError(t, err)
// create a patch policy for an installer with no associated FMA
_, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
Type: fleet.PolicyTypePatch,
PatchSoftwareTitleID: &titleID,
})
require.ErrorContains(t, err, "Software installer for Fleet maintained app with title ID")
maintainedApp, err := ds.UpsertMaintainedApp(ctx, &fleet.MaintainedApp{
Name: "Maintained1",
Slug: "maintained1",
Platform: "darwin",
UniqueIdentifier: "fleet.maintained1",
})
require.NoError(t, err)
require.NoError(t, ds.DeleteSoftwareInstaller(ctx, installerID))
payload.FleetMaintainedAppID = &maintainedApp.ID
_, titleID, err = ds.MatchOrCreateSoftwareInstaller(context.Background(), payload)
require.NoError(t, err)
p1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
Type: fleet.PolicyTypePatch,
PatchSoftwareTitleID: &titleID,
})
require.NoError(t, err)
require.Equal(t, "SELECT 1 WHERE NOT EXISTS (SELECT 1 FROM apps WHERE bundle_identifier = 'fleet.maintained1' AND version_compare(bundle_short_version, '1.0') < 0);", p1.Query)
_, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 1;",
Platform: "darwin",
})
require.NoError(t, err)
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{p1.ID})
require.NoError(t, err)
// everything automatically generated
p3, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Type: fleet.PolicyTypePatch,
PatchSoftwareTitleID: &titleID,
})
require.NoError(t, err)
require.Equal(t, "macOS - Maintained1 up to date", p3.Name)
require.Equal(t, "Outdated software might introduce security vulnerabilities or compatibility issues.", p3.Description)
require.Equal(t, "Install the latest version from self-service.", *p3.Resolution)
require.Equal(t, "darwin", p3.Platform)
require.Equal(t, "SELECT 1 WHERE NOT EXISTS (SELECT 1 FROM apps WHERE bundle_identifier = 'fleet.maintained1' AND version_compare(bundle_short_version, '1.0') < 0);", p3.Query)
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{p3.ID})
require.NoError(t, err)
// some fields should not be overwritten
p4, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "name",
Description: "description",
Resolution: "resolution",
Type: fleet.PolicyTypePatch,
PatchSoftwareTitleID: &titleID,
})
require.NoError(t, err)
require.Equal(t, "name", p4.Name)
require.Equal(t, "description", p4.Description)
require.Equal(t, "resolution", *p4.Resolution)
require.Equal(t, "darwin", p4.Platform)
require.Equal(t, "SELECT 1 WHERE NOT EXISTS (SELECT 1 FROM apps WHERE bundle_identifier = 'fleet.maintained1' AND version_compare(bundle_short_version, '1.0') < 0);", p4.Query)
// test GetPatchPolicy
data, err := ds.GetPatchPolicy(ctx, &team1.ID, titleID)
require.NoError(t, err)
require.Equal(t, p4.ID, data.ID)
require.Equal(t, p4.Name, data.Name)
payload2 := &fleet.UploadSoftwareInstallerPayload{
Filename: "bar",
Title: "bar",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
}
_, titleID2, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), payload2)
require.NoError(t, err)
_, err = ds.GetPatchPolicy(ctx, &team1.ID, titleID2)
require.True(t, fleet.IsNotFound(err))
maintainedApp2, err := ds.UpsertMaintainedApp(ctx, &fleet.MaintainedApp{
Name: "Maintained2",
Slug: "maintained2",
Platform: "windows",
UniqueIdentifier: "fleet.maintained2",
})
require.NoError(t, err)
payload3 := &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
StorageID: "storage2",
Filename: "maintained2",
Title: "Maintained2",
Version: "1.0",
Source: "programs",
Platform: "windows",
BundleIdentifier: "fleet.maintained2",
UserID: user1.ID,
TeamID: &team1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
FleetMaintainedAppID: &maintainedApp2.ID,
}
_, titleID3, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), payload3)
require.NoError(t, err)
p5, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Type: fleet.PolicyTypePatch,
PatchSoftwareTitleID: &titleID3,
})
require.NoError(t, err)
require.Equal(t, "Windows - Maintained2 up to date", p5.Name)
require.Equal(t, "windows", p5.Platform)
require.Equal(t, "SELECT 1 WHERE NOT EXISTS (SELECT 1 FROM programs WHERE name = 'Maintained2' AND version_compare(version, '1.0') < 0);", p5.Query)
}
func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
gpol, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{
Name: "query 1",
Query: "select 1;",
Description: "query desc",
Resolution: "query resolution",
})
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
TeamID: nil,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
// Add an FMA.
fma, err := ds.UpsertMaintainedApp(ctx, &fleet.MaintainedApp{ID: 1})
require.NoError(t, err)
payload := &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
StorageID: "storage1",
Filename: "maintained1",
Title: "Maintained1",
Version: "1.0",
Source: "apps",
Platform: "darwin",
BundleIdentifier: "fleet.maintained1",
UserID: user1.ID,
TeamID: nil,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
FleetMaintainedAppID: &fma.ID,
}
_, titleID2, err := ds.MatchOrCreateSoftwareInstaller(context.Background(), payload)
require.NoError(t, err)
test.CreateInsertGlobalVPPToken(t, ds)
// create team1 app
teamApp, err := ds.InsertVPPAppWithTeam(ctx, &fleet.VPPApp{
Name: "vpp1", BundleIdentifier: "com.app.appy",
VPPAppTeam: fleet.VPPAppTeam{VPPAppID: fleet.VPPAppID{AdamID: "adam_app", Platform: fleet.MacOSPlatform}},
}, nil)
require.NoError(t, err)
teamAppMeta, err := ds.GetVPPAppMetadataByTeamAndTitleID(ctx, nil, teamApp.TitleID)
require.NoError(t, err)
// Create policies with automations
teamInstallerPolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 2",
Query: "select 1;",
SoftwareInstallerID: &installerID,
})
require.NoError(t, err)
teamAppStorePolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 3",
Query: "select 1;",
VPPAppsTeamsID: &teamAppMeta.VPPAppsTeamsID,
})
require.NoError(t, err)
script, err := ds.NewScript(context.Background(), &fleet.Script{
TeamID: nil,
Name: "hello-world.sh",
ScriptContents: "echo 'Hello World'",
})
require.NoError(t, err)
teamScriptPolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 4",
Query: "SELECT 1;",
ScriptID: &script.ID,
})
require.NoError(t, err)
teamCalendarPolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 5",
Query: "SELECT 1;",
CalendarEventsEnabled: true,
})
require.NoError(t, err)
teamConditionalPolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 6",
Query: "SELECT 1;",
ConditionalAccessEnabled: true,
})
require.NoError(t, err)
teamWebhookPolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 7",
Query: "SELECT 1;",
})
require.NoError(t, err)
teamPatchPolicy, err := ds.NewTeamPolicy(ctx, 0, nil, fleet.PolicyPayload{
Name: "query 8",
Query: "SELECT 1;",
Type: fleet.PolicyTypePatch,
PatchSoftwareTitleID: &titleID2,
})
require.NoError(t, err)
// TODO: test ticket integration policies?
config := fleet.TeamConfig{}
config.WebhookSettings.FailingPoliciesWebhook.PolicyIDs = []uint{teamWebhookPolicy.ID}
err = ds.SaveDefaultTeamConfig(ctx, &config)
require.NoError(t, err)
// All policies are listed
merged, err := ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, "")
require.NoError(t, err)
require.Len(t, merged, 8)
assert.Equal(t, gpol.ID, merged[0].ID)
assert.Equal(t, teamInstallerPolicy.ID, merged[1].ID)
assert.Equal(t, teamAppStorePolicy.ID, merged[2].ID)
assert.Equal(t, teamScriptPolicy.ID, merged[3].ID)
assert.Equal(t, teamCalendarPolicy.ID, merged[4].ID)
assert.Equal(t, teamConditionalPolicy.ID, merged[5].ID)
assert.Equal(t, teamWebhookPolicy.ID, merged[6].ID)
assert.Equal(t, teamPatchPolicy.ID, merged[7].ID)
mergedCount, err := ds.CountMergedTeamPolicies(ctx, 0, "", "")
require.NoError(t, err)
assert.Equal(t, 8, mergedCount)
// Test software
merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, "software")
require.NoError(t, err)
require.Len(t, merged, 3)
assert.Equal(t, teamInstallerPolicy.ID, merged[0].ID)
assert.Equal(t, teamAppStorePolicy.ID, merged[1].ID)
assert.Equal(t, teamPatchPolicy.ID, merged[2].ID)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "software")
require.NoError(t, err)
assert.Equal(t, 3, mergedCount)
// Test scripts
merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, "scripts")
require.NoError(t, err)
require.Len(t, merged, 1)
assert.Equal(t, teamScriptPolicy.ID, merged[0].ID)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "scripts")
require.NoError(t, err)
assert.Equal(t, 1, mergedCount)
// Test calendar
merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, "calendar")
require.NoError(t, err)
require.Len(t, merged, 1)
assert.Equal(t, teamCalendarPolicy.ID, merged[0].ID)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "calendar")
require.NoError(t, err)
assert.Equal(t, 1, mergedCount)
// Test conditional_access
merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, "conditional_access")
require.NoError(t, err)
require.Len(t, merged, 1)
assert.Equal(t, teamConditionalPolicy.ID, merged[0].ID)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "conditional_access")
require.NoError(t, err)
assert.Equal(t, 1, mergedCount)
// Test other
merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, "other")
require.NoError(t, err)
require.Len(t, merged, 1)
assert.Equal(t, teamWebhookPolicy.ID, merged[0].ID)
mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "other")
require.NoError(t, err)
assert.Equal(t, 1, mergedCount)
// Test not merged
policies, _, err := ds.ListTeamPolicies(ctx, 0, fleet.ListOptions{
OrderKey: "name",
OrderDirection: fleet.OrderAscending,
}, fleet.ListOptions{}, "software")
require.NoError(t, err)
require.Len(t, policies, 3)
assert.Equal(t, teamInstallerPolicy.ID, policies[0].ID)
assert.Equal(t, teamAppStorePolicy.ID, policies[1].ID)
assert.Equal(t, teamPatchPolicy.ID, policies[2].ID)
mergedCount, err = ds.CountPolicies(ctx, ptr.Uint(0), "", "software")
require.NoError(t, err)
assert.Equal(t, 3, mergedCount)
}