mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 00:49:03 +00:00
Make decoder completely generic and simplify things (#1542)
* Make decoder completely generic and simplify things * Add commends and unexport func
This commit is contained in:
parent
53dbb2ad50
commit
f2837fd4b3
7 changed files with 262 additions and 125 deletions
|
|
@ -1,143 +1,135 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type handlerFunc func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error)
|
||||
|
||||
func makeDecoderForType(v interface{}) func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
t := reflect.TypeOf(v)
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
req := reflect.New(t).Interface()
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
// parseTag parses a `url` tag and whether it's optional or not, which is an optional part of the tag
|
||||
func parseTag(tag string) (string, bool, error) {
|
||||
parts := strings.Split(tag, ",")
|
||||
switch len(parts) {
|
||||
case 0:
|
||||
return "", false, errors.Errorf("Error parsing %s: too few parts", tag)
|
||||
case 1:
|
||||
return tag, false, nil
|
||||
case 2:
|
||||
return parts[0], parts[1] == "optional", nil
|
||||
default:
|
||||
return "", false, errors.Errorf("Error parsing %s: too many parts", tag)
|
||||
}
|
||||
}
|
||||
|
||||
func makeDecoderForIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
t := reflect.TypeOf(v)
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
value := reflect.New(t)
|
||||
for _, idKey := range idKeys {
|
||||
err := setIDFromKey(r, t, value, idKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return value.Interface(), nil
|
||||
// allFields returns all the fields for a struct, including the ones from embedded structs
|
||||
func allFields(ifv reflect.Value) []reflect.StructField {
|
||||
if ifv.Kind() == reflect.Ptr {
|
||||
ifv = ifv.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
func makeDecoderForTypeAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
t := reflect.TypeOf(v)
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
req, err := makeDecoderForType(v)(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(req)
|
||||
for _, idKey := range idKeys {
|
||||
err := setIDFromKey(r, t, value, idKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
}
|
||||
|
||||
func setIDFromKey(r *http.Request, t reflect.Type, v reflect.Value, idKey string) error {
|
||||
id, err := idFromRequest(r, idKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := ""
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Tag.Get("url") == idKey {
|
||||
name = f.Name
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return errors.Errorf("%s not found in URL", idKey)
|
||||
}
|
||||
|
||||
field := v.Elem().FieldByName(name)
|
||||
field.SetUint(uint64(id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeDecoderForOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
t := reflect.TypeOf(v)
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
req, err := makeDecoderForIDs(v, idKeys...)(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(req)
|
||||
err = setListOptions(r, t, value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
}
|
||||
|
||||
func makeDecoderForTypeOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
t := reflect.TypeOf(v)
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
req, err := makeDecoderForTypeAndIDs(v, idKeys...)(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(req)
|
||||
err = setListOptions(r, t, value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
}
|
||||
|
||||
func setListOptions(r *http.Request, t reflect.Type, v reflect.Value) error {
|
||||
opt, err := listOptionsFromRequest(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := ""
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Tag.Get("url") == "list_options" {
|
||||
name = f.Name
|
||||
}
|
||||
}
|
||||
// ListOptions are optional
|
||||
if name == "" {
|
||||
if ifv.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
field := v.Elem().FieldByName(name)
|
||||
field.Set(reflect.ValueOf(opt))
|
||||
var fields []reflect.StructField
|
||||
|
||||
return nil
|
||||
if !ifv.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := ifv.Type()
|
||||
|
||||
for i := 0; i < ifv.NumField(); i++ {
|
||||
v := ifv.Field(i)
|
||||
|
||||
if v.Kind() == reflect.Struct && t.Field(i).Anonymous {
|
||||
fields = append(fields, allFields(v)...)
|
||||
continue
|
||||
}
|
||||
fields = append(fields, ifv.Type().Field(i))
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// makeDecoder creates a decoder for the type for the struct passed on. If the struct has at least 1 json tag
|
||||
// it'll unmarshall the body. If the struct has a `url` tag with value list-options it'll gather fleet.ListOptions
|
||||
// from the URL. And finally, any other `url` tag will be treated as an ID from the URL path pattern, and it'll
|
||||
// be decoded and set accordingly.
|
||||
// IDs are expected to be uint, and can be optional by setting the tag as follows: `url:"some-id,optional"`
|
||||
// list-options are optional by default and it'll ignore the optional portion of the tag.
|
||||
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
|
||||
t := reflect.TypeOf(iface)
|
||||
if t.Kind() != reflect.Struct {
|
||||
panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface))
|
||||
}
|
||||
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
v := reflect.New(t)
|
||||
nilBody := false
|
||||
|
||||
buf := bufio.NewReader(r.Body)
|
||||
if _, err := buf.Peek(1); err == io.EOF {
|
||||
nilBody = true
|
||||
} else {
|
||||
req := v.Interface()
|
||||
if err := json.NewDecoder(buf).Decode(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v = reflect.ValueOf(req)
|
||||
}
|
||||
|
||||
for _, f := range allFields(v) {
|
||||
field := v.Elem().FieldByName(f.Name)
|
||||
|
||||
urlTagValue, ok := f.Tag.Lookup("url")
|
||||
|
||||
optional := false
|
||||
var err error
|
||||
if ok {
|
||||
urlTagValue, optional, err = parseTag(urlTagValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ok && urlTagValue == "list_options" {
|
||||
opts, err := listOptionsFromRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
field.Set(reflect.ValueOf(opts))
|
||||
continue
|
||||
}
|
||||
|
||||
if ok {
|
||||
id, err := idFromRequest(r, urlTagValue)
|
||||
if err != nil && err == errBadRoute && !optional {
|
||||
return nil, err
|
||||
}
|
||||
field.SetUint(uint64(id))
|
||||
continue
|
||||
}
|
||||
|
||||
_, jsonExpected := f.Tag.Lookup("json")
|
||||
if jsonExpected && nilBody {
|
||||
return nil, errors.New("Expected JSON Body")
|
||||
}
|
||||
}
|
||||
|
||||
return v.Interface(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func makeAuthenticatedServiceEndpoint(svc fleet.Service, f handlerFunc) endpoint.Endpoint {
|
||||
|
|
|
|||
145
server/service/endpoint_utils_test.go
Normal file
145
server/service/endpoint_utils_test.go
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUniversalDecoderIDs(t *testing.T) {
|
||||
type universalStruct struct {
|
||||
ID1 uint `url:"some-id"`
|
||||
OptionalID uint `url:"some-other-id,optional"`
|
||||
}
|
||||
decoder := makeDecoder(universalStruct{})
|
||||
|
||||
req := httptest.NewRequest("POST", "/target", nil)
|
||||
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
|
||||
|
||||
decoded, err := decoder(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
casted, ok := decoded.(*universalStruct)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, uint(999), casted.ID1)
|
||||
assert.Equal(t, uint(0), casted.OptionalID)
|
||||
|
||||
// fails if non optional IDs are not provided
|
||||
req = httptest.NewRequest("POST", "/target", nil)
|
||||
_, err = decoder(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUniversalDecoderIDsAndJSON(t *testing.T) {
|
||||
type universalStruct struct {
|
||||
ID1 uint `url:"some-id"`
|
||||
SomeString string `json:"some_string"`
|
||||
}
|
||||
decoder := makeDecoder(universalStruct{})
|
||||
|
||||
body := `{"some_string": "hello"}`
|
||||
req := httptest.NewRequest("POST", "/target", strings.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
|
||||
|
||||
decoded, err := decoder(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
casted, ok := decoded.(*universalStruct)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, uint(999), casted.ID1)
|
||||
assert.Equal(t, "hello", casted.SomeString)
|
||||
}
|
||||
|
||||
func TestUniversalDecoderIDsAndJSONEmbedded(t *testing.T) {
|
||||
type EmbeddedJSON struct {
|
||||
SomeString string `json:"some_string"`
|
||||
}
|
||||
type UniversalStruct struct {
|
||||
ID1 uint `url:"some-id"`
|
||||
EmbeddedJSON
|
||||
}
|
||||
decoder := makeDecoder(UniversalStruct{})
|
||||
|
||||
body := `{"some_string": "hello"}`
|
||||
req := httptest.NewRequest("POST", "/target", strings.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
|
||||
|
||||
decoded, err := decoder(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
casted, ok := decoded.(*UniversalStruct)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, uint(999), casted.ID1)
|
||||
assert.Equal(t, "hello", casted.SomeString)
|
||||
}
|
||||
|
||||
func TestUniversalDecoderIDsAndListOptions(t *testing.T) {
|
||||
type universalStruct struct {
|
||||
ID1 uint `url:"some-id"`
|
||||
Opts fleet.ListOptions `url:"list_options"`
|
||||
SomeString string `json:"some_string"`
|
||||
}
|
||||
decoder := makeDecoder(universalStruct{})
|
||||
|
||||
body := `{"some_string": "bye"}`
|
||||
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
||||
|
||||
decoded, err := decoder(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
casted, ok := decoded.(*universalStruct)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, uint(123), casted.ID1)
|
||||
assert.Equal(t, "bye", casted.SomeString)
|
||||
assert.Equal(t, uint(77), casted.Opts.PerPage)
|
||||
assert.Equal(t, uint(4), casted.Opts.Page)
|
||||
}
|
||||
|
||||
func TestUniversalDecoderHandlersEmbeddedAndNot(t *testing.T) {
|
||||
type EmbeddedJSON struct {
|
||||
SomeString string `json:"some_string"`
|
||||
}
|
||||
type universalStruct struct {
|
||||
ID1 uint `url:"some-id"`
|
||||
Opts fleet.ListOptions `url:"list_options"`
|
||||
EmbeddedJSON
|
||||
}
|
||||
decoder := makeDecoder(universalStruct{})
|
||||
|
||||
body := `{"some_string": "o/"}`
|
||||
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
||||
|
||||
decoded, err := decoder(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
casted, ok := decoded.(*universalStruct)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, uint(123), casted.ID1)
|
||||
assert.Equal(t, "o/", casted.SomeString)
|
||||
assert.Equal(t, uint(77), casted.Opts.PerPage)
|
||||
assert.Equal(t, uint(4), casted.Opts.Page)
|
||||
}
|
||||
|
||||
func TestUniversalDecoderListOptions(t *testing.T) {
|
||||
type universalStruct struct {
|
||||
ID1 uint `url:"some-id"`
|
||||
Opts fleet.ListOptions `url:"list_options"`
|
||||
}
|
||||
decoder := makeDecoder(universalStruct{})
|
||||
|
||||
req := httptest.NewRequest("POST", "/target", nil)
|
||||
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
|
||||
|
||||
decoded, err := decoder(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
_, ok := decoded.(*universalStruct)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
|
@ -503,7 +503,7 @@ func TestTeamSchedule(t *testing.T) {
|
|||
|
||||
ts = getTeamScheduleResponse{}
|
||||
doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts)
|
||||
assert.Len(t, ts.Scheduled, 1)
|
||||
require.Len(t, ts.Scheduled, 1)
|
||||
assert.Equal(t, uint(42), ts.Scheduled[0].Interval)
|
||||
assert.Equal(t, "TestQuery", ts.Scheduled[0].Name)
|
||||
assert.Equal(t, qr.ID, ts.Scheduled[0].QueryID)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ func (r getTeamScheduleResponse) error() error { return r.Err }
|
|||
func makeGetTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, getTeamScheduleEndpoint),
|
||||
makeDecoderForOptionsAndIDs(getTeamScheduleRequest{}, "team_id"),
|
||||
makeDecoder(getTeamScheduleRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
@ -77,7 +77,7 @@ func (r teamScheduleQueryResponse) error() error { return r.Err }
|
|||
func makeTeamScheduleQueryEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, teamScheduleQueryEndpoint),
|
||||
makeDecoderForTypeAndIDs(teamScheduleQueryRequest{}, "team_id"),
|
||||
makeDecoder(teamScheduleQueryRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
@ -148,7 +148,7 @@ func (r modifyTeamScheduleResponse) error() error { return r.Err }
|
|||
func makeModifyTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, modifyTeamScheduleEndpoint),
|
||||
makeDecoderForTypeAndIDs(modifyTeamScheduleRequest{}, "team_id", "scheduled_query_id"),
|
||||
makeDecoder(modifyTeamScheduleRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
@ -197,7 +197,7 @@ func (r deleteTeamScheduleResponse) error() error { return r.Err }
|
|||
func makeDeleteTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, deleteTeamScheduleEndpoint),
|
||||
makeDecoderForIDs(deleteTeamScheduleRequest{}, "team_id", "scheduled_query_id"),
|
||||
makeDecoder(deleteTeamScheduleRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ func (r applyTeamSpecsResponse) error() error { return r.Err }
|
|||
func makeApplyTeamSpecsEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, applyTeamSpecsEndpoint),
|
||||
makeDecoderForType(applyTeamSpecsRequest{}),
|
||||
makeDecoder(applyTeamSpecsRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ func (r translatorResponse) error() error { return r.Err }
|
|||
func makeTranslatorEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, translatorEndpoint),
|
||||
makeDecoderForType(translatorRequest{}),
|
||||
makeDecoder(translatorRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ func (r applyUserRoleSpecsResponse) error() error { return r.Err }
|
|||
func makeApplyUserRoleSpecsEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
|
||||
return newServer(
|
||||
makeAuthenticatedServiceEndpoint(svc, applyUserRoleSpecsEndpoint),
|
||||
makeDecoderForType(applyUserRoleSpecsRequest{}),
|
||||
makeDecoder(applyUserRoleSpecsRequest{}),
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue