Issue 1512 filter observer can run queries (#2110)

* wip

* Filter queries for observers

* Update e2e test now that we filter queries
This commit is contained in:
Tomas Touceda 2021-09-20 13:07:51 -03:00 committed by GitHub
parent e286ee387e
commit b32b441c12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 164 additions and 35 deletions

View file

@ -0,0 +1 @@
* Only show observers queries they can run.

View file

@ -67,14 +67,7 @@ describe("Free tier - Observer user", () => {
cy.visit("/queries/manage");
cy.findByText(/get authorized/i).click();
cy.findByText(/packs/i).should("not.exist");
cy.findByLabelText(/query name/i).should("not.exist");
cy.findByLabelText(/sql/i).should("not.exist");
cy.findByLabelText(/description/i).should("not.exist");
cy.findByLabelText(/observer can run/i).should("not.exist");
cy.findByText(/show sql/i).click();
cy.findByRole("button", { name: /run query/i }).should("not.exist");
cy.findByText(/get authorized/i).should("not.exist");
// On the Profile page, they should…
// See Observer in Role section, and no Team section

View file

@ -170,7 +170,7 @@ func (d *Datastore) Query(ctx context.Context, id uint) (*fleet.Query, error) {
// ListQueries returns a list of queries with sort order and results limit
// determined by passed in fleet.ListOptions
func (d *Datastore) ListQueries(ctx context.Context, opt fleet.ListOptions) ([]*fleet.Query, error) {
func (d *Datastore) ListQueries(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error) {
sql := `
SELECT q.*, COALESCE(u.name, '<deleted>') AS author_name
FROM queries q
@ -178,7 +178,11 @@ func (d *Datastore) ListQueries(ctx context.Context, opt fleet.ListOptions) ([]*
ON q.author_id = u.id
WHERE saved = true
`
sql = appendListOptionsToSQL(sql, opt)
if opt.OnlyObserverCanRun {
sql += " AND q.observer_can_run=true"
}
sql = appendListOptionsToSQL(sql, opt.ListOptions)
results := []*fleet.Query{}
if err := sqlx.SelectContext(ctx, d.reader, &results, sql); err != nil {

View file

@ -29,7 +29,7 @@ func TestApplyQueries(t *testing.T) {
err := ds.ApplyQueries(context.Background(), zwass.ID, expectedQueries)
require.Nil(t, err)
queries, err := ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err := ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
require.Len(t, queries, len(expectedQueries))
for i, q := range queries {
@ -47,7 +47,7 @@ func TestApplyQueries(t *testing.T) {
err = ds.ApplyQueries(context.Background(), groob.ID, expectedQueries)
require.Nil(t, err)
queries, err = ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err = ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
require.Len(t, queries, len(expectedQueries))
for i, q := range queries {
@ -65,7 +65,7 @@ func TestApplyQueries(t *testing.T) {
err = ds.ApplyQueries(context.Background(), zwass.ID, []*fleet.Query{expectedQueries[2]})
require.Nil(t, err)
queries, err = ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err = ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
require.Len(t, queries, len(expectedQueries))
for i, q := range queries {
@ -130,7 +130,7 @@ func TestDeleteQueries(t *testing.T) {
q3 := test.NewQuery(t, ds, "q3", "select 1", user.ID, true)
q4 := test.NewQuery(t, ds, "q4", "select * from osquery_info", user.ID, true)
queries, err := ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err := ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
assert.Len(t, queries, 4)
@ -138,7 +138,7 @@ func TestDeleteQueries(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, uint(2), deleted)
queries, err = ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err = ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
assert.Len(t, queries, 2)
@ -146,7 +146,7 @@ func TestDeleteQueries(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, uint(1), deleted)
queries, err = ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err = ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
assert.Len(t, queries, 1)
@ -154,7 +154,7 @@ func TestDeleteQueries(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, uint(1), deleted)
queries, err = ds.ListQueries(context.Background(), fleet.ListOptions{})
queries, err = ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.Nil(t, err)
assert.Len(t, queries, 0)
@ -215,7 +215,7 @@ func TestListQuery(t *testing.T) {
})
require.Nil(t, err)
opts := fleet.ListOptions{}
opts := fleet.ListQueryOptions{}
results, err := ds.ListQueries(context.Background(), opts)
assert.Nil(t, err)
assert.Equal(t, 10, len(results))
@ -234,9 +234,9 @@ func TestLoadPacksForQueries(t *testing.T) {
require.Nil(t, err)
specs := []*fleet.PackSpec{
&fleet.PackSpec{Name: "p1"},
&fleet.PackSpec{Name: "p2"},
&fleet.PackSpec{Name: "p3"},
{Name: "p1"},
{Name: "p2"},
{Name: "p3"},
}
err = ds.ApplyPackSpecs(context.Background(), specs)
require.Nil(t, err)
@ -250,10 +250,10 @@ func TestLoadPacksForQueries(t *testing.T) {
assert.Empty(t, q1.Packs)
specs = []*fleet.PackSpec{
&fleet.PackSpec{
{
Name: "p2",
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
Name: "q0",
QueryName: queries[0].Name,
Interval: 60,
@ -275,19 +275,19 @@ func TestLoadPacksForQueries(t *testing.T) {
assert.Empty(t, q1.Packs)
specs = []*fleet.PackSpec{
&fleet.PackSpec{
{
Name: "p1",
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
QueryName: queries[1].Name,
Interval: 60,
},
},
},
&fleet.PackSpec{
{
Name: "p3",
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
QueryName: queries[1].Name,
Interval: 60,
},
@ -312,15 +312,15 @@ func TestLoadPacksForQueries(t *testing.T) {
}
specs = []*fleet.PackSpec{
&fleet.PackSpec{
{
Name: "p3",
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
Name: "q0",
QueryName: queries[0].Name,
Interval: 60,
},
fleet.PackSpecQuery{
{
Name: "q1",
QueryName: queries[1].Name,
Interval: 60,
@ -370,3 +370,40 @@ func TestDuplicateNewQuery(t *testing.T) {
// is private to the individual datastore implementations
assert.Contains(t, err.Error(), "already exists")
}
func TestListQueryFiltersObserver(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
_, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query1",
Query: "select 1;",
Saved: true,
})
require.NoError(t, err)
_, err = ds.NewQuery(context.Background(), &fleet.Query{
Name: "query2",
Query: "select 1;",
Saved: true,
})
require.NoError(t, err)
query3, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query3",
Query: "select 1;",
Saved: true,
ObserverCanRun: true,
})
require.NoError(t, err)
queries, err := ds.ListQueries(context.Background(), fleet.ListQueryOptions{})
require.NoError(t, err)
require.Len(t, queries, 3)
queries, err = ds.ListQueries(
context.Background(),
fleet.ListQueryOptions{OnlyObserverCanRun: true, ListOptions: fleet.ListOptions{PerPage: 1}},
)
require.NoError(t, err)
require.Len(t, queries, 1)
assert.Equal(t, query3.ID, queries[0].ID)
}

View file

@ -243,6 +243,12 @@ type ListOptions struct {
MatchQuery string
}
type ListQueryOptions struct {
ListOptions
OnlyObserverCanRun bool
}
// EnrollSecret contains information about an enroll secret, name, and active
// status. Enroll secrets are used for osquery authentication.
type EnrollSecret struct {

View file

@ -61,7 +61,7 @@ type Datastore interface {
Query(ctx context.Context, id uint) (*Query, error)
// ListQueries returns a list of queries with the provided sorting and paging options. Associated packs should also
// be loaded.
ListQueries(ctx context.Context, opt ListOptions) ([]*Query, error)
ListQueries(ctx context.Context, opt ListQueryOptions) ([]*Query, error)
// QueryByName looks up a query by name.
QueryByName(ctx context.Context, name string, opts ...OptionalArg) (*Query, error)

View file

@ -59,7 +59,7 @@ type DeleteQueriesFunc func(ctx context.Context, ids []uint) (uint, error)
type QueryFunc func(ctx context.Context, id uint) (*fleet.Query, error)
type ListQueriesFunc func(ctx context.Context, opt fleet.ListOptions) ([]*fleet.Query, error)
type ListQueriesFunc func(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error)
type QueryByNameFunc func(ctx context.Context, name string, opts ...fleet.OptionalArg) (*fleet.Query, error)
@ -818,7 +818,7 @@ func (s *DataStore) Query(ctx context.Context, id uint) (*fleet.Query, error) {
return s.QueryFunc(ctx, id)
}
func (s *DataStore) ListQueries(ctx context.Context, opt fleet.ListOptions) ([]*fleet.Query, error) {
func (s *DataStore) ListQueries(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error) {
s.ListQueriesFuncInvoked = true
return s.ListQueriesFunc(ctx, opt)
}

View file

@ -66,7 +66,7 @@ func (svc Service) GetQuerySpecs(ctx context.Context) ([]*fleet.QuerySpec, error
return nil, err
}
queries, err := svc.ds.ListQueries(ctx, fleet.ListOptions{})
queries, err := svc.ds.ListQueries(ctx, fleet.ListQueryOptions{})
if err != nil {
return nil, errors.Wrap(err, "getting queries")
}
@ -90,12 +90,39 @@ func (svc Service) GetQuerySpec(ctx context.Context, name string) (*fleet.QueryS
return specFromQuery(query), nil
}
func onlyShowObserverCanRunQueries(user *fleet.User) bool {
if user.GlobalRole != nil && *user.GlobalRole == fleet.RoleObserver {
return true
} else if len(user.Teams) > 0 {
allObserver := true
for _, team := range user.Teams {
if team.Role != fleet.RoleObserver {
allObserver = false
break
}
}
return allObserver
}
return false
}
func (svc Service) ListQueries(ctx context.Context, opt fleet.ListOptions) ([]*fleet.Query, error) {
if err := svc.authz.Authorize(ctx, &fleet.Query{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListQueries(ctx, opt)
user := authz.UserFromContext(ctx)
onlyShowObserverCanRun := onlyShowObserverCanRunQueries(user)
queries, err := svc.ds.ListQueries(ctx, fleet.ListQueryOptions{
ListOptions: opt,
OnlyObserverCanRun: onlyShowObserverCanRun,
})
if err != nil {
return nil, err
}
return queries, nil
}
func (svc *Service) GetQuery(ctx context.Context, id uint) (*fleet.Query, error) {

View file

@ -4,8 +4,11 @@ import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -21,3 +24,61 @@ func TestNewQueryAttach(t *testing.T) {
)
require.Error(t, err)
}
func TestFilterQueriesForObserver(t *testing.T) {
require.True(t, onlyShowObserverCanRunQueries(&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}))
require.False(t, onlyShowObserverCanRunQueries(&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}))
require.False(t, onlyShowObserverCanRunQueries(&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}))
require.True(t, onlyShowObserverCanRunQueries(&fleet.User{Teams: []fleet.UserTeam{{Role: fleet.RoleObserver}}}))
require.True(t, onlyShowObserverCanRunQueries(&fleet.User{Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver},
{Role: fleet.RoleObserver},
}}))
require.False(t, onlyShowObserverCanRunQueries(&fleet.User{Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver},
{Role: fleet.RoleMaintainer},
}}))
}
func TestListQueries(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
cases := [...]struct {
title string
user *fleet.User
expectedOpts fleet.ListQueryOptions
}{
{
title: "global admin",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
expectedOpts: fleet.ListQueryOptions{OnlyObserverCanRun: false},
},
{
title: "global observer",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
expectedOpts: fleet.ListQueryOptions{OnlyObserverCanRun: true},
},
{
title: "team admin",
user: &fleet.User{Teams: []fleet.UserTeam{{Role: fleet.RoleAdmin}}},
expectedOpts: fleet.ListQueryOptions{OnlyObserverCanRun: false},
},
}
var calledWithOpts fleet.ListQueryOptions
ds.ListQueriesFunc = func(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error) {
calledWithOpts = opt
return []*fleet.Query{}, nil
}
for _, tt := range cases {
t.Run(tt.title, func(t *testing.T) {
viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
_, err := svc.ListQueries(viewerCtx, fleet.ListOptions{})
require.NoError(t, err)
assert.Equal(t, tt.expectedOpts, calledWithOpts)
})
}
}