From 49b04ba4a50d6a3c3d4d5d5f86c010e2918a7335 Mon Sep 17 00:00:00 2001
From: Jacob Shandling <61553566+jacobshandling@users.noreply.github.com>
Date: Wed, 17 May 2023 13:41:30 -0700
Subject: [PATCH] For requests with invalid list options, return 400 instead of
500 (#11632)
## Addresses #11272
- For requests with invalid list options (`page`, `per_page`,
`order_key`, `order_direction`), return `400` instead of `500`
# Checklist for submitter
If some of the following don't apply, delete the relevant line.
- [x] Manual QA for all new/changed functionality
---------
Co-authored-by: Jacob Shandling
---
server/service/transport.go | 13 ++++++------
server/service/transport_test.go | 35 +++++++++++++++++---------------
2 files changed, 25 insertions(+), 23 deletions(-)
diff --git a/server/service/transport.go b/server/service/transport.go
index 98fbdd9dc3..815740f8e3 100644
--- a/server/service/transport.go
+++ b/server/service/transport.go
@@ -130,10 +130,10 @@ func listOptionsFromRequest(r *http.Request) (fleet.ListOptions, error) {
if pageString != "" {
page, err = strconv.Atoi(pageString)
if err != nil {
- return fleet.ListOptions{}, ctxerr.New(r.Context(), "non-int page value")
+ return fleet.ListOptions{}, ctxerr.Wrap(r.Context(), badRequest("non-int page value"))
}
if page < 0 {
- return fleet.ListOptions{}, ctxerr.New(r.Context(), "negative page value")
+ return fleet.ListOptions{}, ctxerr.Wrap(r.Context(), badRequest("negative page value"))
}
}
@@ -143,10 +143,10 @@ func listOptionsFromRequest(r *http.Request) (fleet.ListOptions, error) {
if perPageString != "" {
perPage, err = strconv.Atoi(perPageString)
if err != nil {
- return fleet.ListOptions{}, ctxerr.New(r.Context(), "non-int per_page value")
+ return fleet.ListOptions{}, ctxerr.Wrap(r.Context(), badRequest("non-int per_page value"))
}
if perPage <= 0 {
- return fleet.ListOptions{}, ctxerr.New(r.Context(), "invalid per_page value")
+ return fleet.ListOptions{}, ctxerr.Wrap(r.Context(), badRequest("invalid per_page value"))
}
}
@@ -158,8 +158,7 @@ func listOptionsFromRequest(r *http.Request) (fleet.ListOptions, error) {
}
if orderKey == "" && orderDirectionString != "" {
- return fleet.ListOptions{},
- ctxerr.New(r.Context(), "order_key must be specified with order_direction")
+ return fleet.ListOptions{}, ctxerr.Wrap(r.Context(), badRequest("order_key must be specified with order_direction"))
}
var orderDirection fleet.OrderDirection
@@ -172,7 +171,7 @@ func listOptionsFromRequest(r *http.Request) (fleet.ListOptions, error) {
orderDirection = fleet.OrderAscending
default:
return fleet.ListOptions{},
- ctxerr.New(r.Context(), "unknown order_direction: "+orderDirectionString)
+ ctxerr.Wrap(r.Context(), badRequest("unknown order_direction: "+orderDirectionString))
}
diff --git a/server/service/transport_test.go b/server/service/transport_test.go
index 6f67e95f7d..38d286ab4c 100644
--- a/server/service/transport_test.go
+++ b/server/service/transport_test.go
@@ -7,6 +7,7 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestListOptionsFromRequest(t *testing.T) {
@@ -15,8 +16,8 @@ func TestListOptionsFromRequest(t *testing.T) {
url string
// expected list options
listOptions fleet.ListOptions
- // should cause an error
- shouldErr bool
+ // should cause a BadRequest error
+ shouldErr400 bool
}{
// both params provided
{
@@ -67,30 +68,30 @@ func TestListOptionsFromRequest(t *testing.T) {
},
},
- // various error cases
+ // various 400 error cases
{
- url: "/foo?page=foo&per_page=10",
- shouldErr: true,
+ url: "/foo?page=foo&per_page=10",
+ shouldErr400: true,
},
{
- url: "/foo?page=1&per_page=foo",
- shouldErr: true,
+ url: "/foo?page=1&per_page=foo",
+ shouldErr400: true,
},
{
- url: "/foo?page=-1",
- shouldErr: true,
+ url: "/foo?page=-1",
+ shouldErr400: true,
},
{
- url: "/foo?page=-1&per_page=-10",
- shouldErr: true,
+ url: "/foo?page=-1&per_page=-10",
+ shouldErr400: true,
},
{
- url: "/foo?page=1&order_direction=desc",
- shouldErr: true,
+ url: "/foo?page=1&order_direction=desc",
+ shouldErr400: true,
},
{
- url: "/foo?&order_direction=foo&order_key=",
- shouldErr: true,
+ url: "/foo?&order_direction=foo&order_key=",
+ shouldErr400: true,
},
}
@@ -100,8 +101,10 @@ func TestListOptionsFromRequest(t *testing.T) {
req := &http.Request{URL: url}
opt, err := listOptionsFromRequest(req)
- if tt.shouldErr {
+ if tt.shouldErr400 {
assert.NotNil(t, err)
+ var be *fleet.BadRequestError
+ require.ErrorAs(t, err, &be)
return
}