From f2837fd4b3167a5e4d456f2ce80ee7e37adb0a2f Mon Sep 17 00:00:00 2001 From: Tomas Touceda Date: Tue, 3 Aug 2021 16:56:54 -0300 Subject: [PATCH] Make decoder completely generic and simplify things (#1542) * Make decoder completely generic and simplify things * Add commends and unexport func --- server/service/endpoint_utils.go | 226 +++++++++++++------------- server/service/endpoint_utils_test.go | 145 +++++++++++++++++ server/service/integration_test.go | 2 +- server/service/team_schedule.go | 8 +- server/service/teams.go | 2 +- server/service/translator.go | 2 +- server/service/user_roles.go | 2 +- 7 files changed, 262 insertions(+), 125 deletions(-) create mode 100644 server/service/endpoint_utils_test.go diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 0730893fbb..8cdc50770c 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -1,143 +1,135 @@ package service import ( + "bufio" "context" "encoding/json" + "fmt" + "io" "net/http" "reflect" + "strings" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/go-kit/kit/endpoint" + kithttp "github.com/go-kit/kit/transport/http" "github.com/pkg/errors" ) type handlerFunc func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) -func makeDecoderForType(v interface{}) func(ctx context.Context, r *http.Request) (interface{}, error) { - t := reflect.TypeOf(v) - return func(ctx context.Context, r *http.Request) (interface{}, error) { - req := reflect.New(t).Interface() - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - return req, nil +// parseTag parses a `url` tag and whether it's optional or not, which is an optional part of the tag +func parseTag(tag string) (string, bool, error) { + parts := strings.Split(tag, ",") + switch len(parts) { + case 0: + return "", false, errors.Errorf("Error parsing %s: too few parts", tag) + case 1: + return tag, false, nil + case 2: + return parts[0], parts[1] == "optional", nil + default: + return "", false, errors.Errorf("Error parsing %s: too many parts", tag) } } -func makeDecoderForIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { - t := reflect.TypeOf(v) - return func(ctx context.Context, r *http.Request) (interface{}, error) { - value := reflect.New(t) - for _, idKey := range idKeys { - err := setIDFromKey(r, t, value, idKey) - if err != nil { - return nil, err - } - } - - return value.Interface(), nil +// allFields returns all the fields for a struct, including the ones from embedded structs +func allFields(ifv reflect.Value) []reflect.StructField { + if ifv.Kind() == reflect.Ptr { + ifv = ifv.Elem() } -} - -func makeDecoderForTypeAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { - t := reflect.TypeOf(v) - return func(ctx context.Context, r *http.Request) (interface{}, error) { - req, err := makeDecoderForType(v)(ctx, r) - if err != nil { - return nil, err - } - - value := reflect.ValueOf(req) - for _, idKey := range idKeys { - err := setIDFromKey(r, t, value, idKey) - if err != nil { - return nil, err - } - } - - return req, nil - } -} - -func setIDFromKey(r *http.Request, t reflect.Type, v reflect.Value, idKey string) error { - id, err := idFromRequest(r, idKey) - if err != nil { - return err - } - name := "" - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - if f.Tag.Get("url") == idKey { - name = f.Name - } - } - if name == "" { - return errors.Errorf("%s not found in URL", idKey) - } - - field := v.Elem().FieldByName(name) - field.SetUint(uint64(id)) - - return nil -} - -func makeDecoderForOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { - t := reflect.TypeOf(v) - return func(ctx context.Context, r *http.Request) (interface{}, error) { - req, err := makeDecoderForIDs(v, idKeys...)(ctx, r) - if err != nil { - return nil, err - } - - value := reflect.ValueOf(req) - err = setListOptions(r, t, value) - if err != nil { - return nil, err - } - - return req, nil - } -} - -func makeDecoderForTypeOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) { - t := reflect.TypeOf(v) - return func(ctx context.Context, r *http.Request) (interface{}, error) { - req, err := makeDecoderForTypeAndIDs(v, idKeys...)(ctx, r) - if err != nil { - return nil, err - } - - value := reflect.ValueOf(req) - err = setListOptions(r, t, value) - if err != nil { - return nil, err - } - - return req, nil - } -} - -func setListOptions(r *http.Request, t reflect.Type, v reflect.Value) error { - opt, err := listOptionsFromRequest(r) - if err != nil { - return err - } - name := "" - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - if f.Tag.Get("url") == "list_options" { - name = f.Name - } - } - // ListOptions are optional - if name == "" { + if ifv.Kind() != reflect.Struct { return nil } - field := v.Elem().FieldByName(name) - field.Set(reflect.ValueOf(opt)) + var fields []reflect.StructField - return nil + if !ifv.IsValid() { + return nil + } + + t := ifv.Type() + + for i := 0; i < ifv.NumField(); i++ { + v := ifv.Field(i) + + if v.Kind() == reflect.Struct && t.Field(i).Anonymous { + fields = append(fields, allFields(v)...) + continue + } + fields = append(fields, ifv.Type().Field(i)) + } + + return fields +} + +// makeDecoder creates a decoder for the type for the struct passed on. If the struct has at least 1 json tag +// it'll unmarshall the body. If the struct has a `url` tag with value list-options it'll gather fleet.ListOptions +// from the URL. And finally, any other `url` tag will be treated as an ID from the URL path pattern, and it'll +// be decoded and set accordingly. +// IDs are expected to be uint, and can be optional by setting the tag as follows: `url:"some-id,optional"` +// list-options are optional by default and it'll ignore the optional portion of the tag. +func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { + t := reflect.TypeOf(iface) + if t.Kind() != reflect.Struct { + panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface)) + } + + return func(ctx context.Context, r *http.Request) (interface{}, error) { + v := reflect.New(t) + nilBody := false + + buf := bufio.NewReader(r.Body) + if _, err := buf.Peek(1); err == io.EOF { + nilBody = true + } else { + req := v.Interface() + if err := json.NewDecoder(buf).Decode(req); err != nil { + return nil, err + } + v = reflect.ValueOf(req) + } + + for _, f := range allFields(v) { + field := v.Elem().FieldByName(f.Name) + + urlTagValue, ok := f.Tag.Lookup("url") + + optional := false + var err error + if ok { + urlTagValue, optional, err = parseTag(urlTagValue) + if err != nil { + return nil, err + } + } + + if ok && urlTagValue == "list_options" { + opts, err := listOptionsFromRequest(r) + if err != nil { + return nil, err + } + field.Set(reflect.ValueOf(opts)) + continue + } + + if ok { + id, err := idFromRequest(r, urlTagValue) + if err != nil && err == errBadRoute && !optional { + return nil, err + } + field.SetUint(uint64(id)) + continue + } + + _, jsonExpected := f.Tag.Lookup("json") + if jsonExpected && nilBody { + return nil, errors.New("Expected JSON Body") + } + } + + return v.Interface(), nil + } } func makeAuthenticatedServiceEndpoint(svc fleet.Service, f handlerFunc) endpoint.Endpoint { diff --git a/server/service/endpoint_utils_test.go b/server/service/endpoint_utils_test.go new file mode 100644 index 0000000000..03a49d53a3 --- /dev/null +++ b/server/service/endpoint_utils_test.go @@ -0,0 +1,145 @@ +package service + +import ( + "context" + "net/http/httptest" + "strings" + "testing" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUniversalDecoderIDs(t *testing.T) { + type universalStruct struct { + ID1 uint `url:"some-id"` + OptionalID uint `url:"some-other-id,optional"` + } + decoder := makeDecoder(universalStruct{}) + + req := httptest.NewRequest("POST", "/target", nil) + req = mux.SetURLVars(req, map[string]string{"some-id": "999"}) + + decoded, err := decoder(context.Background(), req) + require.NoError(t, err) + casted, ok := decoded.(*universalStruct) + require.True(t, ok) + + assert.Equal(t, uint(999), casted.ID1) + assert.Equal(t, uint(0), casted.OptionalID) + + // fails if non optional IDs are not provided + req = httptest.NewRequest("POST", "/target", nil) + _, err = decoder(context.Background(), req) + require.Error(t, err) +} + +func TestUniversalDecoderIDsAndJSON(t *testing.T) { + type universalStruct struct { + ID1 uint `url:"some-id"` + SomeString string `json:"some_string"` + } + decoder := makeDecoder(universalStruct{}) + + body := `{"some_string": "hello"}` + req := httptest.NewRequest("POST", "/target", strings.NewReader(body)) + req = mux.SetURLVars(req, map[string]string{"some-id": "999"}) + + decoded, err := decoder(context.Background(), req) + require.NoError(t, err) + casted, ok := decoded.(*universalStruct) + require.True(t, ok) + + assert.Equal(t, uint(999), casted.ID1) + assert.Equal(t, "hello", casted.SomeString) +} + +func TestUniversalDecoderIDsAndJSONEmbedded(t *testing.T) { + type EmbeddedJSON struct { + SomeString string `json:"some_string"` + } + type UniversalStruct struct { + ID1 uint `url:"some-id"` + EmbeddedJSON + } + decoder := makeDecoder(UniversalStruct{}) + + body := `{"some_string": "hello"}` + req := httptest.NewRequest("POST", "/target", strings.NewReader(body)) + req = mux.SetURLVars(req, map[string]string{"some-id": "999"}) + + decoded, err := decoder(context.Background(), req) + require.NoError(t, err) + casted, ok := decoded.(*UniversalStruct) + require.True(t, ok) + + assert.Equal(t, uint(999), casted.ID1) + assert.Equal(t, "hello", casted.SomeString) +} + +func TestUniversalDecoderIDsAndListOptions(t *testing.T) { + type universalStruct struct { + ID1 uint `url:"some-id"` + Opts fleet.ListOptions `url:"list_options"` + SomeString string `json:"some_string"` + } + decoder := makeDecoder(universalStruct{}) + + body := `{"some_string": "bye"}` + req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body)) + req = mux.SetURLVars(req, map[string]string{"some-id": "123"}) + + decoded, err := decoder(context.Background(), req) + require.NoError(t, err) + casted, ok := decoded.(*universalStruct) + require.True(t, ok) + + assert.Equal(t, uint(123), casted.ID1) + assert.Equal(t, "bye", casted.SomeString) + assert.Equal(t, uint(77), casted.Opts.PerPage) + assert.Equal(t, uint(4), casted.Opts.Page) +} + +func TestUniversalDecoderHandlersEmbeddedAndNot(t *testing.T) { + type EmbeddedJSON struct { + SomeString string `json:"some_string"` + } + type universalStruct struct { + ID1 uint `url:"some-id"` + Opts fleet.ListOptions `url:"list_options"` + EmbeddedJSON + } + decoder := makeDecoder(universalStruct{}) + + body := `{"some_string": "o/"}` + req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body)) + req = mux.SetURLVars(req, map[string]string{"some-id": "123"}) + + decoded, err := decoder(context.Background(), req) + require.NoError(t, err) + casted, ok := decoded.(*universalStruct) + require.True(t, ok) + + assert.Equal(t, uint(123), casted.ID1) + assert.Equal(t, "o/", casted.SomeString) + assert.Equal(t, uint(77), casted.Opts.PerPage) + assert.Equal(t, uint(4), casted.Opts.Page) +} + +func TestUniversalDecoderListOptions(t *testing.T) { + type universalStruct struct { + ID1 uint `url:"some-id"` + Opts fleet.ListOptions `url:"list_options"` + } + decoder := makeDecoder(universalStruct{}) + + req := httptest.NewRequest("POST", "/target", nil) + req = mux.SetURLVars(req, map[string]string{"some-id": "123"}) + + decoded, err := decoder(context.Background(), req) + require.NoError(t, err) + _, ok := decoded.(*universalStruct) + require.True(t, ok) +} diff --git a/server/service/integration_test.go b/server/service/integration_test.go index 143a61bf9c..8b65e2ea96 100644 --- a/server/service/integration_test.go +++ b/server/service/integration_test.go @@ -503,7 +503,7 @@ func TestTeamSchedule(t *testing.T) { ts = getTeamScheduleResponse{} doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) - assert.Len(t, ts.Scheduled, 1) + require.Len(t, ts.Scheduled, 1) assert.Equal(t, uint(42), ts.Scheduled[0].Interval) assert.Equal(t, "TestQuery", ts.Scheduled[0].Name) assert.Equal(t, qr.ID, ts.Scheduled[0].QueryID) diff --git a/server/service/team_schedule.go b/server/service/team_schedule.go index ee9e001d71..d24efb321a 100644 --- a/server/service/team_schedule.go +++ b/server/service/team_schedule.go @@ -25,7 +25,7 @@ func (r getTeamScheduleResponse) error() error { return r.Err } func makeGetTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, getTeamScheduleEndpoint), - makeDecoderForOptionsAndIDs(getTeamScheduleRequest{}, "team_id"), + makeDecoder(getTeamScheduleRequest{}), opts, ) } @@ -77,7 +77,7 @@ func (r teamScheduleQueryResponse) error() error { return r.Err } func makeTeamScheduleQueryEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, teamScheduleQueryEndpoint), - makeDecoderForTypeAndIDs(teamScheduleQueryRequest{}, "team_id"), + makeDecoder(teamScheduleQueryRequest{}), opts, ) } @@ -148,7 +148,7 @@ func (r modifyTeamScheduleResponse) error() error { return r.Err } func makeModifyTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, modifyTeamScheduleEndpoint), - makeDecoderForTypeAndIDs(modifyTeamScheduleRequest{}, "team_id", "scheduled_query_id"), + makeDecoder(modifyTeamScheduleRequest{}), opts, ) } @@ -197,7 +197,7 @@ func (r deleteTeamScheduleResponse) error() error { return r.Err } func makeDeleteTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, deleteTeamScheduleEndpoint), - makeDecoderForIDs(deleteTeamScheduleRequest{}, "team_id", "scheduled_query_id"), + makeDecoder(deleteTeamScheduleRequest{}), opts, ) } diff --git a/server/service/teams.go b/server/service/teams.go index 16a491e3dc..7384323577 100644 --- a/server/service/teams.go +++ b/server/service/teams.go @@ -24,7 +24,7 @@ func (r applyTeamSpecsResponse) error() error { return r.Err } func makeApplyTeamSpecsEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, applyTeamSpecsEndpoint), - makeDecoderForType(applyTeamSpecsRequest{}), + makeDecoder(applyTeamSpecsRequest{}), opts, ) } diff --git a/server/service/translator.go b/server/service/translator.go index e8c7fce1df..0b0531335f 100644 --- a/server/service/translator.go +++ b/server/service/translator.go @@ -22,7 +22,7 @@ func (r translatorResponse) error() error { return r.Err } func makeTranslatorEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, translatorEndpoint), - makeDecoderForType(translatorRequest{}), + makeDecoder(translatorRequest{}), opts, ) } diff --git a/server/service/user_roles.go b/server/service/user_roles.go index a2c9530204..8872a77c06 100644 --- a/server/service/user_roles.go +++ b/server/service/user_roles.go @@ -22,7 +22,7 @@ func (r applyUserRoleSpecsResponse) error() error { return r.Err } func makeApplyUserRoleSpecsEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler { return newServer( makeAuthenticatedServiceEndpoint(svc, applyUserRoleSpecsEndpoint), - makeDecoderForType(applyUserRoleSpecsRequest{}), + makeDecoder(applyUserRoleSpecsRequest{}), opts, ) }