Move carves endpoints to new endpoint pattern (#3148)

This commit is contained in:
Martin Angers 2021-12-01 15:45:29 -05:00 committed by GitHub
parent 3a031e946d
commit 7464e72ba8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 444 additions and 271 deletions

View file

@ -167,6 +167,9 @@ func (d *Datastore) Carve(ctx context.Context, carveId int64) (*fleet.CarveMetad
var metadata fleet.CarveMetadata
if err := sqlx.GetContext(ctx, d.reader, &metadata, stmt, carveId); err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("Carve").WithID(uint(carveId)))
}
return nil, ctxerr.Wrap(ctx, err, "get carve by ID")
}
@ -183,6 +186,9 @@ func (d *Datastore) CarveBySessionId(ctx context.Context, sessionId string) (*fl
var metadata fleet.CarveMetadata
if err := sqlx.GetContext(ctx, d.reader, &metadata, stmt, sessionId); err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("CarveBySessionId").WithName(sessionId))
}
return nil, ctxerr.Wrap(ctx, err, "get carve by session ID")
}
@ -199,6 +205,9 @@ func (d *Datastore) CarveByName(ctx context.Context, name string) (*fleet.CarveM
var metadata fleet.CarveMetadata
if err := sqlx.GetContext(ctx, d.reader, &metadata, stmt, name); err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("Carve").WithName(name))
}
return nil, ctxerr.Wrap(ctx, err, "get carve by name")
}
@ -259,6 +268,9 @@ func (d *Datastore) GetBlock(ctx context.Context, metadata *fleet.CarveMetadata,
`
var data []byte
if err := sqlx.GetContext(ctx, d.reader, &data, stmt, metadata.ID, blockId); err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("CarveBlock").WithID(uint(blockId)))
}
return nil, ctxerr.Wrap(ctx, err, "select data")
}

133
server/service/carves.go Normal file
View file

@ -0,0 +1,133 @@
package service
import (
"context"
"errors"
"fmt"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
////////////////////////////////////////////////////////////////////////////////
// List Carves
////////////////////////////////////////////////////////////////////////////////
type listCarvesRequest struct {
ListOptions fleet.CarveListOptions `url:"carve_options"`
}
type listCarvesResponse struct {
Carves []fleet.CarveMetadata `json:"carves"`
Err error `json:"error,omitempty"`
}
func (r listCarvesResponse) error() error { return r.Err }
func listCarvesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*listCarvesRequest)
carves, err := svc.ListCarves(ctx, req.ListOptions)
if err != nil {
return listCarvesResponse{Err: err}, nil
}
resp := listCarvesResponse{}
for _, carve := range carves {
resp.Carves = append(resp.Carves, *carve)
}
return resp, nil
}
func (svc *Service) ListCarves(ctx context.Context, opt fleet.CarveListOptions) ([]*fleet.CarveMetadata, error) {
if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.carveStore.ListCarves(ctx, opt)
}
////////////////////////////////////////////////////////////////////////////////
// Get Carve
////////////////////////////////////////////////////////////////////////////////
type getCarveRequest struct {
ID int64 `url:"id"`
}
type getCarveResponse struct {
Carve fleet.CarveMetadata `json:"carve"`
Err error `json:"error,omitempty"`
}
func (r getCarveResponse) error() error { return r.Err }
func getCarveEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getCarveRequest)
carve, err := svc.GetCarve(ctx, req.ID)
if err != nil {
return getCarveResponse{Err: err}, nil
}
return getCarveResponse{Carve: *carve}, nil
}
func (svc *Service) GetCarve(ctx context.Context, id int64) (*fleet.CarveMetadata, error) {
if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.carveStore.Carve(ctx, id)
}
////////////////////////////////////////////////////////////////////////////////
// Get Carve Block
////////////////////////////////////////////////////////////////////////////////
type getCarveBlockRequest struct {
ID int64 `url:"id"`
BlockId int64 `url:"block_id"`
}
type getCarveBlockResponse struct {
Data []byte `json:"data"`
Err error `json:"error,omitempty"`
}
func (r getCarveBlockResponse) error() error { return r.Err }
func getCarveBlockEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getCarveBlockRequest)
data, err := svc.GetBlock(ctx, req.ID, req.BlockId)
if err != nil {
return getCarveBlockResponse{Err: err}, nil
}
return getCarveBlockResponse{Data: data}, nil
}
func (svc *Service) GetBlock(ctx context.Context, carveId, blockId int64) ([]byte, error) {
if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil {
return nil, err
}
metadata, err := svc.carveStore.Carve(ctx, carveId)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get carve by name")
}
if metadata.Expired {
return nil, errors.New("cannot get block for expired carve")
}
if blockId > metadata.MaxBlock {
return nil, fmt.Errorf("block %d not yet available", blockId)
}
data, err := svc.carveStore.GetBlock(ctx, metadata, blockId)
if err != nil {
return nil, ctxerr.Wrapf(ctx, err, "get block %d", blockId)
}
return data, nil
}

View file

@ -0,0 +1,186 @@
package service
import (
"context"
"errors"
"testing"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestListCarves(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.ListCarvesFunc = func(ctx context.Context, opts fleet.CarveListOptions) ([]*fleet.CarveMetadata, error) {
return []*fleet.CarveMetadata{
{ID: 1},
{ID: 2},
}, nil
}
// admin user
carves, err := svc.ListCarves(test.UserContext(test.UserAdmin), fleet.CarveListOptions{})
require.NoError(t, err)
require.Len(t, carves, 2)
// only global admin can read carves
_, err = svc.ListCarves(test.UserContext(test.UserNoRoles), fleet.CarveListOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
// no user in context
_, err = svc.ListCarves(context.Background(), fleet.CarveListOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}
func TestGetCarve(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.CarveFunc = func(ctx context.Context, id int64) (*fleet.CarveMetadata, error) {
return &fleet.CarveMetadata{
ID: id,
}, nil
}
// admin user
carve, err := svc.GetCarve(test.UserContext(test.UserAdmin), 1)
require.NoError(t, err)
require.Equal(t, int64(1), carve.ID)
// only global admin can read carves
_, err = svc.GetCarve(test.UserContext(test.UserNoRoles), 1)
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
// no user in context
_, err = svc.GetCarve(context.Background(), 1)
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}
func TestCarveGetBlock(t *testing.T) {
ds := new(mock.Store)
svc := &Service{carveStore: ds, authz: authz.Must()}
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: "foobar",
MaxBlock: 3,
}
ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
ds.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) {
assert.Equal(t, metadata.ID, carve.ID)
assert.Equal(t, int64(3), blockId)
return []byte("foobar"), nil
}
data, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3)
require.NoError(t, err)
assert.Equal(t, []byte("foobar"), data)
// only global admin can read carves
_, err = svc.GetBlock(test.UserContext(test.UserNoRoles), metadata.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}
func TestCarveGetBlockNotAvailableError(t *testing.T) {
ds := new(mock.Store)
svc := &Service{carveStore: ds, authz: authz.Must()}
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: "foobar",
MaxBlock: 3,
}
ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
// Block requested is greater than max block
_, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 7)
require.Error(t, err)
assert.Contains(t, err.Error(), "not yet available")
}
func TestCarveGetBlockGetBlockError(t *testing.T) {
ds := new(mock.Store)
svc := &Service{carveStore: ds, authz: authz.Must()}
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: "foobar",
MaxBlock: 3,
}
ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
ds.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) {
assert.Equal(t, metadata.ID, carve.ID)
assert.Equal(t, int64(3), blockId)
return nil, errors.New("yow!!")
}
// GetBlock failed
_, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3)
require.Error(t, err)
assert.Contains(t, err.Error(), "yow!!")
}
func TestCarveGetBlockExpired(t *testing.T) {
ds := new(mock.Store)
svc := &Service{carveStore: ds, authz: authz.Must()}
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: "foobar",
MaxBlock: 3,
Expired: true,
}
ds.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
// carve is expired
_, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3)
require.Error(t, err)
assert.Contains(t, err.Error(), "expired carve")
}

View file

@ -87,90 +87,3 @@ func makeCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint {
return carveBlockResponse{Success: true}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Get Carve
////////////////////////////////////////////////////////////////////////////////
type getCarveRequest struct {
ID int64
}
type getCarveResponse struct {
Carve fleet.CarveMetadata `json:"carve"`
Err error `json:"error,omitempty"`
}
func (r getCarveResponse) error() error { return r.Err }
func makeGetCarveEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(getCarveRequest)
carve, err := svc.GetCarve(ctx, req.ID)
if err != nil {
return getCarveResponse{Err: err}, nil
}
return getCarveResponse{Carve: *carve}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// List Carves
////////////////////////////////////////////////////////////////////////////////
type listCarvesRequest struct {
ListOptions fleet.CarveListOptions
}
type listCarvesResponse struct {
Carves []fleet.CarveMetadata `json:"carves"`
Err error `json:"error,omitempty"`
}
func (r listCarvesResponse) error() error { return r.Err }
func makeListCarvesEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listCarvesRequest)
carves, err := svc.ListCarves(ctx, req.ListOptions)
if err != nil {
return listCarvesResponse{Err: err}, nil
}
resp := listCarvesResponse{}
for _, carve := range carves {
resp.Carves = append(resp.Carves, *carve)
}
return resp, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Get Carve Block
////////////////////////////////////////////////////////////////////////////////
type getCarveBlockRequest struct {
ID int64
BlockId int64
}
type getCarveBlockResponse struct {
Data []byte `json:"data"`
Err error `json:"error,omitempty"`
}
func (r getCarveBlockResponse) error() error { return r.Err }
func makeGetCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(getCarveBlockRequest)
data, err := svc.GetBlock(ctx, req.ID, req.BlockId)
if err != nil {
return getCarveBlockResponse{Err: err}, nil
}
return getCarveBlockResponse{Data: data}, nil
}
}

View file

@ -122,6 +122,12 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
return nil, err
}
field.Set(reflect.ValueOf(opts))
case "carve_options":
opts, err := carveListOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
default:
id, err := idFromRequest(r, urlTagValue)
if err != nil {
@ -131,7 +137,11 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
return nil, err
}
field.SetUint(uint64(id))
if field.Kind() == reflect.Int64 {
field.SetInt(int64(id))
} else {
field.SetUint(uint64(id))
}
}
}

View file

@ -107,9 +107,6 @@ type FleetEndpoints struct {
SSOSettings endpoint.Endpoint
StatusResultStore endpoint.Endpoint
StatusLiveQuery endpoint.Endpoint
ListCarves endpoint.Endpoint
GetCarve endpoint.Endpoint
GetCarveBlock endpoint.Endpoint
Version endpoint.Endpoint
CreateTeam endpoint.Endpoint
ModifyTeam endpoint.Endpoint
@ -214,9 +211,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
SearchTargets: authenticatedUser(svc, makeSearchTargetsEndpoint(svc)),
GetCertificate: authenticatedUser(svc, makeCertificateEndpoint(svc)),
ChangeEmail: authenticatedUser(svc, makeChangeEmailEndpoint(svc)),
ListCarves: authenticatedUser(svc, makeListCarvesEndpoint(svc)),
GetCarve: authenticatedUser(svc, makeGetCarveEndpoint(svc)),
GetCarveBlock: authenticatedUser(svc, makeGetCarveBlockEndpoint(svc)),
Version: authenticatedUser(svc, makeVersionEndpoint(svc)),
CreateTeam: authenticatedUser(svc, makeCreateTeamEndpoint(svc)),
ModifyTeam: authenticatedUser(svc, makeModifyTeamEndpoint(svc)),
@ -333,9 +327,6 @@ type fleetHandlers struct {
SettingsSSO http.Handler
StatusResultStore http.Handler
StatusLiveQuery http.Handler
ListCarves http.Handler
GetCarve http.Handler
GetCarveBlock http.Handler
Version http.Handler
CreateTeam http.Handler
ModifyTeam http.Handler
@ -439,9 +430,6 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest),
StatusResultStore: newServer(e.StatusResultStore, decodeNoParamsRequest),
StatusLiveQuery: newServer(e.StatusLiveQuery, decodeNoParamsRequest),
ListCarves: newServer(e.ListCarves, decodeListCarvesRequest),
GetCarve: newServer(e.GetCarve, decodeGetCarveRequest),
GetCarveBlock: newServer(e.GetCarveBlock, decodeGetCarveBlockRequest),
Version: newServer(e.Version, decodeNoParamsRequest),
CreateTeam: newServer(e.CreateTeam, decodeCreateTeamRequest),
ModifyTeam: newServer(e.ModifyTeam, decodeModifyTeamRequest),
@ -642,10 +630,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/status/result_store", h.StatusResultStore).Methods("GET").Name("status_result_store")
r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query")
r.Handle("/api/v1/fleet/carves", h.ListCarves).Methods("GET").Name("list_carves")
r.Handle("/api/v1/fleet/carves/{id:[0-9]+}", h.GetCarve).Methods("GET").Name("get_carve")
r.Handle("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", h.GetCarveBlock).Methods("GET").Name("get_carve_block")
r.Handle("/api/v1/fleet/teams", h.CreateTeam).Methods("POST").Name("create_team")
r.Handle("/api/v1/fleet/teams", h.ListTeams).Methods("GET").Name("list_teams")
r.Handle("/api/v1/fleet/teams/{id:[0-9]+}", h.ModifyTeam).Methods("PATCH").Name("modify_team")
@ -719,6 +703,10 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
e.PATCH("/api/v1/fleet/invites/{id:[0-9]+}", updateInviteEndpoint, updateInviteRequest{})
e.GET("/api/v1/fleet/activities", listActivitiesEndpoint, listActivitiesRequest{})
e.GET("/api/v1/fleet/carves", listCarvesEndpoint, listCarvesRequest{})
e.GET("/api/v1/fleet/carves/{id:[0-9]+}", getCarveEndpoint, getCarveRequest{})
e.GET("/api/v1/fleet/carves/{id:[0-9]+}/block/{block_id}", getCarveBlockEndpoint, getCarveBlockRequest{})
}
// TODO: this duplicates the one in makeKitHandler

View file

@ -1185,3 +1185,80 @@ func (s *integrationTestSuite) TestListActivities() {
require.Len(t, listResp.Activities, 1)
assert.Equal(t, fleet.ActivityTypeEditedPack, listResp.Activities[0].Type)
}
func (s *integrationTestSuite) TestListGetCarves() {
t := s.T()
ctx := context.Background()
hosts := s.createHosts(t)
c1, err := s.ds.NewCarve(ctx, &fleet.CarveMetadata{
CreatedAt: time.Now(),
HostId: hosts[0].ID,
Name: t.Name() + "_1",
SessionId: "ssn1",
})
require.NoError(t, err)
c2, err := s.ds.NewCarve(ctx, &fleet.CarveMetadata{
CreatedAt: time.Now(),
HostId: hosts[1].ID,
Name: t.Name() + "_2",
SessionId: "ssn2",
})
require.NoError(t, err)
c3, err := s.ds.NewCarve(ctx, &fleet.CarveMetadata{
CreatedAt: time.Now(),
HostId: hosts[2].ID,
Name: t.Name() + "_3",
SessionId: "ssn3",
})
require.NoError(t, err)
// set c1 max block
c1.MaxBlock = 3
require.NoError(t, s.ds.UpdateCarve(ctx, c1))
// make c2 expired, set max block
c2.Expired = true
c2.MaxBlock = 3
require.NoError(t, s.ds.UpdateCarve(ctx, c2))
var listResp listCarvesResponse
s.DoJSON("GET", "/api/v1/fleet/carves", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "id")
require.Len(t, listResp.Carves, 2)
assert.Equal(t, c1.ID, listResp.Carves[0].ID)
assert.Equal(t, c3.ID, listResp.Carves[1].ID)
// include expired
s.DoJSON("GET", "/api/v1/fleet/carves", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "id", "expired", "1")
require.Len(t, listResp.Carves, 2)
assert.Equal(t, c1.ID, listResp.Carves[0].ID)
assert.Equal(t, c2.ID, listResp.Carves[1].ID)
// empty page
s.DoJSON("GET", "/api/v1/fleet/carves", nil, http.StatusOK, &listResp, "page", "3", "per_page", "2", "order_key", "id", "expired", "1")
require.Len(t, listResp.Carves, 0)
// get specific carve
var getResp getCarveResponse
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d", c2.ID), nil, http.StatusOK, &getResp)
require.Equal(t, c2.ID, getResp.Carve.ID)
require.True(t, getResp.Carve.Expired)
// get non-existing carve
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d", c3.ID+1), nil, http.StatusNotFound, &getResp)
// get expired carve block
var blkResp getCarveBlockResponse
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d/block/%d", c2.ID, 1), nil, http.StatusInternalServerError, &blkResp)
// get valid carve block, but block not inserted yet
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d/block/%d", c1.ID, 1), nil, http.StatusNotFound, &blkResp)
require.NoError(t, s.ds.NewBlock(ctx, c1, 1, []byte("block1")))
require.NoError(t, s.ds.NewBlock(ctx, c1, 2, []byte("block2")))
require.NoError(t, s.ds.NewBlock(ctx, c1, 3, []byte("block3")))
// get valid carve block
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/carves/%d/block/%d", c1.ID, 1), nil, http.StatusOK, &blkResp)
require.Equal(t, "block1", string(blkResp.Data))
}

View file

@ -105,45 +105,3 @@ func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayl
return nil
}
func (svc *Service) GetCarve(ctx context.Context, id int64) (*fleet.CarveMetadata, error) {
if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.carveStore.Carve(ctx, id)
}
func (svc *Service) ListCarves(ctx context.Context, opt fleet.CarveListOptions) ([]*fleet.CarveMetadata, error) {
if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.carveStore.ListCarves(ctx, opt)
}
func (svc *Service) GetBlock(ctx context.Context, carveId, blockId int64) ([]byte, error) {
if err := svc.authz.Authorize(ctx, &fleet.CarveMetadata{}, fleet.ActionRead); err != nil {
return nil, err
}
metadata, err := svc.carveStore.Carve(ctx, carveId)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get carve by name")
}
if metadata.Expired {
return nil, errors.New("cannot get block for expired carve")
}
if blockId > metadata.MaxBlock {
return nil, fmt.Errorf("block %d not yet available", blockId)
}
data, err := svc.carveStore.GetBlock(ctx, metadata, blockId)
if err != nil {
return nil, ctxerr.Wrapf(ctx, err, "get block %d", blockId)
}
return data, nil
}

View file

@ -6,12 +6,10 @@ import (
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -358,87 +356,3 @@ func TestCarveCarveBlock(t *testing.T) {
require.NoError(t, err)
assert.True(t, ms.NewBlockFuncInvoked)
}
func TestCarveGetBlock(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms, authz: authz.Must()}
ms.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
ms.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) {
assert.Equal(t, metadata.ID, carve.ID)
assert.Equal(t, int64(3), blockId)
return []byte("foobar"), nil
}
data, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3)
require.NoError(t, err)
assert.Equal(t, []byte("foobar"), data)
}
func TestCarveGetBlockNotAvailableError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms, authz: authz.Must()}
ms.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
// Block requested is great than max block
_, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 7)
require.Error(t, err)
assert.Contains(t, err.Error(), "not yet available")
}
func TestCarveGetBlockGetBlockError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms, authz: authz.Must()}
ms.CarveFunc = func(ctx context.Context, carveId int64) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.ID, carveId)
return metadata, nil
}
ms.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) {
assert.Equal(t, metadata.ID, carve.ID)
assert.Equal(t, int64(3), blockId)
return nil, errors.New("yow!!")
}
// Block requested is greater than max block
_, err := svc.GetBlock(test.UserContext(test.UserAdmin), metadata.ID, 3)
require.Error(t, err)
assert.Contains(t, err.Error(), "yow!!")
}

View file

@ -241,6 +241,27 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
return hopt, nil
}
func carveListOptionsFromRequest(r *http.Request) (fleet.CarveListOptions, error) {
opt, err := listOptionsFromRequest(r)
if err != nil {
return fleet.CarveListOptions{}, err
}
copt := fleet.CarveListOptions{ListOptions: opt}
expired := r.URL.Query().Get("expired")
// TODO(mna): allow the same bool encodings as strconv.ParseBool and use it?
switch expired {
case "1", "true":
copt.Expired = true
case "0", "":
copt.Expired = false
default:
return copt, ctxerr.Errorf(r.Context(), "invalid expired value %s", expired)
}
return copt, nil
}
func userListOptionsFromRequest(r *http.Request) (fleet.UserListOptions, error) {
opt, err := listOptionsFromRequest(r)
if err != nil {

View file

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func decodeCarveBeginRequest(ctx context.Context, r *http.Request) (interface{}, error) {
@ -30,41 +29,3 @@ func decodeCarveBlockRequest(ctx context.Context, r *http.Request) (interface{},
return req, nil
}
func decodeGetCarveRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := idFromRequest(r, "id")
if err != nil {
return nil, err
}
return getCarveRequest{ID: int64(id)}, nil
}
func decodeListCarvesRequest(ctx context.Context, r *http.Request) (interface{}, error) {
opt, err := listOptionsFromRequest(r)
if err != nil {
return nil, err
}
copt := fleet.CarveListOptions{ListOptions: opt}
expired := r.URL.Query().Get("expired")
switch expired {
case "1", "true":
copt.Expired = true
case "0", "":
copt.Expired = false
default:
return nil, ctxerr.Errorf(ctx, "invalid expired value %s", expired)
}
return listCarvesRequest{ListOptions: copt}, nil
}
func decodeGetCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := idFromRequest(r, "id")
if err != nil {
return nil, err
}
blockId, err := idFromRequest(r, "block_id")
if err != nil {
return nil, err
}
return getCarveBlockRequest{ID: int64(id), BlockId: int64(blockId)}, nil
}