Add server validation for query Platform field (#13923)

This commit is contained in:
Jacob Shandling 2023-09-15 13:20:39 -07:00 committed by GitHub
parent 9debc2fd2c
commit e2aa0b28c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 362 additions and 11 deletions

View file

@ -55,7 +55,7 @@ func testQueriesApply(t *testing.T, ds *Datastore) {
Query: "select * from foo",
ObserverCanRun: true,
Interval: 10,
Platform: "macos",
Platform: "darwin",
MinOsqueryVersion: "5.2.1",
AutomationsEnabled: true,
Logging: "differential",
@ -135,6 +135,23 @@ func testQueriesApply(t *testing.T, ds *Datastore) {
require.Equal(t, zwass.Name, q.AuthorName)
}
}
// Zach tries to add a query with an invalid platform string
invalidQueries := []*fleet.Query{
{
Name: "foo",
Description: "get the foos",
Query: "select * from foo",
ObserverCanRun: true,
Interval: 10,
Platform: "not valid",
MinOsqueryVersion: "5.2.1",
AutomationsEnabled: true,
Logging: "differential",
},
}
err = ds.ApplyQueries(context.Background(), zwass.ID, invalidQueries)
require.ErrorIs(t, err, fleet.ErrQueryInvalidPlatform)
}
func testQueriesDelete(t *testing.T, ds *Datastore) {

View file

@ -142,6 +142,7 @@ func (q *Query) GetRemoved() *bool {
}
// Verify verifies the query payload is valid.
// Called when creating or modifying a query
func (q *QueryPayload) Verify() error {
if q.Name != nil {
if err := verifyQueryName(*q.Name); err != nil {
@ -158,10 +159,16 @@ func (q *QueryPayload) Verify() error {
return err
}
}
if q.Platform != nil {
if err := verifyQueryPlatforms(*q.Platform); err != nil {
return err
}
}
return nil
}
// Verify verifies the query fields are valid.
// Called when creating queries by spec
func (q *Query) Verify() error {
if err := verifyQueryName(q.Name); err != nil {
return err
@ -172,6 +179,9 @@ func (q *Query) Verify() error {
if err := verifyLogging(q.Logging); err != nil {
return err
}
if err := verifyQueryPlatforms(q.Platform); err != nil {
return err
}
return nil
}
@ -196,9 +206,10 @@ func (tq *TargetedQuery) AuthzType() string {
}
var (
errQueryEmptyName = errors.New("query name cannot be empty")
errQueryEmptyQuery = errors.New("query's SQL query cannot be empty")
errInvalidLogging = fmt.Errorf("invalid logging value, must be one of '%s', '%s', '%s'", LoggingSnapshot, LoggingDifferential, LoggingDifferentialIgnoreRemovals)
errQueryEmptyName = errors.New("query name cannot be empty")
errQueryEmptyQuery = errors.New("query's SQL query cannot be empty")
ErrQueryInvalidPlatform = errors.New("query's platform must be a comma-separated list of 'darwin', 'linux', 'windows', and/or 'chrome' in a single string")
errInvalidLogging = fmt.Errorf("invalid logging value, must be one of '%s', '%s', '%s'", LoggingSnapshot, LoggingDifferential, LoggingDifferentialIgnoreRemovals)
)
func verifyQueryName(name string) error {
@ -223,6 +234,23 @@ func verifyLogging(logging string) error {
return nil
}
func verifyQueryPlatforms(platforms string) error {
if emptyString(platforms) {
return nil
}
platformsList := strings.Split(platforms, ",")
for _, platform := range platformsList {
// TODO(jacob) should we accept these strings with spaces? If not, remove `TrimSpace`
switch strings.TrimSpace(platform) {
case "windows", "linux", "darwin", "chrome":
// OK
default:
return ErrQueryInvalidPlatform
}
}
return nil
}
const (
QueryKind = "query"
)

View file

@ -180,3 +180,33 @@ func TestRoundtripQueriesYaml(t *testing.T) {
})
}
}
func TestVerifyQueryPlatforms(t *testing.T) {
testCases := []struct {
name string
platformString string
shouldErr bool
}{
{"empty platform string okay", "", false},
{"platform string 'darwin' okay", "darwin", false},
{"platform string 'linux' okay", "linux", false},
{"platform string 'windows' okay", "windows", false},
{"platform string 'darwin,linux,windows' okay", "darwin,linux,windows", false},
{"platform string 'foo' invalid not a supported platform", "foo", true},
{"platform string 'charles,darwin,linux,windows' invalid 'charles' not a supported platform", "charles,darwin,linux,windows", true},
{"platform string 'darwin windows' invalid missing comma delimiter", "darwin windows", true},
{"platform string 'charles darwin' invalid 'charles' not supported and missing comma delimiter", "charles darwin", true},
{"platform string ';inux' invalid  ';inux' not a supported platform", ";inux", true},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
err := verifyQueryPlatforms(tt.platformString)
if tt.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View file

@ -4464,9 +4464,10 @@ func (s *integrationTestSuite) TestQueriesBadRequests() {
defer cleanupQuery(s, existingQueryID)
for _, tc := range []struct {
tname string
name string
query string
tname string
name string
query string
platform string
}{
{
tname: "empty name",
@ -4483,23 +4484,50 @@ func (s *integrationTestSuite) TestQueriesBadRequests() {
name: "Invalid query",
query: "",
},
{
tname: "unsupported platform",
name: "bad query",
query: "select 1",
platform: "oops",
},
{
tname: "unsupported platform",
name: "bad query",
query: "select 1",
platform: "charles,darwin",
},
{
tname: "missing platform comma delimeter",
name: "bad query",
query: "select 1",
platform: "linuxdarwin",
},
{
tname: "missing platform comma delimeter",
name: "bad query",
query: "select 1",
platform: "windows darwin",
},
} {
t.Run(tc.tname, func(t *testing.T) {
reqQuery := &fleet.QueryPayload{
Name: ptr.String(tc.name),
Query: ptr.String(tc.query),
Name: ptr.String(tc.name),
Query: ptr.String(tc.query),
Platform: ptr.String(tc.platform),
}
createQueryResp := createQueryResponse{}
s.DoJSON("POST", "/api/latest/fleet/queries", reqQuery, http.StatusBadRequest, &createQueryResp)
require.Nil(t, createQueryResp.Query)
payload := fleet.QueryPayload{
Name: ptr.String(tc.name),
Query: ptr.String(tc.query),
Name: ptr.String(tc.name),
Query: ptr.String(tc.query),
Platform: ptr.String(tc.platform),
}
mResp := modifyQueryResponse{}
s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/queries/%d", existingQueryID), &payload, http.StatusBadRequest, &mResp)
require.Nil(t, mResp.Query)
// TODO add checks for specific errors
})
}
}
@ -5036,6 +5064,21 @@ func (s *integrationTestSuite) TestQuerySpecs() {
},
}, http.StatusOK, &applyResp)
// try to create a query with invalid platform, fail
q4 := q1 + "_4"
s.DoJSON("POST", "/api/latest/fleet/spec/queries", applyQuerySpecsRequest{
Specs: []*fleet.QuerySpec{
{Name: q4, Query: "SELECT 4", Platform: "not valid"},
},
}, http.StatusBadRequest, &applyResp)
// try to edit a query with invalid platform, fail
s.DoJSON("POST", "/api/latest/fleet/spec/queries", applyQuerySpecsRequest{
Specs: []*fleet.QuerySpec{
{Name: q3, Query: "SELECT 3", Platform: "charles darwin"},
},
}, http.StatusBadRequest, &applyResp)
// list specs - has 3, not 4 (one was an update)
s.DoJSON("GET", "/api/latest/fleet/spec/queries", nil, http.StatusOK, &getSpecsResp)
require.Len(t, getSpecsResp.Specs, 3)

View file

@ -114,6 +114,239 @@ func TestListQueries(t *testing.T) {
}
}
func TestQueryPayloadValidationCreate(t *testing.T) {
ds := new(mock.Store)
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
return query, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
act, ok := activity.(fleet.ActivityTypeCreatedSavedQuery)
assert.True(t, ok)
assert.NotEmpty(t, act.Name)
return nil
}
svc, ctx := newTestService(t, ds, nil, nil)
testCases := []struct {
name string
queryPayload fleet.QueryPayload
shouldErr bool
}{
{
"All valid",
fleet.QueryPayload{
Name: ptr.String("test query"),
Query: ptr.String("select 1"),
Logging: ptr.String("snapshot"),
Platform: ptr.String(""),
},
false,
},
{
"Invalid - empty string name",
fleet.QueryPayload{
Name: ptr.String(""),
Query: ptr.String("select 1"),
Logging: ptr.String("snapshot"),
Platform: ptr.String(""),
},
true,
},
{
"Empty SQL",
fleet.QueryPayload{
Name: ptr.String("bad sql"),
Query: ptr.String(""),
Logging: ptr.String("snapshot"),
Platform: ptr.String(""),
},
true,
},
{
"Invalid logging",
fleet.QueryPayload{
Name: ptr.String("bad logging"),
Query: ptr.String("select 1"),
Logging: ptr.String("hopscotch"),
Platform: ptr.String(""),
},
true,
},
{
"Unsupported platform",
fleet.QueryPayload{
Name: ptr.String("invalid platform"),
Query: ptr.String("select 1"),
Logging: ptr.String("differential"),
Platform: ptr.String("charles"),
},
true,
},
{
"Missing comma",
fleet.QueryPayload{
Name: ptr.String("invalid platform"),
Query: ptr.String("select 1"),
Logging: ptr.String("differential"),
Platform: ptr.String("darwin windows"),
},
true,
},
{
"Unsupported platform 'sphinx' ",
fleet.QueryPayload{
Name: ptr.String("invalid platform"),
Query: ptr.String("select 1"),
Logging: ptr.String("differential"),
Platform: ptr.String("darwin,windows,sphinx"),
},
true,
},
}
testAdmin := fleet.User{
ID: 1,
Teams: []fleet.UserTeam{},
GlobalRole: ptr.String(fleet.RoleAdmin),
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
viewerCtx := viewer.NewContext(ctx, viewer.Viewer{User: &testAdmin})
query, err := svc.NewQuery(viewerCtx, tt.queryPayload)
if tt.shouldErr {
assert.Error(t, err)
assert.Nil(t, query)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, query)
}
})
}
}
// similar for modify
func TestQueryPayloadValidationModify(t *testing.T) {
ds := new(mock.Store)
ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) {
return &fleet.Query{
ID: id,
Name: "mock saved query",
Description: "some desc",
Query: "select 1;",
Platform: "",
Saved: true,
ObserverCanRun: false,
}, nil
}
ds.SaveQueryFunc = func(ctx context.Context, query *fleet.Query) error {
assert.NotEmpty(t, query)
return nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
act, ok := activity.(fleet.ActivityTypeEditedSavedQuery)
assert.True(t, ok)
assert.NotEmpty(t, act.Name)
return nil
}
svc, ctx := newTestService(t, ds, nil, nil)
testCases := []struct {
name string
queryPayload fleet.QueryPayload
shouldErr bool
}{
{
"All valid",
fleet.QueryPayload{
Name: ptr.String("updated test query"),
Query: ptr.String("select 1"),
Logging: ptr.String("snapshot"),
Platform: ptr.String(""),
},
false,
},
{
"Invalid - empty string name",
fleet.QueryPayload{
Name: ptr.String(""),
Query: ptr.String("select 1"),
Logging: ptr.String("snapshot"),
Platform: ptr.String(""),
},
true,
},
{
"Empty SQL",
fleet.QueryPayload{
Name: ptr.String("bad sql"),
Query: ptr.String(""),
Logging: ptr.String("snapshot"),
Platform: ptr.String(""),
},
true,
},
{
"Invalid logging",
fleet.QueryPayload{
Name: ptr.String("bad logging"),
Query: ptr.String("select 1"),
Logging: ptr.String("hopscotch"),
Platform: ptr.String(""),
},
true,
},
{
"Unsupported platform",
fleet.QueryPayload{
Name: ptr.String("invalid platform"),
Query: ptr.String("select 1"),
Logging: ptr.String("differential"),
Platform: ptr.String("charles"),
},
true,
},
{
"Missing comma delimeter in platform string",
fleet.QueryPayload{
Name: ptr.String("invalid platform"),
Query: ptr.String("select 1"),
Logging: ptr.String("differential"),
Platform: ptr.String("darwin windows"),
},
true,
},
{
"Unsupported platform 2",
fleet.QueryPayload{
Name: ptr.String("invalid platform"),
Query: ptr.String("select 1"),
Logging: ptr.String("differential"),
Platform: ptr.String("darwin,windows,sphinx"),
},
true,
},
}
testAdmin := fleet.User{
ID: 1,
Teams: []fleet.UserTeam{},
GlobalRole: ptr.String(fleet.RoleAdmin),
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
viewerCtx := viewer.NewContext(ctx, viewer.Viewer{User: &testAdmin})
_, err := svc.ModifyQuery(viewerCtx, 1, tt.queryPayload)
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestQueryAuth(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)