diff --git a/changes/issue-3704-non-empty-packs-queries-policies b/changes/issue-3704-non-empty-packs-queries-policies new file mode 100644 index 0000000000..e424af75d6 --- /dev/null +++ b/changes/issue-3704-non-empty-packs-queries-policies @@ -0,0 +1 @@ +* Return 400 when trying to create queries, packs and policies with empty names. diff --git a/server/fleet/packs.go b/server/fleet/packs.go index 3fdd726d45..21e6f010f1 100644 --- a/server/fleet/packs.go +++ b/server/fleet/packs.go @@ -1,5 +1,9 @@ package fleet +import ( + "errors" +) + type PackListOptions struct { ListOptions @@ -24,6 +28,14 @@ type Pack struct { TeamIDs []uint `json:"team_ids"` } +// Verify verifies the pack's fields are valid. +func (p *Pack) Verify() error { + if emptyString(p.Name) { + return errPackEmptyName + } + return nil +} + // EditablePackType only returns true when the pack doesn't have a specific Type set, only nil & empty string Pack.Type // is editable https://github.com/fleetdm/fleet/issues/1485 func (p *Pack) EditablePackType() bool { @@ -49,6 +61,18 @@ type PackPayload struct { TeamIDs *[]uint `json:"team_ids"` } +var errPackEmptyName = errors.New("pack name cannot be empty") + +// Verify verifies the pack's payload fields are valid. +func (p *PackPayload) Verify() error { + if p.Name != nil { + if emptyString(*p.Name) { + return errPackEmptyName + } + } + return nil +} + type PackSpec struct { ID uint `json:"id,omitempty"` Name string `json:"name"` @@ -59,6 +83,14 @@ type PackSpec struct { Queries []PackSpecQuery `json:"queries,omitempty"` } +// Verify verifies the pack's spec fields are valid. +func (p *PackSpec) Verify() error { + if emptyString(p.Name) { + return errPackEmptyName + } + return nil +} + type PackSpecTargets struct { Labels []string `json:"labels"` } diff --git a/server/fleet/policies.go b/server/fleet/policies.go index 56fcf5186e..e80bf33119 100644 --- a/server/fleet/policies.go +++ b/server/fleet/policies.go @@ -58,14 +58,18 @@ func (p PolicyPayload) Verify() error { } func verifyPolicyName(name string) error { - if name == "" { + if emptyString(name) { return errPolicyEmptyName } return nil } +func emptyString(s string) bool { + return len(strings.TrimSpace(s)) == 0 +} + func verifyPolicyQuery(query string) error { - if query == "" { + if emptyString(query) { return errPolicyEmptyQuery } if validateSQLRegexp.MatchString(query) { diff --git a/server/fleet/queries.go b/server/fleet/queries.go index 807a9e1fd1..260aa0fe98 100644 --- a/server/fleet/queries.go +++ b/server/fleet/queries.go @@ -44,15 +44,59 @@ func (q Query) AuthzType() string { return "query" } +// Verify verifies the query payload is valid. +func (q *QueryPayload) Verify() error { + if q.Name != nil { + if err := verifyQueryName(*q.Name); err != nil { + return err + } + } + if q.Query != nil { + if err := verifyQuerySQL(*q.Query); err != nil { + return err + } + } + return nil +} + +// Verify verifies the query fields are valid. +func (q *Query) Verify() error { + if err := verifyQueryName(q.Name); err != nil { + return err + } + if err := verifyQuerySQL(q.Query); err != nil { + return err + } + return nil +} + var ( - validateSQLRegexp = regexp.MustCompile(`(?i)attach[^\w]+.*[^\w]+as[^\w]+`) + validateSQLRegexp = regexp.MustCompile(`(?i)attach[^\w]+.*[^\w]+as[^\w]+`) + errQueryEmptyName = errors.New("query name cannot be empty") + errQueryEmptyQuery = errors.New("query's SQL query cannot be empty") + errQueryInvalidSQL = errors.New("invalid query's SQL") ) -// ValidateSQL performs security validations on the input query. It does not -// actually determine whether the query is well formed. -func (q Query) ValidateSQL() error { - if validateSQLRegexp.MatchString(q.Query) { - return errors.New("ATTACH not allowed in queries") +func verifyQueryName(name string) error { + if emptyString(name) { + return errQueryEmptyName + } + return nil +} + +func verifyQuerySQL(query string) error { + if emptyString(query) { + return errQueryEmptyQuery + } + if err := verifySQL(query); err != nil { + return err + } + return nil +} + +func verifySQL(query string) error { + if validateSQLRegexp.MatchString(query) { + return errQueryInvalidSQL } return nil } diff --git a/server/fleet/queries_test.go b/server/fleet/queries_test.go index 0aeefbd74e..1452efdaf2 100644 --- a/server/fleet/queries_test.go +++ b/server/fleet/queries_test.go @@ -127,7 +127,7 @@ func TestValidateSQL(t *testing.T) { for _, tt := range testCases { t.Run(tt.sql, func(t *testing.T) { - err := Query{Query: tt.sql}.ValidateSQL() + err := verifySQL(tt.sql) if tt.shouldErr { require.Error(t, err) } else { diff --git a/server/service/global_policies.go b/server/service/global_policies.go index d56d96d7d6..f027d36482 100644 --- a/server/service/global_policies.go +++ b/server/service/global_policies.go @@ -56,9 +56,9 @@ func (svc Service) NewGlobalPolicy(ctx context.Context, p fleet.PolicyPayload) ( return nil, errors.New("user must be authenticated to create team policies") } if err := p.Verify(); err != nil { - return nil, &badRequestError{ + return nil, ctxerr.Wrap(ctx, &badRequestError{ message: fmt.Sprintf("policy payload verification: %s", err), - } + }) } policy, err := svc.ds.NewGlobalPolicy(ctx, ptr.Uint(vc.UserID()), p) if err != nil { @@ -258,7 +258,9 @@ func (svc Service) ApplyPolicySpecs(ctx context.Context, policies []*fleet.Polic checkGlobalPolicyAuth := false for _, policy := range policies { if err := policy.Verify(); err != nil { - return ctxerr.Wrap(ctx, err, "verifying spec") + return ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("policy spec payload verification: %s", err), + }) } if policy.Team != "" { team, err := svc.ds.TeamByName(ctx, policy.Team) diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index e6017c957c..849374b96a 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1179,6 +1179,12 @@ func (s *integrationTestSuite) TestTeamPoliciesProprietaryInvalid() { name: "", query: "select 1;", }, + { + tname: "empty with space", + testUpdate: true, + name: " ", // #3704 + query: "select 1;", + }, { tname: "Invalid query", testUpdate: true, @@ -2053,6 +2059,94 @@ func (s *integrationTestSuite) TestGlobalPoliciesAutomationConfig() { require.Empty(t, config.WebhookSettings.FailingPoliciesWebhook.PolicyIDs) } +func (s *integrationTestSuite) TestQueriesBadRequests() { + t := s.T() + + reqQuery := &fleet.QueryPayload{ + Name: ptr.String("existing query"), + Query: ptr.String("select 42;"), + } + createQueryResp := createQueryResponse{} + s.DoJSON("POST", "/api/v1/fleet/queries", reqQuery, http.StatusOK, &createQueryResp) + require.NotNil(t, createQueryResp.Query) + existingQueryID := createQueryResp.Query.ID + + for _, tc := range []struct { + tname string + name string + query string + }{ + { + tname: "empty name", + name: " ", // #3704 + query: "select 42;", + }, + { + tname: "empty query", + name: "Some name", + query: "", + }, + { + tname: "Invalid query", + name: "Invalid query", + query: "ATTACH 'foo' AS bar;", + }, + } { + t.Run(tc.tname, func(t *testing.T) { + reqQuery := &fleet.QueryPayload{ + Name: ptr.String(tc.name), + Query: ptr.String(tc.query), + } + createQueryResp := createQueryResponse{} + s.DoJSON("POST", "/api/v1/fleet/queries", reqQuery, http.StatusBadRequest, &createQueryResp) + require.Nil(t, createQueryResp.Query) + + payload := fleet.QueryPayload{ + Name: ptr.String(tc.name), + Query: ptr.String(tc.query), + } + mResp := modifyQueryResponse{} + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/queries/%d", existingQueryID), &payload, http.StatusBadRequest, &mResp) + require.Nil(t, mResp.Query) + }) + } +} + +func (s *integrationTestSuite) TestPacksBadRequests() { + t := s.T() + + reqPacks := &fleet.PackPayload{ + Name: ptr.String("existing pack"), + } + createPackResp := createPackResponse{} + s.DoJSON("POST", "/api/v1/fleet/packs", reqPacks, http.StatusOK, &createPackResp) + existingPackID := createPackResp.Pack.ID + + for _, tc := range []struct { + tname string + name string + }{ + { + tname: "empty name", + name: " ", // #3704 + }, + } { + t.Run(tc.tname, func(t *testing.T) { + reqQuery := &fleet.PackPayload{ + Name: ptr.String(tc.name), + } + createPackResp := createQueryResponse{} + s.DoJSON("POST", "/api/v1/fleet/packs", reqQuery, http.StatusBadRequest, &createPackResp) + + payload := fleet.PackPayload{ + Name: ptr.String(tc.name), + } + mResp := modifyPackResponse{} + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/packs/%d", existingPackID), &payload, http.StatusBadRequest, &mResp) + }) + } +} + func (s *integrationTestSuite) TestTeamsEndpointsWithoutLicense() { t := s.T() diff --git a/server/service/packs.go b/server/service/packs.go index eb19e0aab6..fdb469a525 100644 --- a/server/service/packs.go +++ b/server/service/packs.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" ) @@ -125,6 +126,12 @@ func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pa return nil, err } + if err := p.Verify(); err != nil { + return nil, ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("pack payload verification: %s", err), + }) + } + var pack fleet.Pack if p.Name != nil { @@ -210,6 +217,12 @@ func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload return nil, err } + if err := p.Verify(); err != nil { + return nil, ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("pack payload verification: %s", err), + }) + } + pack, err := svc.ds.Pack(ctx, id) if err != nil { return nil, err @@ -452,6 +465,14 @@ func (svc *Service) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) } } + for _, packSpec := range result { + if err := packSpec.Verify(); err != nil { + return nil, ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("pack payload verification: %s", err), + }) + } + } + if err := svc.ds.ApplyPackSpecs(ctx, result); err != nil { return nil, err } diff --git a/server/service/service_campaigns.go b/server/service/service_campaigns.go index 882b6d61c2..939091105a 100644 --- a/server/service/service_campaigns.go +++ b/server/service/service_campaigns.go @@ -70,8 +70,7 @@ func (svc Service) NewDistributedQueryCampaign(ctx context.Context, queryString Saved: false, AuthorID: ptr.Uint(vc.UserID()), } - err := query.ValidateSQL() - if err != nil { + if err := query.Verify(); err != nil { return nil, err } query, err = svc.ds.NewQuery(ctx, query) diff --git a/server/service/service_queries.go b/server/service/service_queries.go index 72c70435e7..bfcf712c42 100644 --- a/server/service/service_queries.go +++ b/server/service/service_queries.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" @@ -43,8 +44,10 @@ func (svc Service) ApplyQuerySpecs(ctx context.Context, specs []*fleet.QuerySpec } for _, query := range queries { - if err := query.ValidateSQL(); err != nil { - return err + if err := query.Verify(); err != nil { + return ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("query payload verification: %s", err), + }) } } @@ -143,6 +146,12 @@ func (svc *Service) NewQuery(ctx context.Context, p fleet.QueryPayload) (*fleet. return nil, err } + if err := p.Verify(); err != nil { + return nil, ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("query payload verification: %s", err), + }) + } + query := &fleet.Query{Saved: true} if p.Name != nil { @@ -170,10 +179,6 @@ func (svc *Service) NewQuery(ctx context.Context, p fleet.QueryPayload) (*fleet. query.AuthorEmail = vc.Email() } - if err := query.ValidateSQL(); err != nil { - return nil, err - } - query, err := svc.ds.NewQuery(ctx, query) if err != nil { return nil, err @@ -197,6 +202,12 @@ func (svc *Service) ModifyQuery(ctx context.Context, id uint, p fleet.QueryPaylo return nil, err } + if err := p.Verify(); err != nil { + return nil, ctxerr.Wrap(ctx, &badRequestError{ + message: fmt.Sprintf("query payload verification: %s", err), + }) + } + query, err := svc.ds.Query(ctx, id) if err != nil { return nil, err @@ -225,10 +236,6 @@ func (svc *Service) ModifyQuery(ctx context.Context, id uint, p fleet.QueryPaylo query.ObserverCanRun = *p.ObserverCanRun } - if err := query.ValidateSQL(); err != nil { - return nil, err - } - if err := svc.ds.SaveQuery(ctx, query); err != nil { return nil, err } diff --git a/server/service/service_queries_test.go b/server/service/service_queries_test.go index d012da8ada..935c46d5b3 100644 --- a/server/service/service_queries_test.go +++ b/server/service/service_queries_test.go @@ -122,7 +122,7 @@ func TestQueryAuth(t *testing.T) { return 0, nil } - var testCases = []struct { + testCases := []struct { name string user *fleet.User qid uint diff --git a/server/service/team_policies.go b/server/service/team_policies.go index b7ddc6a730..030615274e 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -64,9 +64,9 @@ func (svc Service) NewTeamPolicy(ctx context.Context, teamID uint, p fleet.Polic } if err := p.Verify(); err != nil { - return nil, &badRequestError{ + return nil, ctxerr.Wrap(ctx, &badRequestError{ message: fmt.Sprintf("policy payload verification: %s", err), - } + }) } policy, err := svc.ds.NewTeamPolicy(ctx, teamID, ptr.Uint(vc.UserID()), p) if err != nil { @@ -245,9 +245,9 @@ func (svc Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p fl } if err := p.Verify(); err != nil { - return nil, &badRequestError{ + return nil, ctxerr.Wrap(ctx, &badRequestError{ message: fmt.Sprintf("policy payload verification: %s", err), - } + }) } if p.Name != nil {