Make decoder completely generic and simplify things (#1542)

* Make decoder completely generic and simplify things

* Add commends and unexport func
This commit is contained in:
Tomas Touceda 2021-08-03 16:56:54 -03:00 committed by GitHub
parent 53dbb2ad50
commit f2837fd4b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 262 additions and 125 deletions

View file

@ -1,143 +1,135 @@
package service
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/pkg/errors"
)
type handlerFunc func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error)
func makeDecoderForType(v interface{}) func(ctx context.Context, r *http.Request) (interface{}, error) {
t := reflect.TypeOf(v)
return func(ctx context.Context, r *http.Request) (interface{}, error) {
req := reflect.New(t).Interface()
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
return req, nil
// parseTag parses a `url` tag and whether it's optional or not, which is an optional part of the tag
func parseTag(tag string) (string, bool, error) {
parts := strings.Split(tag, ",")
switch len(parts) {
case 0:
return "", false, errors.Errorf("Error parsing %s: too few parts", tag)
case 1:
return tag, false, nil
case 2:
return parts[0], parts[1] == "optional", nil
default:
return "", false, errors.Errorf("Error parsing %s: too many parts", tag)
}
}
func makeDecoderForIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
t := reflect.TypeOf(v)
return func(ctx context.Context, r *http.Request) (interface{}, error) {
value := reflect.New(t)
for _, idKey := range idKeys {
err := setIDFromKey(r, t, value, idKey)
if err != nil {
return nil, err
}
}
return value.Interface(), nil
// allFields returns all the fields for a struct, including the ones from embedded structs
func allFields(ifv reflect.Value) []reflect.StructField {
if ifv.Kind() == reflect.Ptr {
ifv = ifv.Elem()
}
}
func makeDecoderForTypeAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
t := reflect.TypeOf(v)
return func(ctx context.Context, r *http.Request) (interface{}, error) {
req, err := makeDecoderForType(v)(ctx, r)
if err != nil {
return nil, err
}
value := reflect.ValueOf(req)
for _, idKey := range idKeys {
err := setIDFromKey(r, t, value, idKey)
if err != nil {
return nil, err
}
}
return req, nil
}
}
func setIDFromKey(r *http.Request, t reflect.Type, v reflect.Value, idKey string) error {
id, err := idFromRequest(r, idKey)
if err != nil {
return err
}
name := ""
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("url") == idKey {
name = f.Name
}
}
if name == "" {
return errors.Errorf("%s not found in URL", idKey)
}
field := v.Elem().FieldByName(name)
field.SetUint(uint64(id))
return nil
}
func makeDecoderForOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
t := reflect.TypeOf(v)
return func(ctx context.Context, r *http.Request) (interface{}, error) {
req, err := makeDecoderForIDs(v, idKeys...)(ctx, r)
if err != nil {
return nil, err
}
value := reflect.ValueOf(req)
err = setListOptions(r, t, value)
if err != nil {
return nil, err
}
return req, nil
}
}
func makeDecoderForTypeOptionsAndIDs(v interface{}, idKeys ...string) func(ctx context.Context, r *http.Request) (interface{}, error) {
t := reflect.TypeOf(v)
return func(ctx context.Context, r *http.Request) (interface{}, error) {
req, err := makeDecoderForTypeAndIDs(v, idKeys...)(ctx, r)
if err != nil {
return nil, err
}
value := reflect.ValueOf(req)
err = setListOptions(r, t, value)
if err != nil {
return nil, err
}
return req, nil
}
}
func setListOptions(r *http.Request, t reflect.Type, v reflect.Value) error {
opt, err := listOptionsFromRequest(r)
if err != nil {
return err
}
name := ""
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("url") == "list_options" {
name = f.Name
}
}
// ListOptions are optional
if name == "" {
if ifv.Kind() != reflect.Struct {
return nil
}
field := v.Elem().FieldByName(name)
field.Set(reflect.ValueOf(opt))
var fields []reflect.StructField
return nil
if !ifv.IsValid() {
return nil
}
t := ifv.Type()
for i := 0; i < ifv.NumField(); i++ {
v := ifv.Field(i)
if v.Kind() == reflect.Struct && t.Field(i).Anonymous {
fields = append(fields, allFields(v)...)
continue
}
fields = append(fields, ifv.Type().Field(i))
}
return fields
}
// makeDecoder creates a decoder for the type for the struct passed on. If the struct has at least 1 json tag
// it'll unmarshall the body. If the struct has a `url` tag with value list-options it'll gather fleet.ListOptions
// from the URL. And finally, any other `url` tag will be treated as an ID from the URL path pattern, and it'll
// be decoded and set accordingly.
// IDs are expected to be uint, and can be optional by setting the tag as follows: `url:"some-id,optional"`
// list-options are optional by default and it'll ignore the optional portion of the tag.
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
t := reflect.TypeOf(iface)
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface))
}
return func(ctx context.Context, r *http.Request) (interface{}, error) {
v := reflect.New(t)
nilBody := false
buf := bufio.NewReader(r.Body)
if _, err := buf.Peek(1); err == io.EOF {
nilBody = true
} else {
req := v.Interface()
if err := json.NewDecoder(buf).Decode(req); err != nil {
return nil, err
}
v = reflect.ValueOf(req)
}
for _, f := range allFields(v) {
field := v.Elem().FieldByName(f.Name)
urlTagValue, ok := f.Tag.Lookup("url")
optional := false
var err error
if ok {
urlTagValue, optional, err = parseTag(urlTagValue)
if err != nil {
return nil, err
}
}
if ok && urlTagValue == "list_options" {
opts, err := listOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
continue
}
if ok {
id, err := idFromRequest(r, urlTagValue)
if err != nil && err == errBadRoute && !optional {
return nil, err
}
field.SetUint(uint64(id))
continue
}
_, jsonExpected := f.Tag.Lookup("json")
if jsonExpected && nilBody {
return nil, errors.New("Expected JSON Body")
}
}
return v.Interface(), nil
}
}
func makeAuthenticatedServiceEndpoint(svc fleet.Service, f handlerFunc) endpoint.Endpoint {

View file

@ -0,0 +1,145 @@
package service
import (
"context"
"net/http/httptest"
"strings"
"testing"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUniversalDecoderIDs(t *testing.T) {
type universalStruct struct {
ID1 uint `url:"some-id"`
OptionalID uint `url:"some-other-id,optional"`
}
decoder := makeDecoder(universalStruct{})
req := httptest.NewRequest("POST", "/target", nil)
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
decoded, err := decoder(context.Background(), req)
require.NoError(t, err)
casted, ok := decoded.(*universalStruct)
require.True(t, ok)
assert.Equal(t, uint(999), casted.ID1)
assert.Equal(t, uint(0), casted.OptionalID)
// fails if non optional IDs are not provided
req = httptest.NewRequest("POST", "/target", nil)
_, err = decoder(context.Background(), req)
require.Error(t, err)
}
func TestUniversalDecoderIDsAndJSON(t *testing.T) {
type universalStruct struct {
ID1 uint `url:"some-id"`
SomeString string `json:"some_string"`
}
decoder := makeDecoder(universalStruct{})
body := `{"some_string": "hello"}`
req := httptest.NewRequest("POST", "/target", strings.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
decoded, err := decoder(context.Background(), req)
require.NoError(t, err)
casted, ok := decoded.(*universalStruct)
require.True(t, ok)
assert.Equal(t, uint(999), casted.ID1)
assert.Equal(t, "hello", casted.SomeString)
}
func TestUniversalDecoderIDsAndJSONEmbedded(t *testing.T) {
type EmbeddedJSON struct {
SomeString string `json:"some_string"`
}
type UniversalStruct struct {
ID1 uint `url:"some-id"`
EmbeddedJSON
}
decoder := makeDecoder(UniversalStruct{})
body := `{"some_string": "hello"}`
req := httptest.NewRequest("POST", "/target", strings.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"some-id": "999"})
decoded, err := decoder(context.Background(), req)
require.NoError(t, err)
casted, ok := decoded.(*UniversalStruct)
require.True(t, ok)
assert.Equal(t, uint(999), casted.ID1)
assert.Equal(t, "hello", casted.SomeString)
}
func TestUniversalDecoderIDsAndListOptions(t *testing.T) {
type universalStruct struct {
ID1 uint `url:"some-id"`
Opts fleet.ListOptions `url:"list_options"`
SomeString string `json:"some_string"`
}
decoder := makeDecoder(universalStruct{})
body := `{"some_string": "bye"}`
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
decoded, err := decoder(context.Background(), req)
require.NoError(t, err)
casted, ok := decoded.(*universalStruct)
require.True(t, ok)
assert.Equal(t, uint(123), casted.ID1)
assert.Equal(t, "bye", casted.SomeString)
assert.Equal(t, uint(77), casted.Opts.PerPage)
assert.Equal(t, uint(4), casted.Opts.Page)
}
func TestUniversalDecoderHandlersEmbeddedAndNot(t *testing.T) {
type EmbeddedJSON struct {
SomeString string `json:"some_string"`
}
type universalStruct struct {
ID1 uint `url:"some-id"`
Opts fleet.ListOptions `url:"list_options"`
EmbeddedJSON
}
decoder := makeDecoder(universalStruct{})
body := `{"some_string": "o/"}`
req := httptest.NewRequest("POST", "/target?per_page=77&page=4", strings.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
decoded, err := decoder(context.Background(), req)
require.NoError(t, err)
casted, ok := decoded.(*universalStruct)
require.True(t, ok)
assert.Equal(t, uint(123), casted.ID1)
assert.Equal(t, "o/", casted.SomeString)
assert.Equal(t, uint(77), casted.Opts.PerPage)
assert.Equal(t, uint(4), casted.Opts.Page)
}
func TestUniversalDecoderListOptions(t *testing.T) {
type universalStruct struct {
ID1 uint `url:"some-id"`
Opts fleet.ListOptions `url:"list_options"`
}
decoder := makeDecoder(universalStruct{})
req := httptest.NewRequest("POST", "/target", nil)
req = mux.SetURLVars(req, map[string]string{"some-id": "123"})
decoded, err := decoder(context.Background(), req)
require.NoError(t, err)
_, ok := decoded.(*universalStruct)
require.True(t, ok)
}

View file

@ -503,7 +503,7 @@ func TestTeamSchedule(t *testing.T) {
ts = getTeamScheduleResponse{}
doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts)
assert.Len(t, ts.Scheduled, 1)
require.Len(t, ts.Scheduled, 1)
assert.Equal(t, uint(42), ts.Scheduled[0].Interval)
assert.Equal(t, "TestQuery", ts.Scheduled[0].Name)
assert.Equal(t, qr.ID, ts.Scheduled[0].QueryID)

View file

@ -25,7 +25,7 @@ func (r getTeamScheduleResponse) error() error { return r.Err }
func makeGetTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, getTeamScheduleEndpoint),
makeDecoderForOptionsAndIDs(getTeamScheduleRequest{}, "team_id"),
makeDecoder(getTeamScheduleRequest{}),
opts,
)
}
@ -77,7 +77,7 @@ func (r teamScheduleQueryResponse) error() error { return r.Err }
func makeTeamScheduleQueryEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, teamScheduleQueryEndpoint),
makeDecoderForTypeAndIDs(teamScheduleQueryRequest{}, "team_id"),
makeDecoder(teamScheduleQueryRequest{}),
opts,
)
}
@ -148,7 +148,7 @@ func (r modifyTeamScheduleResponse) error() error { return r.Err }
func makeModifyTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, modifyTeamScheduleEndpoint),
makeDecoderForTypeAndIDs(modifyTeamScheduleRequest{}, "team_id", "scheduled_query_id"),
makeDecoder(modifyTeamScheduleRequest{}),
opts,
)
}
@ -197,7 +197,7 @@ func (r deleteTeamScheduleResponse) error() error { return r.Err }
func makeDeleteTeamScheduleEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, deleteTeamScheduleEndpoint),
makeDecoderForIDs(deleteTeamScheduleRequest{}, "team_id", "scheduled_query_id"),
makeDecoder(deleteTeamScheduleRequest{}),
opts,
)
}

View file

@ -24,7 +24,7 @@ func (r applyTeamSpecsResponse) error() error { return r.Err }
func makeApplyTeamSpecsEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, applyTeamSpecsEndpoint),
makeDecoderForType(applyTeamSpecsRequest{}),
makeDecoder(applyTeamSpecsRequest{}),
opts,
)
}

View file

@ -22,7 +22,7 @@ func (r translatorResponse) error() error { return r.Err }
func makeTranslatorEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, translatorEndpoint),
makeDecoderForType(translatorRequest{}),
makeDecoder(translatorRequest{}),
opts,
)
}

View file

@ -22,7 +22,7 @@ func (r applyUserRoleSpecsResponse) error() error { return r.Err }
func makeApplyUserRoleSpecsEndpoint(svc fleet.Service, opts []kithttp.ServerOption) http.Handler {
return newServer(
makeAuthenticatedServiceEndpoint(svc, applyUserRoleSpecsEndpoint),
makeDecoderForType(applyUserRoleSpecsRequest{}),
makeDecoder(applyUserRoleSpecsRequest{}),
opts,
)
}