From 7464e72ba881a75bc9eb18805a7384d28aaea665 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 1 Dec 2021 15:45:29 -0500 Subject: [PATCH] Move carves endpoints to new endpoint pattern (#3148) --- server/datastore/mysql/carves.go | 12 ++ server/service/carves.go | 133 +++++++++++++++++ server/service/carves_test.go | 186 ++++++++++++++++++++++++ server/service/endpoint_carves.go | 87 ----------- server/service/endpoint_utils.go | 12 +- server/service/handler.go | 20 +-- server/service/integration_core_test.go | 77 ++++++++++ server/service/service_carves.go | 42 ------ server/service/service_carves_test.go | 86 ----------- server/service/transport.go | 21 +++ server/service/transport_carves.go | 39 ----- 11 files changed, 444 insertions(+), 271 deletions(-) create mode 100644 server/service/carves.go create mode 100644 server/service/carves_test.go diff --git a/server/datastore/mysql/carves.go b/server/datastore/mysql/carves.go index 01ca15241d..89f3a76072 100644 --- a/server/datastore/mysql/carves.go +++ b/server/datastore/mysql/carves.go @@ -167,6 +167,9 @@ func (d *Datastore) Carve(ctx context.Context, carveId int64) (*fleet.CarveMetad var metadata fleet.CarveMetadata if err := sqlx.GetContext(ctx, d.reader, &metadata, stmt, carveId); err != nil { + if err == sql.ErrNoRows { + return nil, ctxerr.Wrap(ctx, notFound("Carve").WithID(uint(carveId))) + } return nil, ctxerr.Wrap(ctx, err, "get carve by ID") } @@ -183,6 +186,9 @@ func (d *Datastore) CarveBySessionId(ctx context.Context, sessionId string) (*fl var metadata fleet.CarveMetadata if err := sqlx.GetContext(ctx, d.reader, &metadata, stmt, sessionId); err != nil { + if err == sql.ErrNoRows { + return nil, ctxerr.Wrap(ctx, notFound("CarveBySessionId").WithName(sessionId)) + } return nil, ctxerr.Wrap(ctx, err, "get carve by session ID") } @@ -199,6 +205,9 @@ func (d *Datastore) CarveByName(ctx context.Context, name string) (*fleet.CarveM var metadata fleet.CarveMetadata if err := sqlx.GetContext(ctx, d.reader, &metadata, stmt, name); err != nil { + if err == sql.ErrNoRows { + return nil, ctxerr.Wrap(ctx, notFound("Carve").WithName(name)) + } return nil, ctxerr.Wrap(ctx, err, "get carve by name") } @@ -259,6 +268,9 @@ func (d *Datastore) GetBlock(ctx context.Context, metadata *fleet.CarveMetadata, ` var data []byte if err := sqlx.GetContext(ctx, d.reader, &data, stmt, metadata.ID, blockId); err != nil { + if err == sql.ErrNoRows { + return nil, ctxerr.Wrap(ctx, notFound("CarveBlock").WithID(uint(blockId))) + } return nil, ctxerr.Wrap(ctx, err, "select data") } diff --git a/server/service/carves.go b/server/service/carves.go new file mode 100644 index 0000000000..e5e659d850 --- /dev/null +++ b/server/service/carves.go @@ -0,0 +1,133 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/fleet" +) + +//////////////////////////////////////////////////////////////////////////////// +// List Carves +//////////////////////////////////////////////////////////////////////////////// + +type listCarvesRequest struct { + ListOptions fleet.CarveListOptions `url:"carve_options"` +} + +type listCarvesResponse struct { + Carves []fleet.CarveMetadata `json:"carves"` + Err error `json:"error,omitempty"` +} + +func (r listCarvesResponse) error() error { return r.Err } + +func listCarvesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*listCarvesRequest) + carves, err := svc.ListCarves(ctx, req.ListOptions) + if err != nil { + return listCarvesResponse{Err: err}, nil + } + + resp := listCarvesResponse{} + for _, carve := range carves { + resp.Carves = append(resp.Carves, *carve) + } + return resp, nil +} + +func (svc *Service) ListCarves(ctx context.Context, opt fleet.CarveListOptions) ([]*fleet.CarveMetadata, error) { + if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.carveStore.ListCarves(ctx, opt) +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Carve +//////////////////////////////////////////////////////////////////////////////// + +type getCarveRequest struct { + ID int64 `url:"id"` +} + +type getCarveResponse struct { + Carve fleet.CarveMetadata `json:"carve"` + Err error `json:"error,omitempty"` +} + +func (r getCarveResponse) error() error { return r.Err } + +func getCarveEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getCarveRequest) + carve, err := svc.GetCarve(ctx, req.ID) + if err != nil { + return getCarveResponse{Err: err}, nil + } + + return getCarveResponse{Carve: *carve}, nil + +} + +func (svc *Service) GetCarve(ctx context.Context, id int64) (*fleet.CarveMetadata, error) { + if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.carveStore.Carve(ctx, id) +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Carve Block +//////////////////////////////////////////////////////////////////////////////// + +type getCarveBlockRequest struct { + ID int64 `url:"id"` + BlockId int64 `url:"block_id"` +} + +type getCarveBlockResponse struct { + Data []byte `json:"data"` + Err error `json:"error,omitempty"` +} + +func (r getCarveBlockResponse) error() error { return r.Err } + +func getCarveBlockEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getCarveBlockRequest) + data, err := svc.GetBlock(ctx, req.ID, req.BlockId) + if err != nil { + return getCarveBlockResponse{Err: err}, nil + } + + return getCarveBlockResponse{Data: data}, nil +} + +func (svc *Service) GetBlock(ctx context.Context, carveId, blockId int64) ([]byte, error) { + if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil { + return nil, err + } + + metadata, err := svc.carveStore.Carve(ctx, carveId) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "get carve by name") + } + + if metadata.Expired { + return nil, errors.New("cannot get block for expired carve") + } + + if blockId > metadata.MaxBlock { + return nil, fmt.Errorf("block %d not yet available", blockId) + } + + data, err := svc.carveStore.GetBlock(ctx, metadata, blockId) + if err != nil { + return nil, ctxerr.Wrapf(ctx, err, "get block %d", blockId) + } + + return data, nil +} diff --git a/server/service/carves_test.go b/server/service/carves_test.go new file mode 100644 index 0000000000..987e1cefdb --- /dev/null +++ b/server/service/carves_test.go @@ -0,0 +1,186 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListCarves(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.ListCarvesFunc = func(ctx context.Context, opts fleet.CarveListOptions) ([]*fleet.CarveMetadata, error) { + return []*fleet.CarveMetadata{ + {ID: 1}, + {ID: 2}, + }, nil + } + + // admin user + carves, err := svc.ListCarves(test.UserContext(test.UserAdmin), fleet.CarveListOptions{}) + require.NoError(t, err) + require.Len(t, carves, 2) + + // only global admin can read carves + _, err = svc.ListCarves(test.UserContext(test.UserNoRoles), fleet.CarveListOptions{}) + require.Error(t, err) + require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) + + // no user in context + _, err = svc.ListCarves(context.Background(), fleet.CarveListOptions{}) + require.Error(t, err) + require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) +} + +func TestGetCarve(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.CarveFunc = func(ctx context.Context, id int64) (*fleet.CarveMetadata, error) { + return &fleet.CarveMetadata{ + ID: id, + }, nil + } + + // admin user + carve, err := svc.GetCarve(test.UserContext(test.UserAdmin), 1) + require.NoError(t, err) + require.Equal(t, int64(1), carve.ID) + + // only global admin can read carves + _, err = svc.GetCarve(test.UserContext(test.UserNoRoles), 1) + require.Error(t, err) + require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) + + // no user in context + _, err = svc.GetCarve(context.Background(), 1) + require.Error(t, err) + require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) +} + +func TestCarveGetBlock(t *testing.T) { + ds := new(mock.Store) + svc := &Service{carveStore: ds, authz: authz.Must()} + + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: "foobar", + MaxBlock: 3, + } + + ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.ID, carveId) + return metadata, nil + } + ds.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) { + assert.Equal(t, metadata.ID, carve.ID) + assert.Equal(t, int64(3), blockId) + return []byte("foobar"), nil + } + + data, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3) + require.NoError(t, err) + assert.Equal(t, []byte("foobar"), data) + + // only global admin can read carves + _, err = svc.GetBlock(test.UserContext(test.UserNoRoles), metadata.ID, 2) + require.Error(t, err) + require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) +} + +func TestCarveGetBlockNotAvailableError(t *testing.T) { + ds := new(mock.Store) + svc := &Service{carveStore: ds, authz: authz.Must()} + + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: "foobar", + MaxBlock: 3, + } + + ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.ID, carveId) + return metadata, nil + } + + // Block requested is greater than max block + _, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 7) + require.Error(t, err) + assert.Contains(t, err.Error(), "not yet available") +} + +func TestCarveGetBlockGetBlockError(t *testing.T) { + ds := new(mock.Store) + svc := &Service{carveStore: ds, authz: authz.Must()} + + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: "foobar", + MaxBlock: 3, + } + + ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.ID, carveId) + return metadata, nil + } + ds.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) { + assert.Equal(t, metadata.ID, carve.ID) + assert.Equal(t, int64(3), blockId) + return nil, errors.New("yow!!") + } + + // GetBlock failed + _, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3) + require.Error(t, err) + assert.Contains(t, err.Error(), "yow!!") +} + +func TestCarveGetBlockExpired(t *testing.T) { + ds := new(mock.Store) + svc := &Service{carveStore: ds, authz: authz.Must()} + + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: "foobar", + MaxBlock: 3, + Expired: true, + } + + ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.ID, carveId) + return metadata, nil + } + + // carve is expired + _, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3) + require.Error(t, err) + assert.Contains(t, err.Error(), "expired carve") +} diff --git a/server/service/endpoint_carves.go b/server/service/endpoint_carves.go index cb5f299dcf..ec94fc1694 100644 --- a/server/service/endpoint_carves.go +++ b/server/service/endpoint_carves.go @@ -87,90 +87,3 @@ func makeCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint { return carveBlockResponse{Success: true}, nil } } - -//////////////////////////////////////////////////////////////////////////////// -// Get Carve -//////////////////////////////////////////////////////////////////////////////// - -type getCarveRequest struct { - ID int64 -} - -type getCarveResponse struct { - Carve fleet.CarveMetadata `json:"carve"` - Err error `json:"error,omitempty"` -} - -func (r getCarveResponse) error() error { return r.Err } - -func makeGetCarveEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getCarveRequest) - carve, err := svc.GetCarve(ctx, req.ID) - if err != nil { - return getCarveResponse{Err: err}, nil - } - - return getCarveResponse{Carve: *carve}, nil - - } -} - -//////////////////////////////////////////////////////////////////////////////// -// List Carves -//////////////////////////////////////////////////////////////////////////////// - -type listCarvesRequest struct { - ListOptions fleet.CarveListOptions -} - -type listCarvesResponse struct { - Carves []fleet.CarveMetadata `json:"carves"` - Err error `json:"error,omitempty"` -} - -func (r listCarvesResponse) error() error { return r.Err } - -func makeListCarvesEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(listCarvesRequest) - carves, err := svc.ListCarves(ctx, req.ListOptions) - if err != nil { - return listCarvesResponse{Err: err}, nil - } - - resp := listCarvesResponse{} - for _, carve := range carves { - resp.Carves = append(resp.Carves, *carve) - } - return resp, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Get Carve Block -//////////////////////////////////////////////////////////////////////////////// - -type getCarveBlockRequest struct { - ID int64 - BlockId int64 -} - -type getCarveBlockResponse struct { - Data []byte `json:"data"` - Err error `json:"error,omitempty"` -} - -func (r getCarveBlockResponse) error() error { return r.Err } - -func makeGetCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getCarveBlockRequest) - data, err := svc.GetBlock(ctx, req.ID, req.BlockId) - if err != nil { - return getCarveBlockResponse{Err: err}, nil - } - - return getCarveBlockResponse{Data: data}, nil - } -} diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index f3974737bc..a8211b49b9 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -122,6 +122,12 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { return nil, err } field.Set(reflect.ValueOf(opts)) + case "carve_options": + opts, err := carveListOptionsFromRequest(r) + if err != nil { + return nil, err + } + field.Set(reflect.ValueOf(opts)) default: id, err := idFromRequest(r, urlTagValue) if err != nil { @@ -131,7 +137,11 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { return nil, err } - field.SetUint(uint64(id)) + if field.Kind() == reflect.Int64 { + field.SetInt(int64(id)) + } else { + field.SetUint(uint64(id)) + } } } diff --git a/server/service/handler.go b/server/service/handler.go index 7b8825e251..dda3833ad4 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -107,9 +107,6 @@ type FleetEndpoints struct { SSOSettings endpoint.Endpoint StatusResultStore endpoint.Endpoint StatusLiveQuery endpoint.Endpoint - ListCarves endpoint.Endpoint - GetCarve endpoint.Endpoint - GetCarveBlock endpoint.Endpoint Version endpoint.Endpoint CreateTeam endpoint.Endpoint ModifyTeam endpoint.Endpoint @@ -214,9 +211,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th SearchTargets: authenticatedUser(svc, makeSearchTargetsEndpoint(svc)), GetCertificate: authenticatedUser(svc, makeCertificateEndpoint(svc)), ChangeEmail: authenticatedUser(svc, makeChangeEmailEndpoint(svc)), - ListCarves: authenticatedUser(svc, makeListCarvesEndpoint(svc)), - GetCarve: authenticatedUser(svc, makeGetCarveEndpoint(svc)), - GetCarveBlock: authenticatedUser(svc, makeGetCarveBlockEndpoint(svc)), Version: authenticatedUser(svc, makeVersionEndpoint(svc)), CreateTeam: authenticatedUser(svc, makeCreateTeamEndpoint(svc)), ModifyTeam: authenticatedUser(svc, makeModifyTeamEndpoint(svc)), @@ -333,9 +327,6 @@ type fleetHandlers struct { SettingsSSO http.Handler StatusResultStore http.Handler StatusLiveQuery http.Handler - ListCarves http.Handler - GetCarve http.Handler - GetCarveBlock http.Handler Version http.Handler CreateTeam http.Handler ModifyTeam http.Handler @@ -439,9 +430,6 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest), StatusResultStore: newServer(e.StatusResultStore, decodeNoParamsRequest), StatusLiveQuery: newServer(e.StatusLiveQuery, decodeNoParamsRequest), - ListCarves: newServer(e.ListCarves, decodeListCarvesRequest), - GetCarve: newServer(e.GetCarve, decodeGetCarveRequest), - GetCarveBlock: newServer(e.GetCarveBlock, decodeGetCarveBlockRequest), Version: newServer(e.Version, decodeNoParamsRequest), CreateTeam: newServer(e.CreateTeam, decodeCreateTeamRequest), ModifyTeam: newServer(e.ModifyTeam, decodeModifyTeamRequest), @@ -642,10 +630,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/status/result_store", h.StatusResultStore).Methods("GET").Name("status_result_store") r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query") - r.Handle("/api/v1/fleet/carves", h.ListCarves).Methods("GET").Name("list_carves") - r.Handle("/api/v1/fleet/carves/{id:[0-9]+}", h.GetCarve).Methods("GET").Name("get_carve") - r.Handle("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block") - r.Handle("/api/v1/fleet/teams", h.CreateTeam).Methods("POST").Name("create_team") r.Handle("/api/v1/fleet/teams", h.ListTeams).Methods("GET").Name("list_teams") r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.ModifyTeam).Methods("PATCH").Name("modify_team") @@ -719,6 +703,10 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht e.PATCH("/api/v1/fleet/invites/{id:[0-9]+}", updateInviteEndpoint, updateInviteRequest{}) e.GET("/api/v1/fleet/activities", listActivitiesEndpoint, listActivitiesRequest{}) + + e.GET("/api/v1/fleet/carves", listCarvesEndpoint, listCarvesRequest{}) + e.GET("/api/v1/fleet/carves/{id:[0-9]+}", getCarveEndpoint, getCarveRequest{}) + e.GET("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", getCarveBlockEndpoint, getCarveBlockRequest{}) } // TODO: this duplicates the one in makeKitHandler diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 567487ee06..e744a5a8bd 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1185,3 +1185,80 @@ func (s *integrationTestSuite) TestListActivities() { require.Len(t, listResp.Activities, 1) assert.Equal(t, fleet.ActivityTypeEditedPack, listResp.Activities[0].Type) } + +func (s *integrationTestSuite) TestListGetCarves() { + t := s.T() + + ctx := context.Background() + + hosts := s.createHosts(t) + c1, err := s.ds.NewCarve(ctx, &fleet.CarveMetadata{ + CreatedAt: time.Now(), + HostId: hosts[0].ID, + Name: t.Name() + "_1", + SessionId: "ssn1", + }) + require.NoError(t, err) + c2, err := s.ds.NewCarve(ctx, &fleet.CarveMetadata{ + CreatedAt: time.Now(), + HostId: hosts[1].ID, + Name: t.Name() + "_2", + SessionId: "ssn2", + }) + require.NoError(t, err) + c3, err := s.ds.NewCarve(ctx, &fleet.CarveMetadata{ + CreatedAt: time.Now(), + HostId: hosts[2].ID, + Name: t.Name() + "_3", + SessionId: "ssn3", + }) + require.NoError(t, err) + + // set c1 max block + c1.MaxBlock = 3 + require.NoError(t, s.ds.UpdateCarve(ctx, c1)) + // make c2 expired, set max block + c2.Expired = true + c2.MaxBlock = 3 + require.NoError(t, s.ds.UpdateCarve(ctx, c2)) + + var listResp listCarvesResponse + s.DoJSON("GET", "/api/v1/fleet/carves", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "id") + require.Len(t, listResp.Carves, 2) + assert.Equal(t, c1.ID, listResp.Carves[0].ID) + assert.Equal(t, c3.ID, listResp.Carves[1].ID) + + // include expired + s.DoJSON("GET", "/api/v1/fleet/carves", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "id", "expired", "1") + require.Len(t, listResp.Carves, 2) + assert.Equal(t, c1.ID, listResp.Carves[0].ID) + assert.Equal(t, c2.ID, listResp.Carves[1].ID) + + // empty page + s.DoJSON("GET", "/api/v1/fleet/carves", nil, http.StatusOK, &listResp, "page", "3", "per_page", "2", "order_key", "id", "expired", "1") + require.Len(t, listResp.Carves, 0) + + // get specific carve + var getResp getCarveResponse + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d", c2.ID), nil, http.StatusOK, &getResp) + require.Equal(t, c2.ID, getResp.Carve.ID) + require.True(t, getResp.Carve.Expired) + + // get non-existing carve + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d", c3.ID+1), nil, http.StatusNotFound, &getResp) + + // get expired carve block + var blkResp getCarveBlockResponse + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d/block/%d", c2.ID, 1), nil, http.StatusInternalServerError, &blkResp) + + // get valid carve block, but block not inserted yet + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d/block/%d", c1.ID, 1), nil, http.StatusNotFound, &blkResp) + + require.NoError(t, s.ds.NewBlock(ctx, c1, 1, []byte("block1"))) + require.NoError(t, s.ds.NewBlock(ctx, c1, 2, []byte("block2"))) + require.NoError(t, s.ds.NewBlock(ctx, c1, 3, []byte("block3"))) + + // get valid carve block + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d/block/%d", c1.ID, 1), nil, http.StatusOK, &blkResp) + require.Equal(t, "block1", string(blkResp.Data)) +} diff --git a/server/service/service_carves.go b/server/service/service_carves.go index 52ad1e245d..1d53b70390 100644 --- a/server/service/service_carves.go +++ b/server/service/service_carves.go @@ -105,45 +105,3 @@ func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayl return nil } - -func (svc *Service) GetCarve(ctx context.Context, id int64) (*fleet.CarveMetadata, error) { - if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.carveStore.Carve(ctx, id) -} - -func (svc *Service) ListCarves(ctx context.Context, opt fleet.CarveListOptions) ([]*fleet.CarveMetadata, error) { - if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.carveStore.ListCarves(ctx, opt) -} - -func (svc *Service) GetBlock(ctx context.Context, carveId, blockId int64) ([]byte, error) { - if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil { - return nil, err - } - - metadata, err := svc.carveStore.Carve(ctx, carveId) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "get carve by name") - } - - if metadata.Expired { - return nil, errors.New("cannot get block for expired carve") - } - - if blockId > metadata.MaxBlock { - return nil, fmt.Errorf("block %d not yet available", blockId) - } - - data, err := svc.carveStore.GetBlock(ctx, metadata, blockId) - if err != nil { - return nil, ctxerr.Wrapf(ctx, err, "get block %d", blockId) - } - - return data, nil -} diff --git a/server/service/service_carves_test.go b/server/service/service_carves_test.go index d884d4918f..ab1a0b7cbf 100644 --- a/server/service/service_carves_test.go +++ b/server/service/service_carves_test.go @@ -6,12 +6,10 @@ import ( "testing" "time" - "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/fleet" hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" "github.com/fleetdm/fleet/v4/server/mock" - "github.com/fleetdm/fleet/v4/server/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -358,87 +356,3 @@ func TestCarveCarveBlock(t *testing.T) { require.NoError(t, err) assert.True(t, ms.NewBlockFuncInvoked) } - -func TestCarveGetBlock(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms, authz: authz.Must()} - ms.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.ID, carveId) - return metadata, nil - } - ms.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) { - assert.Equal(t, metadata.ID, carve.ID) - assert.Equal(t, int64(3), blockId) - return []byte("foobar"), nil - } - - data, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3) - require.NoError(t, err) - assert.Equal(t, []byte("foobar"), data) -} - -func TestCarveGetBlockNotAvailableError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms, authz: authz.Must()} - ms.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.ID, carveId) - return metadata, nil - } - - // Block requested is great than max block - _, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 7) - require.Error(t, err) - assert.Contains(t, err.Error(), "not yet available") -} - -func TestCarveGetBlockGetBlockError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms, authz: authz.Must()} - ms.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.ID, carveId) - return metadata, nil - } - ms.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) { - assert.Equal(t, metadata.ID, carve.ID) - assert.Equal(t, int64(3), blockId) - return nil, errors.New("yow!!") - } - - // Block requested is greater than max block - _, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3) - require.Error(t, err) - assert.Contains(t, err.Error(), "yow!!") -} diff --git a/server/service/transport.go b/server/service/transport.go index 68b9b93fc7..0ba76da896 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -241,6 +241,27 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) return hopt, nil } +func carveListOptionsFromRequest(r *http.Request) (fleet.CarveListOptions, error) { + opt, err := listOptionsFromRequest(r) + if err != nil { + return fleet.CarveListOptions{}, err + } + + copt := fleet.CarveListOptions{ListOptions: opt} + + expired := r.URL.Query().Get("expired") + // TODO(mna): allow the same bool encodings as strconv.ParseBool and use it? + switch expired { + case "1", "true": + copt.Expired = true + case "0", "": + copt.Expired = false + default: + return copt, ctxerr.Errorf(r.Context(), "invalid expired value %s", expired) + } + return copt, nil +} + func userListOptionsFromRequest(r *http.Request) (fleet.UserListOptions, error) { opt, err := listOptionsFromRequest(r) if err != nil { diff --git a/server/service/transport_carves.go b/server/service/transport_carves.go index 49034dc0b5..1629fd73a4 100644 --- a/server/service/transport_carves.go +++ b/server/service/transport_carves.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/fleet" ) func decodeCarveBeginRequest(ctx context.Context, r *http.Request) (interface{}, error) { @@ -30,41 +29,3 @@ func decodeCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, return req, nil } - -func decodeGetCarveRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := idFromRequest(r, "id") - if err != nil { - return nil, err - } - return getCarveRequest{ID: int64(id)}, nil -} - -func decodeListCarvesRequest(ctx context.Context, r *http.Request) (interface{}, error) { - opt, err := listOptionsFromRequest(r) - if err != nil { - return nil, err - } - copt := fleet.CarveListOptions{ListOptions: opt} - expired := r.URL.Query().Get("expired") - switch expired { - case "1", "true": - copt.Expired = true - case "0", "": - copt.Expired = false - default: - return nil, ctxerr.Errorf(ctx, "invalid expired value %s", expired) - } - return listCarvesRequest{ListOptions: copt}, nil -} - -func decodeGetCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := idFromRequest(r, "id") - if err != nil { - return nil, err - } - blockId, err := idFromRequest(r, "block_id") - if err != nil { - return nil, err - } - return getCarveBlockRequest{ID: int64(id), BlockId: int64(blockId)}, nil -}