diff --git a/server/datastore/mysql/queries_test.go b/server/datastore/mysql/queries_test.go index d1e67446ef..29bd76c113 100644 --- a/server/datastore/mysql/queries_test.go +++ b/server/datastore/mysql/queries_test.go @@ -55,7 +55,7 @@ func testQueriesApply(t *testing.T, ds *Datastore) { Query: "select * from foo", ObserverCanRun: true, Interval: 10, - Platform: "macos", + Platform: "darwin", MinOsqueryVersion: "5.2.1", AutomationsEnabled: true, Logging: "differential", @@ -135,6 +135,23 @@ func testQueriesApply(t *testing.T, ds *Datastore) { require.Equal(t, zwass.Name, q.AuthorName) } } + + // Zach tries to add a query with an invalid platform string + invalidQueries := []*fleet.Query{ + { + Name: "foo", + Description: "get the foos", + Query: "select * from foo", + ObserverCanRun: true, + Interval: 10, + Platform: "not valid", + MinOsqueryVersion: "5.2.1", + AutomationsEnabled: true, + Logging: "differential", + }, + } + err = ds.ApplyQueries(context.Background(), zwass.ID, invalidQueries) + require.ErrorIs(t, err, fleet.ErrQueryInvalidPlatform) } func testQueriesDelete(t *testing.T, ds *Datastore) { diff --git a/server/fleet/queries.go b/server/fleet/queries.go index 7a6d515635..867680573a 100644 --- a/server/fleet/queries.go +++ b/server/fleet/queries.go @@ -142,6 +142,7 @@ func (q *Query) GetRemoved() *bool { } // Verify verifies the query payload is valid. +// Called when creating or modifying a query func (q *QueryPayload) Verify() error { if q.Name != nil { if err := verifyQueryName(*q.Name); err != nil { @@ -158,10 +159,16 @@ func (q *QueryPayload) Verify() error { return err } } + if q.Platform != nil { + if err := verifyQueryPlatforms(*q.Platform); err != nil { + return err + } + } return nil } // Verify verifies the query fields are valid. +// Called when creating queries by spec func (q *Query) Verify() error { if err := verifyQueryName(q.Name); err != nil { return err @@ -172,6 +179,9 @@ func (q *Query) Verify() error { if err := verifyLogging(q.Logging); err != nil { return err } + if err := verifyQueryPlatforms(q.Platform); err != nil { + return err + } return nil } @@ -196,9 +206,10 @@ func (tq *TargetedQuery) AuthzType() string { } var ( - errQueryEmptyName = errors.New("query name cannot be empty") - errQueryEmptyQuery = errors.New("query's SQL query cannot be empty") - errInvalidLogging = fmt.Errorf("invalid logging value, must be one of '%s', '%s', '%s'", LoggingSnapshot, LoggingDifferential, LoggingDifferentialIgnoreRemovals) + errQueryEmptyName = errors.New("query name cannot be empty") + errQueryEmptyQuery = errors.New("query's SQL query cannot be empty") + ErrQueryInvalidPlatform = errors.New("query's platform must be a comma-separated list of 'darwin', 'linux', 'windows', and/or 'chrome' in a single string") + errInvalidLogging = fmt.Errorf("invalid logging value, must be one of '%s', '%s', '%s'", LoggingSnapshot, LoggingDifferential, LoggingDifferentialIgnoreRemovals) ) func verifyQueryName(name string) error { @@ -223,6 +234,23 @@ func verifyLogging(logging string) error { return nil } +func verifyQueryPlatforms(platforms string) error { + if emptyString(platforms) { + return nil + } + platformsList := strings.Split(platforms, ",") + for _, platform := range platformsList { + // TODO(jacob) – should we accept these strings with spaces? If not, remove `TrimSpace` + switch strings.TrimSpace(platform) { + case "windows", "linux", "darwin", "chrome": + // OK + default: + return ErrQueryInvalidPlatform + } + } + return nil +} + const ( QueryKind = "query" ) diff --git a/server/fleet/queries_test.go b/server/fleet/queries_test.go index 8975a35132..ca0341d348 100644 --- a/server/fleet/queries_test.go +++ b/server/fleet/queries_test.go @@ -180,3 +180,33 @@ func TestRoundtripQueriesYaml(t *testing.T) { }) } } + +func TestVerifyQueryPlatforms(t *testing.T) { + testCases := []struct { + name string + platformString string + shouldErr bool + }{ + {"empty platform string okay", "", false}, + {"platform string 'darwin' okay", "darwin", false}, + {"platform string 'linux' okay", "linux", false}, + {"platform string 'windows' okay", "windows", false}, + {"platform string 'darwin,linux,windows' okay", "darwin,linux,windows", false}, + {"platform string 'foo' invalid – not a supported platform", "foo", true}, + {"platform string 'charles,darwin,linux,windows' invalid – 'charles' not a supported platform", "charles,darwin,linux,windows", true}, + {"platform string 'darwin windows' invalid – missing comma delimiter", "darwin windows", true}, + {"platform string 'charles darwin' invalid – 'charles' not supported and missing comma delimiter", "charles darwin", true}, + {"platform string ';inux' invalid – ';inux' not a supported platform", ";inux", true}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + err := verifyQueryPlatforms(tt.platformString) + if tt.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 2182be1df4..de724804bd 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -4464,9 +4464,10 @@ func (s *integrationTestSuite) TestQueriesBadRequests() { defer cleanupQuery(s, existingQueryID) for _, tc := range []struct { - tname string - name string - query string + tname string + name string + query string + platform string }{ { tname: "empty name", @@ -4483,23 +4484,50 @@ func (s *integrationTestSuite) TestQueriesBadRequests() { name: "Invalid query", query: "", }, + { + tname: "unsupported platform", + name: "bad query", + query: "select 1", + platform: "oops", + }, + { + tname: "unsupported platform", + name: "bad query", + query: "select 1", + platform: "charles,darwin", + }, + { + tname: "missing platform comma delimeter", + name: "bad query", + query: "select 1", + platform: "linuxdarwin", + }, + { + tname: "missing platform comma delimeter", + name: "bad query", + query: "select 1", + platform: "windows darwin", + }, } { t.Run(tc.tname, func(t *testing.T) { reqQuery := &fleet.QueryPayload{ - Name: ptr.String(tc.name), - Query: ptr.String(tc.query), + Name: ptr.String(tc.name), + Query: ptr.String(tc.query), + Platform: ptr.String(tc.platform), } createQueryResp := createQueryResponse{} s.DoJSON("POST", "/api/latest/fleet/queries", reqQuery, http.StatusBadRequest, &createQueryResp) require.Nil(t, createQueryResp.Query) payload := fleet.QueryPayload{ - Name: ptr.String(tc.name), - Query: ptr.String(tc.query), + Name: ptr.String(tc.name), + Query: ptr.String(tc.query), + Platform: ptr.String(tc.platform), } mResp := modifyQueryResponse{} s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/queries/%d", existingQueryID), &payload, http.StatusBadRequest, &mResp) require.Nil(t, mResp.Query) + // TODO – add checks for specific errors }) } } @@ -5036,6 +5064,21 @@ func (s *integrationTestSuite) TestQuerySpecs() { }, }, http.StatusOK, &applyResp) + // try to create a query with invalid platform, fail + q4 := q1 + "_4" + s.DoJSON("POST", "/api/latest/fleet/spec/queries", applyQuerySpecsRequest{ + Specs: []*fleet.QuerySpec{ + {Name: q4, Query: "SELECT 4", Platform: "not valid"}, + }, + }, http.StatusBadRequest, &applyResp) + + // try to edit a query with invalid platform, fail + s.DoJSON("POST", "/api/latest/fleet/spec/queries", applyQuerySpecsRequest{ + Specs: []*fleet.QuerySpec{ + {Name: q3, Query: "SELECT 3", Platform: "charles darwin"}, + }, + }, http.StatusBadRequest, &applyResp) + // list specs - has 3, not 4 (one was an update) s.DoJSON("GET", "/api/latest/fleet/spec/queries", nil, http.StatusOK, &getSpecsResp) require.Len(t, getSpecsResp.Specs, 3) diff --git a/server/service/queries_test.go b/server/service/queries_test.go index 7477e5fa4c..95b4ab764c 100644 --- a/server/service/queries_test.go +++ b/server/service/queries_test.go @@ -114,6 +114,239 @@ func TestListQueries(t *testing.T) { } } +func TestQueryPayloadValidationCreate(t *testing.T) { + ds := new(mock.Store) + ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { + return query, nil + } + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { + act, ok := activity.(fleet.ActivityTypeCreatedSavedQuery) + assert.True(t, ok) + assert.NotEmpty(t, act.Name) + return nil + } + svc, ctx := newTestService(t, ds, nil, nil) + + testCases := []struct { + name string + queryPayload fleet.QueryPayload + shouldErr bool + }{ + { + "All valid", + fleet.QueryPayload{ + Name: ptr.String("test query"), + Query: ptr.String("select 1"), + Logging: ptr.String("snapshot"), + Platform: ptr.String(""), + }, + false, + }, + { + "Invalid - empty string name", + fleet.QueryPayload{ + Name: ptr.String(""), + Query: ptr.String("select 1"), + Logging: ptr.String("snapshot"), + Platform: ptr.String(""), + }, + true, + }, + { + "Empty SQL", + fleet.QueryPayload{ + Name: ptr.String("bad sql"), + Query: ptr.String(""), + Logging: ptr.String("snapshot"), + Platform: ptr.String(""), + }, + true, + }, + { + "Invalid logging", + fleet.QueryPayload{ + Name: ptr.String("bad logging"), + Query: ptr.String("select 1"), + Logging: ptr.String("hopscotch"), + Platform: ptr.String(""), + }, + true, + }, + { + "Unsupported platform", + fleet.QueryPayload{ + Name: ptr.String("invalid platform"), + Query: ptr.String("select 1"), + Logging: ptr.String("differential"), + Platform: ptr.String("charles"), + }, + true, + }, + { + "Missing comma", + fleet.QueryPayload{ + Name: ptr.String("invalid platform"), + Query: ptr.String("select 1"), + Logging: ptr.String("differential"), + Platform: ptr.String("darwin windows"), + }, + true, + }, + { + "Unsupported platform 'sphinx' ", + fleet.QueryPayload{ + Name: ptr.String("invalid platform"), + Query: ptr.String("select 1"), + Logging: ptr.String("differential"), + Platform: ptr.String("darwin,windows,sphinx"), + }, + true, + }, + } + + testAdmin := fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{}, + GlobalRole: ptr.String(fleet.RoleAdmin), + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + viewerCtx := viewer.NewContext(ctx, viewer.Viewer{User: &testAdmin}) + query, err := svc.NewQuery(viewerCtx, tt.queryPayload) + if tt.shouldErr { + assert.Error(t, err) + assert.Nil(t, query) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, query) + } + }) + } +} + +// similar for modify +func TestQueryPayloadValidationModify(t *testing.T) { + ds := new(mock.Store) + ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { + return &fleet.Query{ + ID: id, + Name: "mock saved query", + Description: "some desc", + Query: "select 1;", + Platform: "", + Saved: true, + ObserverCanRun: false, + }, nil + } + ds.SaveQueryFunc = func(ctx context.Context, query *fleet.Query) error { + assert.NotEmpty(t, query) + return nil + } + + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { + act, ok := activity.(fleet.ActivityTypeEditedSavedQuery) + assert.True(t, ok) + assert.NotEmpty(t, act.Name) + return nil + } + svc, ctx := newTestService(t, ds, nil, nil) + + testCases := []struct { + name string + queryPayload fleet.QueryPayload + shouldErr bool + }{ + { + "All valid", + fleet.QueryPayload{ + Name: ptr.String("updated test query"), + Query: ptr.String("select 1"), + Logging: ptr.String("snapshot"), + Platform: ptr.String(""), + }, + false, + }, + { + "Invalid - empty string name", + fleet.QueryPayload{ + Name: ptr.String(""), + Query: ptr.String("select 1"), + Logging: ptr.String("snapshot"), + Platform: ptr.String(""), + }, + true, + }, + { + "Empty SQL", + fleet.QueryPayload{ + Name: ptr.String("bad sql"), + Query: ptr.String(""), + Logging: ptr.String("snapshot"), + Platform: ptr.String(""), + }, + true, + }, + { + "Invalid logging", + fleet.QueryPayload{ + Name: ptr.String("bad logging"), + Query: ptr.String("select 1"), + Logging: ptr.String("hopscotch"), + Platform: ptr.String(""), + }, + true, + }, + { + "Unsupported platform", + fleet.QueryPayload{ + Name: ptr.String("invalid platform"), + Query: ptr.String("select 1"), + Logging: ptr.String("differential"), + Platform: ptr.String("charles"), + }, + true, + }, + { + "Missing comma delimeter in platform string", + fleet.QueryPayload{ + Name: ptr.String("invalid platform"), + Query: ptr.String("select 1"), + Logging: ptr.String("differential"), + Platform: ptr.String("darwin windows"), + }, + true, + }, + { + "Unsupported platform 2", + fleet.QueryPayload{ + Name: ptr.String("invalid platform"), + Query: ptr.String("select 1"), + Logging: ptr.String("differential"), + Platform: ptr.String("darwin,windows,sphinx"), + }, + true, + }, + } + + testAdmin := fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{}, + GlobalRole: ptr.String(fleet.RoleAdmin), + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + viewerCtx := viewer.NewContext(ctx, viewer.Viewer{User: &testAdmin}) + _, err := svc.ModifyQuery(viewerCtx, 1, tt.queryPayload) + if tt.shouldErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestQueryAuth(t *testing.T) { ds := new(mock.Store) svc, ctx := newTestService(t, ds, nil, nil)