From a3589892c32b552bb4c604eb5e601389a1b2d5b8 Mon Sep 17 00:00:00 2001 From: Mike Arpaia Date: Mon, 13 Feb 2017 14:31:22 -0700 Subject: [PATCH] A simpler attempt at using the payload pattern for scheduled queries (#1210) --- server/kolide/scheduled_queries.go | 13 +++++- server/service/endpoint_scheduled_queries.go | 7 +--- server/service/logging_scheduled_queries.go | 4 +- server/service/service_scheduled_queries.go | 40 ++++++++++++++++++- .../service/service_scheduled_queries_test.go | 7 +++- .../transport_scheduled_queries_test.go | 6 +-- 6 files changed, 63 insertions(+), 14 deletions(-) diff --git a/server/kolide/scheduled_queries.go b/server/kolide/scheduled_queries.go index d13d59752b..cc876f1aac 100644 --- a/server/kolide/scheduled_queries.go +++ b/server/kolide/scheduled_queries.go @@ -17,7 +17,7 @@ type ScheduledQueryService interface { GetScheduledQueriesInPack(ctx context.Context, id uint, opts ListOptions) (queries []*ScheduledQuery, err error) ScheduleQuery(ctx context.Context, sq *ScheduledQuery) (query *ScheduledQuery, err error) DeleteScheduledQuery(ctx context.Context, id uint) (err error) - ModifyScheduledQuery(ctx context.Context, sq *ScheduledQuery) (query *ScheduledQuery, err error) + ModifyScheduledQuery(ctx context.Context, id uint, p ScheduledQueryPayload) (query *ScheduledQuery, err error) } type ScheduledQuery struct { @@ -35,3 +35,14 @@ type ScheduledQuery struct { Version *string `json:"version"` Shard *uint `json:"shard"` } + +type ScheduledQueryPayload struct { + PackID *uint `json:"pack_id"` + QueryID *uint `json:"query_id"` + Interval *uint `json:"interval"` + Snapshot *bool `json:"snapshot"` + Removed *bool `json:"removed"` + Platform *string `json:"platform"` + Version *string `json:"version"` + Shard *uint `json:"shard"` +} diff --git a/server/service/endpoint_scheduled_queries.go b/server/service/endpoint_scheduled_queries.go index f8e97f5dce..4961174e31 100644 --- a/server/service/endpoint_scheduled_queries.go +++ b/server/service/endpoint_scheduled_queries.go @@ -127,7 +127,7 @@ func makeScheduleQueryEndpoint(svc kolide.Service) endpoint.Endpoint { type modifyScheduledQueryRequest struct { ID uint - payload *kolide.ScheduledQuery + payload kolide.ScheduledQueryPayload } type modifyScheduledQueryResponse struct { @@ -141,10 +141,7 @@ func makeModifyScheduledQueryEndpoint(svc kolide.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(modifyScheduledQueryRequest) - sq := req.payload - sq.ID = req.ID - - sq, err := svc.ModifyScheduledQuery(ctx, sq) + sq, err := svc.ModifyScheduledQuery(ctx, req.ID, req.payload) if err != nil { return modifyScheduledQueryResponse{Err: err}, nil } diff --git a/server/service/logging_scheduled_queries.go b/server/service/logging_scheduled_queries.go index 326e56ddc5..fbf74e84ee 100644 --- a/server/service/logging_scheduled_queries.go +++ b/server/service/logging_scheduled_queries.go @@ -78,7 +78,7 @@ func (mw loggingMiddleware) DeleteScheduledQuery(ctx context.Context, id uint) e return err } -func (mw loggingMiddleware) ModifyScheduledQuery(ctx context.Context, sq *kolide.ScheduledQuery) (*kolide.ScheduledQuery, error) { +func (mw loggingMiddleware) ModifyScheduledQuery(ctx context.Context, id uint, p kolide.ScheduledQueryPayload) (*kolide.ScheduledQuery, error) { var ( query *kolide.ScheduledQuery err error @@ -92,6 +92,6 @@ func (mw loggingMiddleware) ModifyScheduledQuery(ctx context.Context, sq *kolide ) }(time.Now()) - query, err = mw.Service.ModifyScheduledQuery(ctx, sq) + query, err = mw.Service.ModifyScheduledQuery(ctx, id, p) return query, err } diff --git a/server/service/service_scheduled_queries.go b/server/service/service_scheduled_queries.go index 7314aa7862..47c6d844b2 100644 --- a/server/service/service_scheduled_queries.go +++ b/server/service/service_scheduled_queries.go @@ -2,6 +2,7 @@ package service import ( "github.com/kolide/kolide/server/kolide" + "github.com/pkg/errors" "golang.org/x/net/context" ) @@ -17,7 +18,44 @@ func (svc service) ScheduleQuery(ctx context.Context, sq *kolide.ScheduledQuery) return svc.ds.NewScheduledQuery(sq) } -func (svc service) ModifyScheduledQuery(ctx context.Context, sq *kolide.ScheduledQuery) (*kolide.ScheduledQuery, error) { +func (svc service) ModifyScheduledQuery(ctx context.Context, id uint, p kolide.ScheduledQueryPayload) (*kolide.ScheduledQuery, error) { + sq, err := svc.GetScheduledQuery(ctx, id) + if err != nil { + return nil, errors.Wrap(err, "getting scheduled query to modify") + } + + if p.PackID != nil { + sq.PackID = *p.PackID + } + + if p.QueryID != nil { + sq.QueryID = *p.QueryID + } + + if p.Interval != nil { + sq.Interval = *p.Interval + } + + if p.Snapshot != nil { + sq.Snapshot = p.Snapshot + } + + if p.Removed != nil { + sq.Removed = p.Removed + } + + if p.Platform != nil { + sq.Platform = p.Platform + } + + if p.Version != nil { + sq.Version = p.Version + } + + if p.Shard != nil { + sq.Shard = p.Shard + } + return svc.ds.SaveScheduledQuery(sq) } diff --git a/server/service/service_scheduled_queries_test.go b/server/service/service_scheduled_queries_test.go index c81f7161c8..0fa55deafc 100644 --- a/server/service/service_scheduled_queries_test.go +++ b/server/service/service_scheduled_queries_test.go @@ -71,8 +71,11 @@ func TestModifyScheduledQuery(t *testing.T) { require.Nil(t, err) assert.Equal(t, uint(60), query.Interval) - query.Interval = uint(120) - query, err = svc.ModifyScheduledQuery(ctx, query) + interval := uint(120) + queryPayload := kolide.ScheduledQueryPayload{ + Interval: &interval, + } + query, err = svc.ModifyScheduledQuery(ctx, sq1.ID, queryPayload) assert.Equal(t, uint(120), query.Interval) queryVerify, err := svc.GetScheduledQuery(ctx, sq1.ID) diff --git a/server/service/transport_scheduled_queries_test.go b/server/service/transport_scheduled_queries_test.go index 87ef9084c9..8d48af78b6 100644 --- a/server/service/transport_scheduled_queries_test.go +++ b/server/service/transport_scheduled_queries_test.go @@ -48,10 +48,10 @@ func TestDecodeModifyScheduledQueryRequest(t *testing.T) { params := r.(modifyScheduledQueryRequest) assert.Equal(t, uint(1), params.ID) - assert.Equal(t, uint(5), params.payload.PackID) - assert.Equal(t, uint(6), params.payload.QueryID) + assert.Equal(t, uint(5), *params.payload.PackID) + assert.Equal(t, uint(6), *params.payload.QueryID) assert.Equal(t, true, *params.payload.Removed) - assert.Equal(t, uint(60), params.payload.Interval) + assert.Equal(t, uint(60), *params.payload.Interval) assert.Equal(t, uint(1), *params.payload.Shard) }).Methods("PATCH")