diff --git a/server/datastore/mysql/packs.go b/server/datastore/mysql/packs.go index 391d4787cb..3cad4129a3 100644 --- a/server/datastore/mysql/packs.go +++ b/server/datastore/mysql/packs.go @@ -76,7 +76,7 @@ func applyPackSpecDB(ctx context.Context, tx sqlx.ExtContext, spec *fleet.PackSp q.Name = q.QueryName } - // Check if query exists ... we have to do this manual check because the FK + // Check if query exists ... we have to do this manually because the FK // constraint was removed as part of the work required for combining queries and schedules var count int if err := tx.QueryRowxContext( diff --git a/server/datastore/mysql/queries_test.go b/server/datastore/mysql/queries_test.go index df75445814..3d5a6b3ae9 100644 --- a/server/datastore/mysql/queries_test.go +++ b/server/datastore/mysql/queries_test.go @@ -24,12 +24,13 @@ func TestQueries(t *testing.T) { {"Delete", testQueriesDelete}, {"GetByName", testQueriesGetByName}, {"DeleteMany", testQueriesDeleteMany}, - // {"Save", testQueriesSave}, - // {"List", testQueriesList}, - // {"LoadPacksForQueries", testQueriesLoadPacksForQueries}, - // {"DuplicateNew", testQueriesDuplicateNew}, - // {"ListFiltersObservers", testQueriesListFiltersObservers}, - // {"ObserverCanRunQuery", testObserverCanRunQuery}, + {"Save", testQueriesSave}, + {"List", testQueriesList}, + {"LoadPacksForQueries", testQueriesLoadPacksForQueries}, + {"DuplicateNew", testQueriesDuplicateNew}, + {"ListFiltersObservers", testQueriesListFiltersObservers}, + {"ObserverCanRunQuery", testObserverCanRunQuery}, + {"ListFiltersByTeamID", testQueriesListFiltersByTeamID}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -79,6 +80,7 @@ func testQueriesApply(t *testing.T, ds *Datastore) { require.Equal(t, &zwass.ID, q.AuthorID) require.Equal(t, zwass.Email, q.AuthorEmail) require.Equal(t, zwass.Name, q.AuthorName) + require.True(t, q.Saved) } // Victor modifies a query (but also pushes the same version of the @@ -98,6 +100,7 @@ func testQueriesApply(t *testing.T, ds *Datastore) { assert.Equal(t, &groob.ID, q.AuthorID) require.Equal(t, groob.Email, q.AuthorEmail) require.Equal(t, groob.Name, q.AuthorName) + require.True(t, q.Saved) } // Zach adds a third query (but does not re-apply the others) @@ -118,6 +121,7 @@ func testQueriesApply(t *testing.T, ds *Datastore) { test.QueryElementsMatch(t, expectedQueries, queries) for _, q := range queries { + require.True(t, q.Saved) switch q.Name { case "foo", "bar": require.Equal(t, &groob.ID, q.AuthorID) @@ -155,14 +159,36 @@ func testQueriesDelete(t *testing.T, ds *Datastore) { func testQueriesGetByName(t *testing.T, ds *Datastore) { user := test.NewUser(t, ds, "Zach", "zwass@fleet.co", true) - q := test.NewQuery(t, ds, nil, "q1", "select * from time", user.ID, true) - actual, err := ds.QueryByName(context.Background(), q.TeamID, q.Name) + // Test we can get global queries by name + globalQ := test.NewQuery(t, ds, nil, "q1", "select * from time", user.ID, true) + + actual, err := ds.QueryByName(context.Background(), nil, globalQ.Name) require.NoError(t, err) + require.Nil(t, actual.TeamID) require.Equal(t, "q1", actual.Name) require.Equal(t, "select * from time", actual.Query) - actual, err = ds.QueryByName(context.Background(), q.TeamID, "xxx") + actual, err = ds.QueryByName(context.Background(), nil, "xxx") + require.Error(t, err) + require.True(t, fleet.IsNotFound(err)) + + // Test we can get queries in a team + teamRocket, err := ds.NewTeam(context.Background(), &fleet.Team{ + Name: "Team Rocket", + Description: "Something cheesy", + }) + require.NoError(t, err) + + teamRocketQ := test.NewQuery(t, ds, &teamRocket.ID, "q1", "select * from time", user.ID, true) + + actual, err = ds.QueryByName(context.Background(), &teamRocket.ID, teamRocketQ.Name) + require.NoError(t, err) + require.Equal(t, "q1", actual.Name) + require.Equal(t, teamRocket.ID, *actual.TeamID) + require.Equal(t, "select * from time", actual.Query) + + actual, err = ds.QueryByName(context.Background(), &teamRocket.ID, "xxx") require.Error(t, err) require.True(t, fleet.IsNotFound(err)) } @@ -215,7 +241,7 @@ func testQueriesSave(t *testing.T, ds *Datastore) { query, err := ds.NewQuery(context.Background(), query) require.NoError(t, err) require.NotNil(t, query) - assert.NotEqual(t, 0, query.ID) + require.NotEqual(t, 0, query.ID) team, err := ds.NewTeam(context.Background(), &fleet.Team{ Name: "some kind of nature", @@ -235,20 +261,15 @@ func testQueriesSave(t *testing.T, ds *Datastore) { err = ds.SaveQuery(context.Background(), query) require.NoError(t, err) - queryVerify, err := ds.Query(context.Background(), query.ID) + actual, err := ds.Query(context.Background(), query.ID) require.NoError(t, err) - require.NotNil(t, queryVerify) + require.NotNil(t, actual) - assert.Equal(t, "baz", queryVerify.Query) - assert.Equal(t, "Zach", queryVerify.AuthorName) - assert.Equal(t, "zwass@fleet.co", queryVerify.AuthorEmail) - assert.True(t, queryVerify.ObserverCanRun) - assert.Equal(t, *query.TeamID, team.ID) - assert.Equal(t, query.ScheduleInterval, uint(10)) - assert.Equal(t, *query.Platform, "macos") - assert.Equal(t, *query.MinOsqueryVersion, "5.2.1") - assert.Equal(t, query.AutomationsEnabled, true) - assert.Equal(t, query.LoggingType, "differential") + test.QueriesMatch(t, actual, query) + + require.Equal(t, "baz", actual.Query) + require.Equal(t, "Zach", actual.AuthorName) + require.Equal(t, "zwass@fleet.co", actual.AuthorEmail) } func testQueriesList(t *testing.T, ds *Datastore) { @@ -277,8 +298,8 @@ func testQueriesList(t *testing.T, ds *Datastore) { results, err := ds.ListQueries(context.Background(), opts) require.NoError(t, err) require.Equal(t, 10, len(results)) - assert.Equal(t, "Zach", results[0].AuthorName) - assert.Equal(t, "zwass@fleet.co", results[0].AuthorEmail) + require.Equal(t, "Zach", results[0].AuthorName) + require.Equal(t, "zwass@fleet.co", results[0].AuthorEmail) idWithAgg := results[0].ID @@ -290,7 +311,7 @@ func testQueriesList(t *testing.T, ds *Datastore) { results, err = ds.ListQueries(context.Background(), opts) require.NoError(t, err) - assert.Equal(t, 10, len(results)) + require.Equal(t, 10, len(results)) foundAgg := false for _, q := range results { @@ -312,7 +333,7 @@ func testQueriesLoadPacksForQueries(t *testing.T, ds *Datastore) { {Name: "q2", Query: "select * from osquery_info"}, } err := ds.ApplyQueries(context.Background(), zwass.ID, queries) - require.Nil(t, err) + require.NoError(t, err) specs := []*fleet.PackSpec{ {Name: "p1"}, @@ -439,12 +460,12 @@ func testQueriesDuplicateNew(t *testing.T, ds *Datastore) { AuthorID: &user.ID, }) require.NoError(t, err) - assert.NotZero(t, globalQ1.ID) + require.NotZero(t, globalQ1.ID) _, err = ds.NewQuery(context.Background(), &fleet.Query{ Name: "foo", Query: "select * from osquery_info;", }) - assert.Contains(t, err.Error(), "already exists") + require.Contains(t, err.Error(), "already exists") // Check uniqueness constraint on queries that belong to a team team, err := ds.NewTeam(context.Background(), &fleet.Team{ @@ -465,7 +486,7 @@ func testQueriesDuplicateNew(t *testing.T, ds *Datastore) { Query: "select * from osquery_info;", TeamID: &team.ID, }) - assert.Contains(t, err.Error(), "already exists") + require.Contains(t, err.Error(), "already exists") } func testQueriesListFiltersObservers(t *testing.T, ds *Datastore) { @@ -499,7 +520,7 @@ func testQueriesListFiltersObservers(t *testing.T, ds *Datastore) { ) require.NoError(t, err) require.Len(t, queries, 1) - assert.Equal(t, query3.ID, queries[0].ID) + require.Equal(t, query3.ID, queries[0].ID) } func testObserverCanRunQuery(t *testing.T, ds *Datastore) { @@ -532,3 +553,65 @@ func testObserverCanRunQuery(t *testing.T, ds *Datastore) { require.Equal(t, q.ObserverCanRun, canRun) } } + +func testQueriesListFiltersByTeamID(t *testing.T, ds *Datastore) { + globalQ1, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query1", + Query: "select 1;", + Saved: true, + }) + require.NoError(t, err) + globalQ2, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query2", + Query: "select 1;", + Saved: true, + }) + require.NoError(t, err) + globalQ3, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query3", + Query: "select 1;", + Saved: true, + }) + require.NoError(t, err) + + queries, err := ds.ListQueries(context.Background(), fleet.ListQueryOptions{}) + require.NoError(t, err) + test.QueryElementsMatch(t, queries, []*fleet.Query{globalQ1, globalQ2, globalQ3}) + + team, err := ds.NewTeam(context.Background(), &fleet.Team{ + Name: "some kind of nature", + Description: "some kind of goal", + }) + require.NoError(t, err) + + teamQ1, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query1", + Query: "select 1;", + Saved: true, + TeamID: &team.ID, + }) + require.NoError(t, err) + teamQ2, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query2", + Query: "select 1;", + Saved: true, + TeamID: &team.ID, + }) + require.NoError(t, err) + teamQ3, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query3", + Query: "select 1;", + Saved: true, + TeamID: &team.ID, + }) + require.NoError(t, err) + + queries, err = ds.ListQueries( + context.Background(), + fleet.ListQueryOptions{ + TeamID: &team.ID, + }, + ) + require.NoError(t, err) + test.QueryElementsMatch(t, queries, []*fleet.Query{teamQ1, teamQ2, teamQ3}) +} diff --git a/server/test/comparisons.go b/server/test/comparisons.go index 8adfaa0a24..2bb66d8b6d 100644 --- a/server/test/comparisons.go +++ b/server/test/comparisons.go @@ -255,3 +255,32 @@ func QueryElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...inte }, cmp.Ignore()) return ElementsMatchWithOptions(t, listA, listB, []cmp.Option{opt}, msgAndArgs) } + +// QueriesMatch asserts that two queries 'match'. +func QueriesMatch(t TestingT, a, b interface{}, msgAndArgs ...interface{}) (ok bool) { + t.Helper() + + opt := cmp.FilterPath(func(p cmp.Path) bool { + for _, ps := range p { + switch ps := ps.(type) { + case cmp.StructField: + switch ps.Name() { + case "ID", + "UpdateCreateTimestamps", + "AuthorID", + "AuthorName", + "AuthorEmail", + "Packs", + "Saved": + return true + } + } + } + return false + }, cmp.Ignore()) + + if !cmp.Equal(a, b, opt) { + return assert.Fail(t, cmp.Diff(a, b, opt), msgAndArgs...) + } + return true +}