Migrate special-case endpoints to new pattern (#4511)

This commit is contained in:
Martin Angers 2022-03-08 11:27:38 -05:00 committed by GitHub
parent c14640ca84
commit c8bc026d6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 5091 additions and 5100 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -74,7 +75,10 @@ func (ds *Datastore) VerifyEnrollSecret(ctx context.Context, secret string) (*fl
var s fleet.EnrollSecret
err := sqlx.GetContext(ctx, ds.reader, &s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
if err != nil {
return nil, ctxerr.New(ctx, "no matching secret found")
if errors.Is(err, sql.ErrNoRows) {
return nil, ctxerr.New(ctx, "no matching secret found")
}
return nil, ctxerr.Wrap(ctx, err, "verify enroll secret")
}
return &s, nil

View file

@ -36,7 +36,8 @@ func (ds *Datastore) DeletePasswordResetRequestsForUser(ctx context.Context, use
return nil
}
func (ds *Datastore) FindPassswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
func (ds *Datastore) FindPasswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
sqlStatement := `
SELECT * FROM password_reset_requests
WHERE token = ? LIMIT 1
@ -48,5 +49,4 @@ func (ds *Datastore) FindPassswordResetByToken(ctx context.Context, token string
}
return passwordResetRequest, nil
}

View file

@ -226,7 +226,7 @@ type Datastore interface {
NewPasswordResetRequest(ctx context.Context, req *PasswordResetRequest) (*PasswordResetRequest, error)
DeletePasswordResetRequestsForUser(ctx context.Context, userID uint) error
FindPassswordResetByToken(ctx context.Context, token string) (*PasswordResetRequest, error)
FindPasswordResetByToken(ctx context.Context, token string) (*PasswordResetRequest, error)
///////////////////////////////////////////////////////////////////////////////
// SessionStore is the abstract interface that all session backends must conform to.

View file

@ -1,7 +1,9 @@
package fleet
import (
"errors"
"fmt"
"unicode"
"github.com/fleetdm/fleet/v4/server"
"golang.org/x/crypto/bcrypt"
@ -69,6 +71,104 @@ type UserPayload struct {
Teams *[]UserTeam `json:"teams,omitempty"`
}
func (p *UserPayload) VerifyInviteCreate() error {
invalid := &InvalidArgumentError{}
if p.Name == nil {
invalid.Append("name", "Full name missing required argument")
} else if *p.Name == "" {
invalid.Append("name", "Full name cannot be empty")
}
// we don't need a password for single sign on
if p.SSOInvite == nil || !*p.SSOInvite {
if p.Password == nil {
invalid.Append("password", "Password missing required argument")
} else if *p.Password == "" {
invalid.Append("password", "Password cannot be empty")
} else if err := ValidatePasswordRequirements(*p.Password); err != nil {
invalid.Append("password", err.Error())
}
}
if p.Email == nil {
invalid.Append("email", "Email missing required argument")
} else if *p.Email == "" {
invalid.Append("email", "Email cannot be empty")
}
if p.InviteToken == nil {
invalid.Append("invite_token", "Invite token missing required argument")
} else if *p.InviteToken == "" {
invalid.Append("invite_token", "Invite token cannot be empty")
}
if invalid.HasErrors() {
return invalid
}
return nil
}
func (p *UserPayload) VerifyAdminCreate() error {
invalid := &InvalidArgumentError{}
if p.Name == nil {
invalid.Append("name", "Full name missing required argument")
} else if *p.Name == "" {
invalid.Append("name", "Full name cannot be empty")
}
// we don't need a password for single sign on
if (p.SSOInvite == nil || !*p.SSOInvite) && (p.SSOEnabled == nil || !*p.SSOEnabled) {
if p.Password == nil {
invalid.Append("password", "Password missing required argument")
} else if *p.Password == "" {
invalid.Append("password", "Password cannot be empty")
}
// Skip password validation in the case of admin created users
}
if p.SSOEnabled != nil && *p.SSOEnabled && p.Password != nil && len(*p.Password) > 0 {
invalid.Append("password", "not allowed for SSO users")
}
if p.Email == nil {
invalid.Append("email", "Email missing required argument")
} else if *p.Email == "" {
invalid.Append("email", "Email cannot be empty")
}
if p.InviteToken != nil {
invalid.Append("invite_token", "Invite token should not be specified with admin user creation")
}
if invalid.HasErrors() {
return invalid
}
return nil
}
func (p *UserPayload) VerifyModify(ownUser bool) error {
invalid := &InvalidArgumentError{}
if p.Name != nil && *p.Name == "" {
invalid.Append("name", "Full name cannot be empty")
}
if p.Email != nil {
if *p.Email == "" {
invalid.Append("email", "Email cannot be empty")
}
// if the user is not an admin, or if an admin is changing their own email
// address a password is required,
if ownUser && p.Password == nil {
invalid.Append("password", "Password cannot be empty if email is changed")
}
}
if invalid.HasErrors() {
return invalid
}
return nil
}
// User creates a user from payload.
func (p UserPayload) User(keySize, cost int) (*User, error) {
user := &User{
@ -130,3 +230,31 @@ func (u *User) SetPassword(plaintext string, keySize, cost int) error {
u.Password = hashed
return nil
}
// Requirements for user password:
// at least 7 character length
// at least 1 symbol
// at least 1 number
func ValidatePasswordRequirements(password string) error {
var (
number bool
symbol bool
)
for _, s := range password {
switch {
case unicode.IsNumber(s):
number = true
case unicode.IsPunct(s) || unicode.IsSymbol(s):
symbol = true
}
}
if len(password) >= 7 &&
number &&
symbol {
return nil
}
return errors.New("Password does not meet validation requirements")
}

View file

@ -42,3 +42,37 @@ func newTestUser(t *testing.T, password, email string) *User {
Email: email,
}
}
func TestUserPasswordRequirements(t *testing.T) {
passwordTests := []struct {
password string
wantErr bool
}{
{
password: "foobar",
wantErr: true,
},
{
password: "foobarbaz",
wantErr: true,
},
{
password: "foobarbaz!",
wantErr: true,
},
{
password: "foobarbaz!3",
},
}
for _, tt := range passwordTests {
t.Run(tt.password, func(t *testing.T) {
err := ValidatePasswordRequirements(tt.password)
if tt.wantErr {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}

View file

@ -190,7 +190,7 @@ type NewPasswordResetRequestFunc func(ctx context.Context, req *fleet.PasswordRe
type DeletePasswordResetRequestsForUserFunc func(ctx context.Context, userID uint) error
type FindPassswordResetByTokenFunc func(ctx context.Context, token string) (*fleet.PasswordResetRequest, error)
type FindPasswordResetByTokenFunc func(ctx context.Context, token string) (*fleet.PasswordResetRequest, error)
type SessionByKeyFunc func(ctx context.Context, key string) (*fleet.Session, error)
@ -650,8 +650,8 @@ type DataStore struct {
DeletePasswordResetRequestsForUserFunc DeletePasswordResetRequestsForUserFunc
DeletePasswordResetRequestsForUserFuncInvoked bool
FindPassswordResetByTokenFunc FindPassswordResetByTokenFunc
FindPassswordResetByTokenFuncInvoked bool
FindPasswordResetByTokenFunc FindPasswordResetByTokenFunc
FindPasswordResetByTokenFuncInvoked bool
SessionByKeyFunc SessionByKeyFunc
SessionByKeyFuncInvoked bool
@ -1384,9 +1384,9 @@ func (s *DataStore) DeletePasswordResetRequestsForUser(ctx context.Context, user
return s.DeletePasswordResetRequestsForUserFunc(ctx, userID)
}
func (s *DataStore) FindPassswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
s.FindPassswordResetByTokenFuncInvoked = true
return s.FindPassswordResetByTokenFunc(ctx, token)
func (s *DataStore) FindPasswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
s.FindPasswordResetByTokenFuncInvoked = true
return s.FindPasswordResetByTokenFunc(ctx, token)
}
func (s *DataStore) SessionByKey(ctx context.Context, key string) (*fleet.Session, error) {

View file

@ -232,3 +232,75 @@ func (svc *Service) CarveBegin(ctx context.Context, payload fleet.CarveBeginPayl
return carve, nil
}
////////////////////////////////////////////////////////////////////////////////
// Receive Block for File Carve
////////////////////////////////////////////////////////////////////////////////
type carveBlockRequest struct {
BlockId int64 `json:"block_id"`
SessionId string `json:"session_id"`
RequestId string `json:"request_id"`
Data []byte `json:"data"`
}
type carveBlockResponse struct {
Success bool `json:"success,omitempty"`
Err error `json:"error,omitempty"`
}
func (r carveBlockResponse) error() error { return r.Err }
func carveBlockEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*carveBlockRequest)
payload := fleet.CarveBlockPayload{
SessionId: req.SessionId,
RequestId: req.RequestId,
BlockId: req.BlockId,
Data: req.Data,
}
err := svc.CarveBlock(ctx, payload)
if err != nil {
return carveBlockResponse{Err: err}, nil
}
return carveBlockResponse{Success: true}, nil
}
func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
// Note host did not authenticate via node key. We need to authenticate them
// by the session ID and request ID
carve, err := svc.carveStore.CarveBySessionId(ctx, payload.SessionId)
if err != nil {
return ctxerr.Wrap(ctx, err, "find carve by session_id")
}
if payload.RequestId != carve.RequestId {
return errors.New("request_id does not match")
}
// Request is now authenticated
if payload.BlockId > carve.BlockCount-1 {
return fmt.Errorf("block_id exceeds expected max (%d): %d", carve.BlockCount-1, payload.BlockId)
}
if payload.BlockId != carve.MaxBlock+1 {
return fmt.Errorf("block_id does not match expected block (%d): %d", carve.MaxBlock+1, payload.BlockId)
}
if int64(len(payload.Data)) > carve.BlockSize {
return fmt.Errorf("exceeded declared block size %d: %d", carve.BlockSize, len(payload.Data))
}
if err := svc.carveStore.NewBlock(ctx, carve, payload.BlockId, payload.Data); err != nil {
return ctxerr.Wrap(ctx, err, "save block data")
}
return nil
}

View file

@ -4,8 +4,10 @@ import (
"context"
"errors"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/authz"
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/test"
@ -184,3 +186,420 @@ func TestCarveGetBlockExpired(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "expired carve")
}
func TestCarveBegin(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
expectedMetadata := fleet.CarveMetadata{
ID: 7,
HostId: host.ID,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
}
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
metadata.ID = 7
return metadata, nil
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
metadata, err := svc.CarveBegin(ctx, payload)
require.NoError(t, err)
assert.NotEmpty(t, metadata.SessionId)
metadata.SessionId = "" // Clear this before comparison
metadata.Name = "" // Clear this before comparison
metadata.CreatedAt = time.Time{} // Clear this before comparison
assert.Equal(t, expectedMetadata, *metadata)
}
func TestCarveBeginNewCarveError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
return nil, errors.New("ouch!")
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "ouch!")
}
func TestCarveBeginEmptyError(t *testing.T) {
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ctx := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if id != 1 {
return nil, errors.New("not found")
}
return &fleet.Host{}, nil
}
_, err := svc.CarveBegin(ctx, fleet.CarveBeginPayload{})
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size must be greater than 0")
}
func TestCarveBeginMissingHostError(t *testing.T) {
ms := new(mock.Store)
svc := &Service{carveStore: ms}
_, err := svc.CarveBegin(context.Background(), fleet.CarveBeginPayload{})
require.Error(t, err)
assert.Contains(t, err.Error(), "missing host")
}
func TestCarveBeginBlockSizeMaxError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 10,
BlockSize: 1024 * 1024 * 1024 * 1024, // 1TB
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "block_size exceeds max")
}
func TestCarveBeginCarveSizeMaxError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 1024 * 1024,
BlockSize: 10 * 1024 * 1024, // 1TB
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size exceeds max")
}
func TestCarveBeginCarveSizeError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 7,
BlockSize: 13,
CarveSize: 7*13 + 1,
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ctx := hostctx.NewContext(context.Background(), &host)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
// Too big
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size does not match")
// Too small
payload.CarveSize = 6 * 13
_, err = svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size does not match")
}
func TestCarveCarveBlockGetCarveError(t *testing.T) {
sessionId := "foobar"
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
return nil, errors.New("ouch!")
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
SessionId: sessionId,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "ouch!")
}
func TestCarveCarveBlockRequestIdError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "not_matching",
SessionId: sessionId,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "request_id does not match")
}
func TestCarveCarveBlockBlockCountExceedError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 23,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "block_id exceeds expected max")
}
func TestCarveCarveBlockBlockCountMatchError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 7,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "block_id does not match")
}
func TestCarveCarveBlockBlockSizeError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 16,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :) TOO LONG!!!"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 4,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeded declared block size")
}
func TestCarveCarveBlockNewBlockError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
return errors.New("kaboom!")
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 4,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "kaboom!")
}
func TestCarveCarveBlock(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 4,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
assert.Equal(t, metadata, carve)
assert.Equal(t, int64(4), blockId)
assert.Equal(t, payload.Data, data)
return nil
}
err := svc.CarveBlock(context.Background(), payload)
require.NoError(t, err)
assert.True(t, ms.NewBlockFuncInvoked)
}

View file

@ -1,47 +0,0 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
)
////////////////////////////////////////////////////////////////////////////////
// Receive Block for File Carve
////////////////////////////////////////////////////////////////////////////////
type carveBlockRequest struct {
NodeKey string `json:"node_key"`
BlockId int64 `json:"block_id"`
SessionId string `json:"session_id"`
RequestId string `json:"request_id"`
Data []byte `json:"data"`
}
type carveBlockResponse struct {
Success bool `json:"success,omitempty"`
Err error `json:"error,omitempty"`
}
func (r carveBlockResponse) error() error { return r.Err }
func makeCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(carveBlockRequest)
payload := fleet.CarveBlockPayload{
SessionId: req.SessionId,
RequestId: req.RequestId,
BlockId: req.BlockId,
Data: req.Data,
}
err := svc.CarveBlock(ctx, payload)
if err != nil {
return carveBlockResponse{Err: err}, nil
}
return carveBlockResponse{Success: true}, nil
}
}

View file

@ -1,30 +0,0 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
)
type verifyInviteRequest struct {
Token string
}
type verifyInviteResponse struct {
Invite *fleet.Invite `json:"invite"`
Err error `json:"error,omitempty"`
}
func (r verifyInviteResponse) error() error { return r.Err }
func makeVerifyInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(verifyInviteRequest)
invite, err := svc.VerifyInvite(ctx, req.Token)
if err != nil {
return verifyInviteResponse{Err: err}, nil
}
return verifyInviteResponse{Invite: invite}, nil
}
}

View file

@ -137,6 +137,10 @@ func authenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
return logged(authUserFunc)
}
func unauthenticatedRequest(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
return logged(next)
}
// logged wraps an endpoint and adds the error if the context supports it
func logged(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
@ -167,16 +171,3 @@ func authViewer(ctx context.Context, sessionKey string, svc fleet.Service) (*vie
}
return &viewer.Viewer{User: user, Session: session}, nil
}
func canPerformPasswordReset(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, fleet.ErrNoContext
}
if !vc.CanPerformPasswordReset() {
return nil, fleet.NewPermissionError("cannot reset password")
}
return next(ctx, request)
}
}

View file

@ -1,36 +0,0 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
)
////////////////////////////////////////////////////////////////////////////////
// Enroll Agent
////////////////////////////////////////////////////////////////////////////////
type enrollAgentRequest struct {
EnrollSecret string `json:"enroll_secret"`
HostIdentifier string `json:"host_identifier"`
HostDetails map[string](map[string]string) `json:"host_details"`
}
type enrollAgentResponse struct {
NodeKey string `json:"node_key,omitempty"`
Err error `json:"error,omitempty"`
}
func (r enrollAgentResponse) error() error { return r.Err }
func makeEnrollAgentEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(enrollAgentRequest)
nodeKey, err := svc.EnrollAgent(ctx, req.EnrollSecret, req.HostIdentifier, req.HostDetails)
if err != nil {
return enrollAgentResponse{Err: err}, nil
}
return enrollAgentResponse{NodeKey: nodeKey}, nil
}
}

View file

@ -1,163 +0,0 @@
package service
import (
"bytes"
"context"
"errors"
"html/template"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
)
////////////////////////////////////////////////////////////////////////////////
// Login
////////////////////////////////////////////////////////////////////////////////
type loginRequest struct {
Email string
Password string
}
type loginResponse struct {
User *fleet.User `json:"user,omitempty"`
AvailableTeams []*fleet.TeamSummary `json:"available_teams"`
Token string `json:"token,omitempty"`
Err error `json:"error,omitempty"`
}
func (r loginResponse) error() error { return r.Err }
func makeLoginEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(loginRequest)
user, token, err := svc.Login(ctx, req.Email, req.Password)
if err != nil {
return loginResponse{Err: err}, nil
}
// Add viewer context allow access to service teams for list of available teams
v, err := authViewer(ctx, token, svc)
if err != nil {
return loginResponse{Err: err}, nil
}
ctx = viewer.NewContext(ctx, *v)
availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user)
if err != nil {
if errors.Is(err, fleet.ErrMissingLicense) {
availableTeams = []*fleet.TeamSummary{}
} else {
return loginResponse{Err: err}, nil
}
}
return loginResponse{user, availableTeams, token, nil}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Logout
////////////////////////////////////////////////////////////////////////////////
type logoutResponse struct {
Err error `json:"error,omitempty"`
}
func (r logoutResponse) error() error { return r.Err }
func makeLogoutEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
err := svc.Logout(ctx)
if err != nil {
return logoutResponse{Err: err}, nil
}
return logoutResponse{}, nil
}
}
type initiateSSORequest struct {
RelayURL string `json:"relay_url"`
}
type initiateSSOResponse struct {
URL string `json:"url,omitempty"`
Err error `json:"error,omitempty"`
}
func (r initiateSSOResponse) error() error { return r.Err }
func makeInitiateSSOEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(initiateSSORequest)
idProviderURL, err := svc.InitiateSSO(ctx, req.RelayURL)
if err != nil {
return initiateSSOResponse{Err: err}, nil
}
return initiateSSOResponse{URL: idProviderURL}, nil
}
}
type callbackSSOResponse struct {
content string
Err error `json:"error,omitempty"`
}
func (r callbackSSOResponse) error() error { return r.Err }
// If html is present we return a web page
func (r callbackSSOResponse) html() string { return r.content }
func makeCallbackSSOEndpoint(svc fleet.Service, urlPrefix string) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
authResponse := request.(fleet.Auth)
session, err := svc.CallbackSSO(ctx, authResponse)
var resp callbackSSOResponse
if err != nil {
// redirect to login page on front end if there was some problem,
// errors should still be logged
session = &fleet.SSOSession{
RedirectURL: urlPrefix + "/login",
Token: "",
}
resp.Err = err
}
relayStateLoadPage := ` <html>
<script type='text/javascript'>
var redirectURL = {{ .RedirectURL }};
window.localStorage.setItem('FLEET::auth_token', '{{ .Token }}');
window.location = redirectURL;
</script>
<body>
Redirecting to Fleet at {{ .RedirectURL }} ...
</body>
</html>
`
tmpl, err := template.New("relayStateLoader").Parse(relayStateLoadPage)
if err != nil {
return nil, err
}
var writer bytes.Buffer
err = tmpl.Execute(&writer, session)
if err != nil {
return nil, err
}
resp.content = writer.String()
return resp, nil
}
}
type ssoSettingsResponse struct {
Settings *fleet.SessionSSOSettings `json:"settings,omitempty"`
Err error `json:"error,omitempty"`
}
func (r ssoSettingsResponse) error() error { return r.Err }
func makeSSOSettingsEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, unused interface{}) (interface{}, error) {
settings, err := svc.SSOSettings(ctx)
if err != nil {
return ssoSettingsResponse{Err: err}, nil
}
return ssoSettingsResponse{Settings: settings}, nil
}
}

View file

@ -1,100 +0,0 @@
package service
import (
"context"
"net/http"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
)
////////////////////////////////////////////////////////////////////////////////
// Create User With Invite
////////////////////////////////////////////////////////////////////////////////
func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(createUserRequest)
user, err := svc.CreateUserFromInvite(ctx, req.UserPayload)
if err != nil {
return createUserResponse{Err: err}, nil
}
return createUserResponse{User: user}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Reset Password
////////////////////////////////////////////////////////////////////////////////
type resetPasswordRequest struct {
PasswordResetToken string `json:"password_reset_token"`
NewPassword string `json:"new_password"`
}
type resetPasswordResponse struct {
Err error `json:"error,omitempty"`
}
func (r resetPasswordResponse) error() error { return r.Err }
func makeResetPasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(resetPasswordRequest)
err := svc.ResetPassword(ctx, req.PasswordResetToken, req.NewPassword)
return resetPasswordResponse{Err: err}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Perform Required Password Reset
////////////////////////////////////////////////////////////////////////////////
type performRequiredPasswordResetRequest struct {
Password string `json:"new_password"`
ID uint `json:"id"`
}
type performRequiredPasswordResetResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r performRequiredPasswordResetResponse) error() error { return r.Err }
func makePerformRequiredPasswordResetEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(performRequiredPasswordResetRequest)
user, err := svc.PerformRequiredPasswordReset(ctx, req.Password)
if err != nil {
return performRequiredPasswordResetResponse{Err: err}, nil
}
return performRequiredPasswordResetResponse{User: user}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Forgot Password
////////////////////////////////////////////////////////////////////////////////
type forgotPasswordRequest struct {
Email string `json:"email"`
}
type forgotPasswordResponse struct {
Err error `json:"error,omitempty"`
}
func (r forgotPasswordResponse) error() error { return r.Err }
func (r forgotPasswordResponse) status() int { return http.StatusAccepted }
func makeForgotPasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(forgotPasswordRequest)
// Any error returned by the service should not be returned to the
// client to prevent information disclosure (it will be logged in the
// server logs).
_ = svc.RequestPasswordReset(ctx, req.Email)
return forgotPasswordResponse{}, nil
}
}

View file

@ -67,6 +67,10 @@ func allFields(ifv reflect.Value) []reflect.StructField {
return fields
}
type requestDecoder interface {
DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error)
}
// makeDecoder creates a decoder for the type for the struct passed on. If the
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
@ -79,12 +83,22 @@ func allFields(ifv reflect.Value) []reflect.StructField {
// follows: `url:"some-id,optional"`.
// The "list_options" are optional by default and it'll ignore the optional
// portion of the tag.
//
// If iface implements the requestDecoder interface, it returns a function that
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
// own decoding.
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if iface == nil {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
return nil, nil
}
}
if rd, ok := iface.(requestDecoder); ok {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
return rd.DecodeRequest(ctx, r)
}
}
t := reflect.TypeOf(iface)
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface))
@ -272,6 +286,7 @@ type authEndpointer struct {
startingAtVersion string
endingAtVersion string
alternativePaths []string
customMiddleware []endpoint.Middleware
}
func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
@ -297,6 +312,16 @@ func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts [
}
}
func newNoAuthEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
return &authEndpointer{
svc: svc,
opts: opts,
r: r,
authFunc: unauthenticatedRequest,
versions: versions,
}
}
var pathReplacer = strings.NewReplacer(
"/", "_",
"{", "_",
@ -374,7 +399,15 @@ func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler
next := func(ctx context.Context, request interface{}) (interface{}, error) {
return f(ctx, request, e.svc)
}
return newServer(e.authFunc(e.svc, next), makeDecoder(v), e.opts)
endp := e.authFunc(e.svc, next)
// apply middleware in reverse order so that the first wraps the second
// wraps the third etc.
for i := len(e.customMiddleware) - 1; i >= 0; i-- {
mw := e.customMiddleware[i]
endp = mw(endp)
}
return newServer(endp, makeDecoder(v), e.opts)
}
func (e *authEndpointer) StartingAtVersion(version string) *authEndpointer {
@ -394,3 +427,9 @@ func (e *authEndpointer) WithAltPaths(paths ...string) *authEndpointer {
ae.alternativePaths = paths
return &ae
}
func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authEndpointer {
ae := *e
ae.customMiddleware = mws
return &ae
}

View file

@ -1,6 +1,7 @@
package service
import (
"bytes"
"context"
"io"
"net/http"
@ -14,6 +15,7 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/kit/endpoint"
kitlog "github.com/go-kit/kit/log"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
@ -251,7 +253,6 @@ func TestUniversalDecoderQueryAndListPlayNice(t *testing.T) {
}
func TestEndpointer(t *testing.T) {
r := mux.NewRouter()
ds := new(mock.Store)
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
@ -395,3 +396,72 @@ func TestEndpointer(t *testing.T) {
require.False(t, doesItMatch(route.method, route.path, false), route)
}
}
func TestEndpointerCustomMiddleware(t *testing.T) {
r := mux.NewRouter()
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
fleetAPIOptions := []kithttp.ServerOption{
kithttp.ServerBefore(
kithttp.PopulateRequestContext,
setRequestsContexts(svc),
),
kithttp.ServerErrorHandler(&errorHandler{kitlog.NewNopLogger()}),
kithttp.ServerErrorEncoder(encodeError),
kithttp.ServerAfter(
kithttp.SetContentType("application/json; charset=utf-8"),
logRequestEnd(kitlog.NewNopLogger()),
checkLicenseExpiration(svc),
),
}
var buf bytes.Buffer
e := newNoAuthEndpointer(svc, fleetAPIOptions, r, "v1")
e.GET("/none/", func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
buf.WriteString("H1")
return nil, nil
}, nil)
e.WithCustomMiddleware(
func(e endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
buf.WriteString("A")
return e(ctx, request)
}
},
func(e endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
buf.WriteString("B")
return e(ctx, request)
}
},
func(e endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
buf.WriteString("C")
return e(ctx, request)
}
},
).
GET("/mw/", func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
buf.WriteString("H2")
return nil, nil
}, nil)
req := httptest.NewRequest("GET", "/none/", nil)
var m1 mux.RouteMatch
require.True(t, r.Match(req, &m1))
rec := httptest.NewRecorder()
m1.Handler.ServeHTTP(rec, req)
require.Equal(t, "H1", buf.String())
buf.Reset()
req = httptest.NewRequest("GET", "/mw/", nil)
var m2 mux.RouteMatch
require.True(t, r.Match(req, &m2))
rec = httptest.NewRecorder()
m2.Handler.ServeHTTP(rec, req)
require.Equal(t, "ABCH2", buf.String())
}

View file

@ -4,7 +4,7 @@ import (
"context"
"errors"
"net/http"
"strings"
"regexp"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
@ -22,92 +22,6 @@ import (
otmiddleware "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux"
)
// FleetEndpoints is a collection of RPC endpoints implemented by the Fleet API.
type FleetEndpoints struct {
Login endpoint.Endpoint
Logout endpoint.Endpoint
ForgotPassword endpoint.Endpoint
ResetPassword endpoint.Endpoint
CreateUserWithInvite endpoint.Endpoint
PerformRequiredPasswordReset endpoint.Endpoint
VerifyInvite endpoint.Endpoint
EnrollAgent endpoint.Endpoint
CarveBlock endpoint.Endpoint
InitiateSSO endpoint.Endpoint
CallbackSSO endpoint.Endpoint
SSOSettings endpoint.Endpoint
}
// MakeFleetServerEndpoints creates the Fleet API endpoints.
func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore throttled.GCRAStore, logger kitlog.Logger) FleetEndpoints {
limiter := ratelimit.NewMiddleware(limitStore)
return FleetEndpoints{
Login: limiter.Limit(
throttled.RateQuota{MaxRate: throttled.PerMin(10), MaxBurst: 9})(
makeLoginEndpoint(svc),
),
Logout: logged(makeLogoutEndpoint(svc)),
ForgotPassword: limiter.Limit(
throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})(
logged(makeForgotPasswordEndpoint(svc)),
),
ResetPassword: logged(makeResetPasswordEndpoint(svc)),
CreateUserWithInvite: logged(makeCreateUserFromInviteEndpoint(svc)),
VerifyInvite: logged(makeVerifyInviteEndpoint(svc)),
InitiateSSO: logged(makeInitiateSSOEndpoint(svc)),
CallbackSSO: logged(makeCallbackSSOEndpoint(svc, urlPrefix)),
SSOSettings: logged(makeSSOSettingsEndpoint(svc)),
// PerformRequiredPasswordReset needs only to authenticate the
// logged in user
PerformRequiredPasswordReset: logged(canPerformPasswordReset(makePerformRequiredPasswordResetEndpoint(svc))),
// Osquery endpoints
EnrollAgent: logged(makeEnrollAgentEndpoint(svc)),
// For some reason osquery does not provide a node key with the block
// data. Instead the carve session ID should be verified in the service
// method.
CarveBlock: logged(makeCarveBlockEndpoint(svc)),
}
}
type fleetHandlers struct {
Login http.Handler
Logout http.Handler
ForgotPassword http.Handler
ResetPassword http.Handler
CreateUserWithInvite http.Handler
PerformRequiredPasswordReset http.Handler
VerifyInvite http.Handler
EnrollAgent http.Handler
CarveBlock http.Handler
InitiateSSO http.Handler
CallbackSSO http.Handler
SettingsSSO http.Handler
}
func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandlers {
newServer := func(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc) http.Handler {
e = authzcheck.NewMiddleware().AuthzCheck()(e)
return kithttp.NewServer(e, decodeFn, encodeResponse, opts...)
}
return &fleetHandlers{
Login: newServer(e.Login, decodeLoginRequest),
Logout: newServer(e.Logout, decodeNoParamsRequest),
ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest),
ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest),
CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest),
PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest),
VerifyInvite: newServer(e.VerifyInvite, decodeVerifyInviteRequest),
EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest),
CarveBlock: newServer(e.CarveBlock, decodeCarveBlockRequest),
InitiateSSO: newServer(e.InitiateSSO, decodeInitiateSSORequest),
CallbackSSO: newServer(e.CallbackSSO, decodeCallbackSSORequest),
SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest),
}
}
type errorHandler struct {
logger kitlog.Logger
}
@ -176,18 +90,16 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log
),
}
fleetEndpoints := MakeFleetServerEndpoints(svc, config.Server.URLPrefix, limitStore, logger)
fleetHandlers := makeKitHandlers(fleetEndpoints, fleetAPIOptions)
r := mux.NewRouter()
if config.Logging.TracingEnabled && config.Logging.TracingType == "opentelemetry" {
r.Use(otmiddleware.Middleware("fleet"))
}
attachFleetAPIRoutes(r, fleetHandlers)
attachNewStyleFleetAPIRoutes(r, svc, logger, fleetAPIOptions)
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions)
// Results endpoint is handled different due to websockets use
// TODO: this would not work once v1 is deprecated - note that the handler too uses the /v1/ path
// and this routes on path prefix, not exact path (unlike the authendpointer struct).
r.PathPrefix("/api/v1/fleet/results/").
Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, logger)).
Name("distributed_query_results")
@ -277,22 +189,9 @@ func addMetrics(r *mux.Router) {
r.Walk(walkFn)
}
func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/login", h.Login).Methods("POST").Name("login")
r.Handle("/api/v1/fleet/logout", h.Logout).Methods("POST").Name("logout")
r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password")
r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password")
r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset")
r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso")
r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config")
r.Handle("/api/v1/fleet/sso/callback", h.CallbackSSO).Methods("POST").Name("callback_sso")
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite")
r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent")
r.Handle("/api/v1/osquery/carve/block", h.CarveBlock).Methods("POST").Name("carve_block")
}
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption) {
func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlog.Logger, opts []kithttp.ServerOption) {
// user-authenticated endpoints
ue := newUserAuthenticatedEndpointer(svc, opts, r, "v1")
@ -441,10 +340,42 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlo
he.POST("/api/_version_/osquery/distributed/write", submitDistributedQueryResultsEndpoint, submitDistributedQueryResultsRequestShim{})
he.POST("/api/_version_/osquery/carve/begin", carveBeginEndpoint, carveBeginRequest{})
he.POST("/api/_version_/osquery/log", submitLogsEndpoint, submitLogsRequest{})
// unauthenticated endpoints - most of those are either login-related,
// invite-related or host-enrolling. So they typically do some kind of
// one-time authentication by verifying that a valid secret token is provided
// with the request.
ne := newNoAuthEndpointer(svc, opts, r, "v1")
ne.POST("/api/_version_/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{})
// For some reason osquery does not provide a node key with the block data.
// Instead the carve session ID should be verified in the service method.
ne.POST("/api/_version_/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{})
ne.POST("/api/_version_/fleet/perform_required_password_reset", performRequiredPasswordResetEndpoint, performRequiredPasswordResetRequest{})
ne.POST("/api/_version_/fleet/users", createUserFromInviteEndpoint, createUserRequest{})
ne.GET("/api/_version_/fleet/invites/{token}", verifyInviteEndpoint, verifyInviteRequest{})
ne.POST("/api/_version_/fleet/reset_password", resetPasswordEndpoint, resetPasswordRequest{})
ne.POST("/api/_version_/fleet/logout", logoutEndpoint, nil)
ne.POST("/api/_version_/fleet/sso", initiateSSOEndpoint, initiateSSORequest{})
ne.POST("/api/_version_/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{})
ne.GET("/api/_version_/fleet/sso", settingsSSOEndpoint, nil)
limiter := ratelimit.NewMiddleware(limitStore)
ne.
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
ne.
WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: throttled.PerMin(10), MaxBurst: 9})).
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
}
// TODO: this duplicates the one in makeKitHandler
func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []kithttp.ServerOption) http.Handler {
// TODO: some handlers don't have authz checks, and because the SkipAuth call is done only in the
// endpoint handler, any middleware that raises errors before the handler is reached will end up
// returning authz check missing instead of the more relevant error. Should be addressed as part
// of #4406.
e = authzcheck.NewMiddleware().AuthzCheck()(e)
return kithttp.NewServer(e, decodeFn, encodeResponse, opts...)
}
@ -453,15 +384,19 @@ func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []k
// If setup hasn't been completed it serves the API with a setup middleware.
// If the server is already configured, the default API handler is exposed.
func WithSetup(svc fleet.Service, logger kitlog.Logger, next http.Handler) http.HandlerFunc {
rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`)
return func(w http.ResponseWriter, r *http.Request) {
configRouter := http.NewServeMux()
// TODO: hard-codes v1 as a path fragment, which would probably not work once we
// deprecate it for newer versions, unless we want to treat the setup differently (not versioned?)
configRouter.Handle("/api/v1/setup", kithttp.NewServer(
makeSetupEndpoint(svc),
decodeSetupRequest,
encodeResponse,
))
// whitelist osqueryd endpoints
if strings.HasPrefix(r.URL.Path, "/api/v1/osquery") {
if rxOsquery.MatchString(r.URL.Path) {
next.ServeHTTP(w, r)
return
}

View file

@ -16,63 +16,10 @@ import (
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/throttled/throttled/v2/store/memstore"
)
func TestAPIRoutes(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
r := mux.NewRouter()
limitStore, _ := memstore.New(0)
ke := MakeFleetServerEndpoints(svc, "", limitStore, kitlog.NewNopLogger())
kh := makeKitHandlers(ke, nil)
attachFleetAPIRoutes(r, kh)
handler := mux.NewRouter()
handler.PathPrefix("/").Handler(r)
routes := []struct {
verb string
uri string
}{
{
verb: "POST",
uri: "/api/v1/fleet/users",
},
{
verb: "POST",
uri: "/api/v1/fleet/login",
},
{
verb: "POST",
uri: "/api/v1/fleet/forgot_password",
},
{
verb: "POST",
uri: "/api/v1/fleet/reset_password",
},
{
verb: "POST",
uri: "/api/v1/osquery/enroll",
},
}
for _, route := range routes {
t.Run(fmt.Sprintf(": %v", route.uri), func(st *testing.T) {
recorder := httptest.NewRecorder()
handler.ServeHTTP(
recorder,
httptest.NewRequest(route.verb, route.uri, nil),
)
assert.NotEqual(st, 404, recorder.Code)
assert.NotEqual(st, 405, recorder.Code, route.verb) // if it matches a path but with wrong verb
})
}
}
func TestAPIRoutesConflicts(t *testing.T) {
ds := new(mock.Store)

View file

@ -15,11 +15,13 @@ import (
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/ghodss/yaml"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@ -978,6 +980,20 @@ func (s *integrationTestSuite) TestInvites() {
require.NotZero(t, createInviteResp.Invite.ID)
validInvite := *createInviteResp.Invite
// create user from valid invite - the token was not returned via the
// response's json, must get it from the db
inv, err := s.ds.Invite(context.Background(), validInvite.ID)
require.NoError(t, err)
validInviteToken := inv.Token
// verify the token with valid invite
var verifyInvResp verifyInviteResponse
s.DoJSON("GET", "/api/v1/fleet/invites/"+validInviteToken, nil, http.StatusOK, &verifyInvResp)
require.Equal(t, validInvite.ID, verifyInvResp.Invite.ID)
// verify the token with an invalid invite
s.DoJSON("GET", "/api/v1/fleet/invites/invalid", nil, http.StatusNotFound, &verifyInvResp)
// create invite without an email
createInviteReq = createInviteRequest{InvitePayload: fleet.InvitePayload{
Email: nil,
@ -1076,9 +1092,21 @@ func (s *integrationTestSuite) TestInvites() {
require.Len(t, verify.Teams, 1)
assert.Equal(t, team.ID, verify.Teams[0].ID)
var createFromInviteResp createUserResponse
s.DoJSON("POST", "/api/v1/fleet/users", fleet.UserPayload{
Name: ptr.String("Full Name"),
Password: ptr.String("pass1word!"),
Email: ptr.String("a@b.c"),
InviteToken: ptr.String(validInviteToken),
}, http.StatusOK, &createFromInviteResp)
// keep the invite token from the other valid invite (before deleting it)
inv, err = s.ds.Invite(context.Background(), createInviteResp.Invite.ID)
require.NoError(t, err)
deletedInviteToken := inv.Token
// delete an existing invite
var delResp deleteInviteResponse
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", validInvite.ID), nil, http.StatusOK, &delResp)
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", createInviteResp.Invite.ID), nil, http.StatusOK, &delResp)
// list invites, is now empty
@ -1088,6 +1116,111 @@ func (s *integrationTestSuite) TestInvites() {
// delete a now non-existing invite
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", validInvite.ID), nil, http.StatusNotFound, &delResp)
// create user from never used but deleted invite
s.DoJSON("POST", "/api/v1/fleet/users", fleet.UserPayload{
Name: ptr.String("Full Name"),
Password: ptr.String("pass1word!"),
Email: ptr.String("a@b.c"),
InviteToken: ptr.String(deletedInviteToken),
}, http.StatusNotFound, &createFromInviteResp)
}
func (s *integrationTestSuite) TestCreateUserFromInviteErrors() {
t := s.T()
// create a valid invite
createInviteReq := createInviteRequest{InvitePayload: fleet.InvitePayload{
Email: ptr.String("a@b.c"),
Name: ptr.String("A"),
GlobalRole: null.StringFrom(fleet.RoleObserver),
}}
createInviteResp := createInviteResponse{}
s.DoJSON("POST", "/api/v1/fleet/invites", createInviteReq, http.StatusOK, &createInviteResp)
// make sure to delete it on exit
defer func() {
var delResp deleteInviteResponse
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", createInviteResp.Invite.ID), nil, http.StatusOK, &delResp)
}()
// the token is not returned via the response's json, must get it from the db
invite, err := s.ds.Invite(context.Background(), createInviteResp.Invite.ID)
require.NoError(t, err)
cases := []struct {
desc string
pld fleet.UserPayload
want int
}{
{
"empty name",
fleet.UserPayload{
Name: ptr.String(""),
Password: ptr.String("pass1word!"),
Email: ptr.String("a@b.c"),
InviteToken: ptr.String(invite.Token),
},
http.StatusUnprocessableEntity,
},
{
"empty email",
fleet.UserPayload{
Name: ptr.String("Name"),
Password: ptr.String("pass1word!"),
Email: ptr.String(""),
InviteToken: ptr.String(invite.Token),
},
http.StatusUnprocessableEntity,
},
{
"empty password",
fleet.UserPayload{
Name: ptr.String("Name"),
Password: ptr.String(""),
Email: ptr.String("a@b.c"),
InviteToken: ptr.String(invite.Token),
},
http.StatusUnprocessableEntity,
},
{
"empty token",
fleet.UserPayload{
Name: ptr.String("Name"),
Password: ptr.String("pass1word!"),
Email: ptr.String("a@b.c"),
InviteToken: ptr.String(""),
},
http.StatusUnprocessableEntity,
},
{
"invalid token",
fleet.UserPayload{
Name: ptr.String("Name"),
Password: ptr.String("pass1word!"),
Email: ptr.String("a@b.c"),
InviteToken: ptr.String("invalid"),
},
http.StatusNotFound,
},
{
"invalid password",
fleet.UserPayload{
Name: ptr.String("Name"),
Password: ptr.String("password"), // no number or symbol
Email: ptr.String("a@b.c"),
InviteToken: ptr.String(invite.Token),
},
http.StatusUnprocessableEntity,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
var resp createUserResponse
s.DoJSON("POST", "/api/v1/fleet/users", c.pld, c.want, &resp)
})
}
}
func (s *integrationTestSuite) TestGetHostSummary() {
@ -2302,6 +2435,9 @@ func (s *integrationTestSuite) TestLabelSpecs() {
}
func (s *integrationTestSuite) TestUsers() {
// ensure that on exit, the admin token is used
defer func() { s.token = s.getTestAdminToken() }()
t := s.T()
// list existing users
@ -2324,14 +2460,16 @@ func (s *integrationTestSuite) TestUsers() {
// create a new user
var createResp createUserResponse
userRawPwd := "pass"
params := fleet.UserPayload{
Name: ptr.String("extra"),
Email: ptr.String("extra@asd.com"),
Password: ptr.String("pass"),
Password: ptr.String(userRawPwd),
GlobalRole: ptr.String(fleet.RoleObserver),
}
s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp)
assert.NotZero(t, createResp.User.ID)
assert.True(t, createResp.User.AdminForcedPasswordReset)
u := *createResp.User
// login as that user and check that teams info is empty
@ -2407,6 +2545,46 @@ func (s *integrationTestSuite) TestUsers() {
}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), params, http.StatusNotFound, &modResp)
// perform a required password change as the user themselves
s.token = s.getTestToken(u.Email, userRawPwd)
var perfPwdResetResp performRequiredPasswordResetResponse
newRawPwd := "new_password!"
s.DoJSON("POST", "/api/v1/fleet/perform_required_password_reset", performRequiredPasswordResetRequest{
Password: newRawPwd,
ID: u.ID,
}, http.StatusOK, &perfPwdResetResp)
assert.False(t, perfPwdResetResp.User.AdminForcedPasswordReset)
oldUserRawPwd := userRawPwd
userRawPwd = newRawPwd
// perform a required password change again, this time it fails as there is no request pending
perfPwdResetResp = performRequiredPasswordResetResponse{}
newRawPwd = "new_password2!"
s.DoJSON("POST", "/api/v1/fleet/perform_required_password_reset", performRequiredPasswordResetRequest{
Password: newRawPwd,
ID: u.ID,
}, http.StatusInternalServerError, &perfPwdResetResp) // TODO: should be 40?, see #4406
s.token = s.getTestAdminToken()
// login as that user to verify that the new password is active (userRawPwd was updated to the new pwd)
loginResp = loginResponse{}
s.DoJSON("POST", "/api/v1/fleet/login", loginRequest{Email: u.Email, Password: userRawPwd}, http.StatusOK, &loginResp)
require.Equal(t, loginResp.User.ID, u.ID)
// logout for that user
s.token = loginResp.Token
var logoutResp logoutResponse
s.DoJSON("POST", "/api/v1/fleet/logout", nil, http.StatusOK, &logoutResp)
// logout again, even though not logged in
s.DoJSON("POST", "/api/v1/fleet/logout", nil, http.StatusInternalServerError, &logoutResp) // TODO: should be OK even if not logged in, see #4406.
s.token = s.getTestAdminToken()
// login as that user with previous pwd fails
loginResp = loginResponse{}
s.DoJSON("POST", "/api/v1/fleet/login", loginRequest{Email: u.Email, Password: oldUserRawPwd}, http.StatusUnauthorized, &loginResp)
// require a password reset
var reqResetResp requirePasswordResetResponse
s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID), map[string]bool{"require": true}, http.StatusOK, &reqResetResp)
@ -3094,6 +3272,245 @@ func (s *integrationTestSuite) TestOsqueryConfig() {
assert.Contains(t, errRes["error"], "invalid node key")
}
func (s *integrationTestSuite) TestEnrollHost() {
t := s.T()
// set the enroll secret
var applyResp applyEnrollSecretSpecResponse
s.DoJSON("POST", "/api/v1/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{
Spec: &fleet.EnrollSecretSpec{
Secrets: []*fleet.EnrollSecret{{Secret: t.Name()}},
},
}, http.StatusOK, &applyResp)
// invalid enroll secret fails
j, err := json.Marshal(&enrollAgentRequest{
EnrollSecret: "nosuchsecret",
HostIdentifier: "abcd",
})
require.NoError(t, err)
s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusUnauthorized)
// valid enroll secret succeeds
j, err = json.Marshal(&enrollAgentRequest{
EnrollSecret: t.Name(),
HostIdentifier: t.Name(),
})
require.NoError(t, err)
var resp enrollAgentResponse
hres := s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusOK)
defer hres.Body.Close()
require.NoError(t, json.NewDecoder(hres.Body).Decode(&resp))
require.NotEmpty(t, resp.NodeKey)
}
func (s *integrationTestSuite) TestCarve() {
t := s.T()
hosts := s.createHosts(t)
// begin a carve with an invalid node key
var errRes map[string]interface{}
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
NodeKey: hosts[0].NodeKey + "zzz",
BlockCount: 1,
BlockSize: 1,
CarveSize: 1,
CarveId: "c1",
}, http.StatusUnauthorized, &errRes)
assert.Contains(t, errRes["error"], "invalid node key")
// invalid carve size
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
NodeKey: hosts[0].NodeKey,
BlockCount: 3,
BlockSize: 3,
CarveSize: 0,
CarveId: "c1",
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
assert.Contains(t, errRes["error"], "carve_size must be greater")
// invalid block size too big
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
NodeKey: hosts[0].NodeKey,
BlockCount: 3,
BlockSize: maxBlockSize + 1,
CarveSize: maxCarveSize,
CarveId: "c1",
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
assert.Contains(t, errRes["error"], "block_size exceeds max")
// invalid carve size too big
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
NodeKey: hosts[0].NodeKey,
BlockCount: 3,
BlockSize: maxBlockSize,
CarveSize: maxCarveSize + 1,
CarveId: "c1",
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
assert.Contains(t, errRes["error"], "carve_size exceeds max")
// invalid carve size, does not match blocks
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
NodeKey: hosts[0].NodeKey,
BlockCount: 3,
BlockSize: 3,
CarveSize: 1,
CarveId: "c1",
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
assert.Contains(t, errRes["error"], "carve_size does not match")
// valid carve begin
var beginResp carveBeginResponse
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
NodeKey: hosts[0].NodeKey,
BlockCount: 3,
BlockSize: 3,
CarveSize: 8,
CarveId: "c1",
RequestId: "r1",
}, http.StatusOK, &beginResp)
require.NotEmpty(t, beginResp.SessionId)
sid := beginResp.SessionId
// sending a block with invalid session id
var blockResp carveBlockResponse
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 1,
SessionId: sid + "zz",
RequestId: "??",
Data: []byte("p1."),
}, http.StatusNotFound, &blockResp)
// sending a block with valid session id but invalid request id
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 1,
SessionId: sid,
RequestId: "??",
Data: []byte("p1."),
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
// sending a block with unexpected block id (expects 0, got 1)
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 1,
SessionId: sid,
RequestId: "r1",
Data: []byte("p1."),
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
// sending a block with valid payload, block 0
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 0,
SessionId: sid,
RequestId: "r1",
Data: []byte("p1."),
}, http.StatusOK, &blockResp)
require.True(t, blockResp.Success)
// sending next block
blockResp = carveBlockResponse{}
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 1,
SessionId: sid,
RequestId: "r1",
Data: []byte("p2."),
}, http.StatusOK, &blockResp)
require.True(t, blockResp.Success)
// sending already-sent block again
blockResp = carveBlockResponse{}
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 1,
SessionId: sid,
RequestId: "r1",
Data: []byte("p2."),
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
// sending final block with too many bytes
blockResp = carveBlockResponse{}
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 2,
SessionId: sid,
RequestId: "r1",
Data: []byte("p3extra"),
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
// sending actual final block
blockResp = carveBlockResponse{}
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 2,
SessionId: sid,
RequestId: "r1",
Data: []byte("p3"),
}, http.StatusOK, &blockResp)
require.True(t, blockResp.Success)
// sending unexpected block
blockResp = carveBlockResponse{}
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
BlockId: 3,
SessionId: sid,
RequestId: "r1",
Data: []byte("p4."),
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
}
func (s *integrationTestSuite) TestPasswordReset() {
t := s.T()
// create a new user
var createResp createUserResponse
userRawPwd := "passw0rd!"
params := fleet.UserPayload{
Name: ptr.String("forgotpwd"),
Email: ptr.String("forgotpwd@example.com"),
Password: ptr.String(userRawPwd),
GlobalRole: ptr.String(fleet.RoleObserver),
}
s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp)
require.NotZero(t, createResp.User.ID)
u := *createResp.User
_ = u
// request forgot password, invalid email
res := s.DoRawNoAuth("POST", "/api/v1/fleet/forgot_password", jsonMustMarshal(t, forgotPasswordRequest{Email: "invalid@asd.com"}), http.StatusAccepted)
res.Body.Close()
// TODO: tested manually (adds too much time to the test), works but hitting the rate
// limit returns 500 instead of 429, see #4406. We get the authz check missing error instead.
//// trigger the rate limit with a batch of requests in a short burst
//for i := 0; i < 20; i++ {
// s.DoJSON("POST", "/api/v1/fleet/forgot_password", forgotPasswordRequest{Email: "invalid@asd.com"}, http.StatusAccepted, &forgotResp)
//}
// request forgot password, valid email
res = s.DoRawNoAuth("POST", "/api/v1/fleet/forgot_password", jsonMustMarshal(t, forgotPasswordRequest{Email: u.Email}), http.StatusAccepted)
res.Body.Close()
var token string
mysql.ExecAdhocSQL(t, s.ds, func(db sqlx.ExtContext) error {
return sqlx.GetContext(context.Background(), db, &token, "SELECT token FROM password_reset_requests WHERE user_id = ?", u.ID)
})
// proceed with reset password
userNewPwd := "newpassw0rd!"
res = s.DoRawNoAuth("POST", "/api/v1/fleet/reset_password", jsonMustMarshal(t, resetPasswordRequest{PasswordResetToken: token, NewPassword: userNewPwd}), http.StatusOK)
res.Body.Close()
// attempt it again with already-used token
userUnusedPwd := "unusedpassw0rd!"
res = s.DoRawNoAuth("POST", "/api/v1/fleet/reset_password", jsonMustMarshal(t, resetPasswordRequest{PasswordResetToken: token, NewPassword: userUnusedPwd}), http.StatusInternalServerError) // TODO: should be 40x, see #4406
res.Body.Close()
// login with the old password, should not succeed
res = s.DoRawNoAuth("POST", "/api/v1/fleet/login", jsonMustMarshal(t, loginRequest{Email: u.Email, Password: userRawPwd}), http.StatusUnauthorized)
res.Body.Close()
// login with the new password, should succeed
res = s.DoRawNoAuth("POST", "/api/v1/fleet/login", jsonMustMarshal(t, loginRequest{Email: u.Email, Password: userNewPwd}), http.StatusOK)
res.Body.Close()
}
// creates a session and returns it, its key is to be passed as authorization header.
func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session {
key := make([]byte, 64)
@ -3116,3 +3533,9 @@ func cleanupQuery(s *integrationTestSuite, queryID uint) {
var delResp deleteQueryByIDResponse
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/queries/id/%d", queryID), nil, http.StatusOK, &delResp)
}
func jsonMustMarshal(t testing.TB, v interface{}) []byte {
b, err := json.Marshal(v)
require.NoError(t, err)
return b
}

View file

@ -10,6 +10,7 @@ import (
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mail"
@ -264,3 +265,51 @@ func (svc *Service) DeleteInvite(ctx context.Context, id uint) error {
}
return svc.ds.DeleteInvite(ctx, id)
}
////////////////////////////////////////////////////////////////////////////////
// Verify invite
////////////////////////////////////////////////////////////////////////////////
type verifyInviteRequest struct {
Token string `url:"token"`
}
type verifyInviteResponse struct {
Invite *fleet.Invite `json:"invite"`
Err error `json:"error,omitempty"`
}
func (r verifyInviteResponse) error() error { return r.Err }
func verifyInviteEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*verifyInviteRequest)
invite, err := svc.VerifyInvite(ctx, req.Token)
if err != nil {
return verifyInviteResponse{Err: err}, nil
}
return verifyInviteResponse{Invite: invite}, nil
}
func (svc *Service) VerifyInvite(ctx context.Context, token string) (*fleet.Invite, error) {
// skipauth: There is no viewer context at this point. We rely on verifying
// the invite for authNZ.
svc.authz.SkipAuthorization(ctx)
logging.WithExtras(ctx, "token", token)
invite, err := svc.ds.InviteByToken(ctx, token)
if err != nil {
return nil, err
}
if invite.Token != token {
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite Token does not match Email Address.")
}
expiresAt := invite.CreatedAt.Add(svc.config.App.InviteTokenValidityPeriod)
if svc.clock.Now().After(expiresAt) {
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite token has expired.")
}
return invite, nil
}

84
server/service/jitter.go Normal file
View file

@ -0,0 +1,84 @@
package service
import (
"sync"
"time"
)
// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value
// that is properly balanced. Balance in this context means that hosts would be distributed uniformly
// across the total jitter time so there are no spikes.
// The way this structure works is as follows:
// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to
// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should
// end up with 1 per bucket.
// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount
// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we
// increase the maxCapacity by 1 for all buckets, and start placing hosts.
// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is
// running.
// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the
// next one will be tried. If all buckets are full, then capacity gets increased and the bucket
// selection process restarts.
// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of
// minutes added to the host check in time.
// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to
// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case.
// In the worst possible case that all hosts start at the same time, max jitter percent can be set to
// 100, and this method will distribute hosts evenly.
// The main caveat of this approach is that it works at the fleet instance. So depending on what
// instance gets chosen by the load balancer, the jitter might be different. However, load tests have
// shown that the distribution in practice is pretty balance even when all hosts try to check in at
// the same time.
type jitterHashTable struct {
mu sync.Mutex
maxCapacity int
bucketCount int
buckets map[int]int
cache map[uint]time.Duration
}
func newJitterHashTable(bucketCount int) *jitterHashTable {
if bucketCount == 0 {
bucketCount = 1
}
return &jitterHashTable{
maxCapacity: 1,
bucketCount: bucketCount,
buckets: make(map[int]int),
cache: make(map[uint]time.Duration),
}
}
func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration {
// if no jitter is configured just return 0
if jh.bucketCount <= 1 {
return 0
}
jh.mu.Lock()
if jitter, ok := jh.cache[hostID]; ok {
jh.mu.Unlock()
return jitter
}
for i := 0; i < jh.bucketCount; i++ {
possibleBucket := (int(hostID) + i) % jh.bucketCount
// if the next bucket has capacity, great!
if jh.buckets[possibleBucket] < jh.maxCapacity {
jh.buckets[possibleBucket]++
jitter := time.Duration(possibleBucket) * time.Minute
jh.cache[hostID] = jitter
jh.mu.Unlock()
return jitter
}
}
// otherwise, bump the capacity and restart the process
jh.maxCapacity++
jh.mu.Unlock()
return jh.jitterForHost(hostID)
}

View file

@ -0,0 +1,52 @@
package service
import (
crand "crypto/rand"
"math"
"math/big"
"testing"
"github.com/stretchr/testify/require"
)
func TestJitterForHost(t *testing.T) {
jh := newJitterHashTable(30)
histogram := make(map[int64]int)
hostCount := 3000
for i := 0; i < hostCount; i++ {
hostID, err := crand.Int(crand.Reader, big.NewInt(10000))
require.NoError(t, err)
jitter := jh.jitterForHost(uint(hostID.Int64() + 10000))
jitterMinutes := int64(jitter.Minutes())
histogram[jitterMinutes]++
}
min, max := math.MaxInt, 0
for jitterMinutes, count := range histogram {
if count < min {
min = count
}
if count > max {
max = count
}
t.Logf("jitterMinutes=%d \t count=%d\n", jitterMinutes, count)
}
variation := max - min
t.Logf("min=%d \t max=%d \t variation=%d\n", min, max, variation)
// check that variation is below 1% of the total amount of hosts
require.Less(t, variation, int(float32(hostCount)/0.01))
}
func TestNoJitter(t *testing.T) {
jh := newJitterHashTable(0)
hostCount := 3000
for i := 0; i < hostCount; i++ {
hostID, err := crand.Int(crand.Reader, big.NewInt(10000))
require.NoError(t, err)
jitter := jh.jitterForHost(uint(hostID.Int64() + 10000))
jitterMinutes := int64(jitter.Minutes())
require.Equal(t, int64(0), jitterMinutes)
}
}

View file

@ -4,8 +4,6 @@ import (
"context"
"fmt"
"net/http"
"reflect"
"runtime"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/go-kit/kit/endpoint"
@ -28,19 +26,15 @@ func NewMiddleware(store throttled.GCRAStore) *Middleware {
}
// Limit returns a new middleware function enforcing the provided quota.
func (m *Middleware) Limit(quota throttled.RateQuota) endpoint.Middleware {
func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
// Get function name to use as a key for rate limiting (each wrapped function
// gets a separate quota)
funcName := runtime.FuncForPC(reflect.ValueOf(next).Pointer()).Name()
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
if err != nil {
panic(err)
}
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
limited, result, err := limiter.RateLimit(funcName, 1)
limited, result, err := limiter.RateLimit(keyName, 1)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "check rate limit")
}

View file

@ -20,6 +20,7 @@ func TestLimit(t *testing.T) {
limiter := NewMiddleware(store)
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
wrapped := limiter.Limit(
"test_limit",
throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
)(endpoint)

View file

@ -8,8 +8,10 @@ import (
"regexp"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
@ -22,6 +24,255 @@ import (
"github.com/spf13/cast"
)
type osqueryError struct {
message string
nodeInvalid bool
}
func (e osqueryError) Error() string {
return e.message
}
func (e osqueryError) NodeInvalid() bool {
return e.nodeInvalid
}
func (svc *Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
if nodeKey == "" {
return nil, false, osqueryError{
message: "authentication error: missing node key",
nodeInvalid: true,
}
}
host, err := svc.ds.LoadHostByNodeKey(ctx, nodeKey)
switch {
case err == nil:
// OK
case fleet.IsNotFound(err):
return nil, false, osqueryError{
message: "authentication error: invalid node key: " + nodeKey,
nodeInvalid: true,
}
default:
return nil, false, osqueryError{
message: "authentication error: " + err.Error(),
}
}
// Update the "seen" time used to calculate online status. These updates are
// batched for MySQL performance reasons. Because this is done
// asynchronously, it is possible for the server to shut down before
// updating the seen time for these hosts. This seems to be an acceptable
// tradeoff as an online host will continue to check in and quickly be
// marked online again.
svc.seenHostSet.addHostID(host.ID)
host.SeenTime = svc.clock.Now()
return host, svc.debugEnabledForHost(ctx, host.ID), nil
}
////////////////////////////////////////////////////////////////////////////////
// Enroll Agent
////////////////////////////////////////////////////////////////////////////////
type enrollAgentRequest struct {
EnrollSecret string `json:"enroll_secret"`
HostIdentifier string `json:"host_identifier"`
HostDetails map[string](map[string]string) `json:"host_details"`
}
type enrollAgentResponse struct {
NodeKey string `json:"node_key,omitempty"`
Err error `json:"error,omitempty"`
}
func (r enrollAgentResponse) error() error { return r.Err }
func enrollAgentEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*enrollAgentRequest)
nodeKey, err := svc.EnrollAgent(ctx, req.EnrollSecret, req.HostIdentifier, req.HostDetails)
if err != nil {
return enrollAgentResponse{Err: err}, nil
}
return enrollAgentResponse{NodeKey: nodeKey}, nil
}
func (svc *Service) EnrollAgent(ctx context.Context, enrollSecret, hostIdentifier string, hostDetails map[string](map[string]string)) (string, error) {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
logging.WithExtras(ctx, "hostIdentifier", hostIdentifier)
secret, err := svc.ds.VerifyEnrollSecret(ctx, enrollSecret)
if err != nil {
return "", osqueryError{
message: "enroll failed: " + err.Error(),
nodeInvalid: true,
}
}
nodeKey, err := server.GenerateRandomText(svc.config.Osquery.NodeKeySize)
if err != nil {
return "", osqueryError{
message: "generate node key failed: " + err.Error(),
nodeInvalid: true,
}
}
hostIdentifier = getHostIdentifier(svc.logger, svc.config.Osquery.HostIdentifier, hostIdentifier, hostDetails)
host, err := svc.ds.EnrollHost(ctx, hostIdentifier, nodeKey, secret.TeamID, svc.config.Osquery.EnrollCooldown)
if err != nil {
return "", osqueryError{message: "save enroll failed: " + err.Error(), nodeInvalid: true}
}
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return "", osqueryError{message: "app config load failed: " + err.Error(), nodeInvalid: true}
}
// Save enrollment details if provided
detailQueries := osquery_utils.GetDetailQueries(appConfig, svc.config)
save := false
if r, ok := hostDetails["os_version"]; ok {
err := detailQueries["os_version"].IngestFunc(svc.logger, host, []map[string]string{r})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "Ingesting os_version")
}
save = true
}
if r, ok := hostDetails["osquery_info"]; ok {
err := detailQueries["osquery_info"].IngestFunc(svc.logger, host, []map[string]string{r})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "Ingesting osquery_info")
}
save = true
}
if r, ok := hostDetails["system_info"]; ok {
err := detailQueries["system_info"].IngestFunc(svc.logger, host, []map[string]string{r})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "Ingesting system_info")
}
save = true
}
if save {
if appConfig.ServerSettings.DeferredSaveHost {
go svc.serialUpdateHost(host)
} else {
if err := svc.ds.UpdateHost(ctx, host); err != nil {
return "", ctxerr.Wrap(ctx, err, "save host in enroll agent")
}
}
}
return nodeKey, nil
}
var counter = int64(0)
func (svc *Service) serialUpdateHost(host *fleet.Host) {
newVal := atomic.AddInt64(&counter, 1)
defer func() {
atomic.AddInt64(&counter, -1)
}()
level.Debug(svc.logger).Log("background", newVal)
ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second)
defer cancelFunc()
err := svc.ds.SerialUpdateHost(ctx, host)
if err != nil {
level.Error(svc.logger).Log("background-err", err)
}
}
func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier string, details map[string](map[string]string)) string {
switch identifierOption {
case "provided":
// Use the host identifier already provided in the request.
return providedIdentifier
case "instance":
r, ok := details["osquery_info"]
if !ok {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing osquery_info",
"identifier", "instance",
)
} else if r["instance_id"] == "" {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing instance_id in osquery_info",
"identifier", "instance",
)
} else {
return r["instance_id"]
}
case "uuid":
r, ok := details["osquery_info"]
if !ok {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing osquery_info",
"identifier", "uuid",
)
} else if r["uuid"] == "" {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing instance_id in osquery_info",
"identifier", "uuid",
)
} else {
return r["uuid"]
}
case "hostname":
r, ok := details["system_info"]
if !ok {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing system_info",
"identifier", "hostname",
)
} else if r["hostname"] == "" {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing instance_id in system_info",
"identifier", "hostname",
)
} else {
return r["hostname"]
}
default:
panic("Unknown option for host_identifier: " + identifierOption)
}
return providedIdentifier
}
func (svc *Service) debugEnabledForHost(ctx context.Context, id uint) bool {
hlogger := log.With(svc.logger, "host-id", id)
ac, err := svc.ds.AppConfig(ctx)
if err != nil {
level.Debug(hlogger).Log("err", ctxerr.Wrap(ctx, err, "getting app config for host debug"))
return false
}
for _, hostID := range ac.ServerSettings.DebugHostIDs {
if hostID == id {
return true
}
}
return false
}
////////////////////////////////////////////////////////////////////////////////
// Get Client Config
////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load diff

View file

@ -90,7 +90,7 @@ func NewService(
return validationMiddleware{svc, ds, sso}, nil
}
func (s Service) SendEmail(mail fleet.Email) error {
func (s *Service) SendEmail(mail fleet.Email) error {
return s.mailService.SendEmail(mail)
}

View file

@ -1,46 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
// Note host did not authenticate via node key. We need to authenticate them
// by the session ID and request ID
carve, err := svc.carveStore.CarveBySessionId(ctx, payload.SessionId)
if err != nil {
return ctxerr.Wrap(ctx, err, "find carve by session_id")
}
if payload.RequestId != carve.RequestId {
return errors.New("request_id does not match")
}
// Request is now authenticated
if payload.BlockId > carve.BlockCount-1 {
return fmt.Errorf("block_id exceeds expected max (%d): %d", carve.BlockCount-1, payload.BlockId)
}
if payload.BlockId != carve.MaxBlock+1 {
return fmt.Errorf("block_id does not match expected block (%d): %d", carve.MaxBlock+1, payload.BlockId)
}
if int64(len(payload.Data)) > carve.BlockSize {
return fmt.Errorf("exceeded declared block size %d: %d", carve.BlockSize, len(payload.Data))
}
if err := svc.carveStore.NewBlock(ctx, carve, payload.BlockId, payload.Data); err != nil {
return ctxerr.Wrap(ctx, err, "save block data")
}
return nil
}

View file

@ -1,432 +0,0 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCarveBegin(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
expectedMetadata := fleet.CarveMetadata{
ID: 7,
HostId: host.ID,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
}
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
metadata.ID = 7
return metadata, nil
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
metadata, err := svc.CarveBegin(ctx, payload)
require.NoError(t, err)
assert.NotEmpty(t, metadata.SessionId)
metadata.SessionId = "" // Clear this before comparison
metadata.Name = "" // Clear this before comparison
metadata.CreatedAt = time.Time{} // Clear this before comparison
assert.Equal(t, expectedMetadata, *metadata)
}
func TestCarveBeginNewCarveError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
return nil, errors.New("ouch!")
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "ouch!")
}
func TestCarveBeginEmptyError(t *testing.T) {
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ctx := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if id != 1 {
return nil, errors.New("not found")
}
return &fleet.Host{}, nil
}
_, err := svc.CarveBegin(ctx, fleet.CarveBeginPayload{})
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size must be greater than 0")
}
func TestCarveBeginMissingHostError(t *testing.T) {
ms := new(mock.Store)
svc := &Service{carveStore: ms}
_, err := svc.CarveBegin(context.Background(), fleet.CarveBeginPayload{})
require.Error(t, err)
assert.Contains(t, err.Error(), "missing host")
}
func TestCarveBeginBlockSizeMaxError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 10,
BlockSize: 1024 * 1024 * 1024 * 1024, // 1TB
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "block_size exceeds max")
}
func TestCarveBeginCarveSizeMaxError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 1024 * 1024,
BlockSize: 10 * 1024 * 1024, // 1TB
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
ctx := hostctx.NewContext(context.Background(), &host)
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size exceeds max")
}
func TestCarveBeginCarveSizeError(t *testing.T) {
host := fleet.Host{ID: 3}
payload := fleet.CarveBeginPayload{
BlockCount: 7,
BlockSize: 13,
CarveSize: 7*13 + 1,
RequestId: "carve_request",
}
ms := new(mock.Store)
ds := new(mock.Store)
svc := &Service{
carveStore: ms,
ds: ds,
}
ctx := hostctx.NewContext(context.Background(), &host)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if host.ID != id {
return nil, errors.New("not found")
}
return &fleet.Host{
Hostname: host.Hostname,
}, nil
}
// Too big
_, err := svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size does not match")
// Too small
payload.CarveSize = 6 * 13
_, err = svc.CarveBegin(ctx, payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "carve_size does not match")
}
func TestCarveCarveBlockGetCarveError(t *testing.T) {
sessionId := "foobar"
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
return nil, errors.New("ouch!")
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
SessionId: sessionId,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "ouch!")
}
func TestCarveCarveBlockRequestIdError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "not_matching",
SessionId: sessionId,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "request_id does not match")
}
func TestCarveCarveBlockBlockCountExceedError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 23,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "block_id exceeds expected max")
}
func TestCarveCarveBlockBlockCountMatchError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 7,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "block_id does not match")
}
func TestCarveCarveBlockBlockSizeError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 16,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :) TOO LONG!!!"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 4,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeded declared block size")
}
func TestCarveCarveBlockNewBlockError(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
return errors.New("kaboom!")
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 4,
}
err := svc.CarveBlock(context.Background(), payload)
require.Error(t, err)
assert.Contains(t, err.Error(), "kaboom!")
}
func TestCarveCarveBlock(t *testing.T) {
sessionId := "foobar"
metadata := &fleet.CarveMetadata{
ID: 2,
HostId: 3,
BlockCount: 23,
BlockSize: 64,
CarveSize: 23 * 64,
RequestId: "carve_request",
SessionId: sessionId,
MaxBlock: 3,
}
payload := fleet.CarveBlockPayload{
Data: []byte("this is the carve data :)"),
RequestId: "carve_request",
SessionId: sessionId,
BlockId: 4,
}
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
assert.Equal(t, metadata.SessionId, sessionId)
return metadata, nil
}
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
assert.Equal(t, metadata, carve)
assert.Equal(t, int64(4), blockId)
assert.Equal(t, payload.Data, data)
return nil
}
err := svc.CarveBlock(context.Background(), payload)
require.NoError(t, err)
assert.True(t, ms.NewBlockFuncInvoked)
}

View file

@ -1,34 +0,0 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func (svc *Service) VerifyInvite(ctx context.Context, token string) (*fleet.Invite, error) {
// skipauth: There is no viewer context at this point. We rely on verifying
// the invite for authNZ.
svc.authz.SkipAuthorization(ctx)
logging.WithExtras(ctx, "token", token)
invite, err := svc.ds.InviteByToken(ctx, token)
if err != nil {
return nil, err
}
if invite.Token != token {
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite Token does not match Email Address.")
}
expiresAt := invite.CreatedAt.Add(svc.config.App.InviteTokenValidityPeriod)
if svc.clock.Now().After(expiresAt) {
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite token has expired.")
}
return invite, nil
}

View file

@ -1,317 +0,0 @@
package service
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/service/osquery_utils"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
)
type osqueryError struct {
message string
nodeInvalid bool
}
func (e osqueryError) Error() string {
return e.message
}
func (e osqueryError) NodeInvalid() bool {
return e.nodeInvalid
}
var counter = int64(0)
func (svc Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
if nodeKey == "" {
return nil, false, osqueryError{
message: "authentication error: missing node key",
nodeInvalid: true,
}
}
host, err := svc.ds.LoadHostByNodeKey(ctx, nodeKey)
switch {
case err == nil:
// OK
case fleet.IsNotFound(err):
return nil, false, osqueryError{
message: "authentication error: invalid node key: " + nodeKey,
nodeInvalid: true,
}
default:
return nil, false, osqueryError{
message: "authentication error: " + err.Error(),
}
}
// Update the "seen" time used to calculate online status. These updates are
// batched for MySQL performance reasons. Because this is done
// asynchronously, it is possible for the server to shut down before
// updating the seen time for these hosts. This seems to be an acceptable
// tradeoff as an online host will continue to check in and quickly be
// marked online again.
svc.seenHostSet.addHostID(host.ID)
host.SeenTime = svc.clock.Now()
return host, svc.debugEnabledForHost(ctx, host.ID), nil
}
func (svc Service) debugEnabledForHost(ctx context.Context, id uint) bool {
hlogger := log.With(svc.logger, "host-id", id)
ac, err := svc.ds.AppConfig(ctx)
if err != nil {
level.Debug(hlogger).Log("err", ctxerr.Wrap(ctx, err, "getting app config for host debug"))
return false
}
for _, hostID := range ac.ServerSettings.DebugHostIDs {
if hostID == id {
return true
}
}
return false
}
func (svc Service) EnrollAgent(ctx context.Context, enrollSecret, hostIdentifier string, hostDetails map[string](map[string]string)) (string, error) {
// skipauth: Authorization is currently for user endpoints only.
svc.authz.SkipAuthorization(ctx)
logging.WithExtras(ctx, "hostIdentifier", hostIdentifier)
secret, err := svc.ds.VerifyEnrollSecret(ctx, enrollSecret)
if err != nil {
return "", osqueryError{
message: "enroll failed: " + err.Error(),
nodeInvalid: true,
}
}
nodeKey, err := server.GenerateRandomText(svc.config.Osquery.NodeKeySize)
if err != nil {
return "", osqueryError{
message: "generate node key failed: " + err.Error(),
nodeInvalid: true,
}
}
hostIdentifier = getHostIdentifier(svc.logger, svc.config.Osquery.HostIdentifier, hostIdentifier, hostDetails)
host, err := svc.ds.EnrollHost(ctx, hostIdentifier, nodeKey, secret.TeamID, svc.config.Osquery.EnrollCooldown)
if err != nil {
return "", osqueryError{message: "save enroll failed: " + err.Error(), nodeInvalid: true}
}
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return "", osqueryError{message: "app config load failed: " + err.Error(), nodeInvalid: true}
}
// Save enrollment details if provided
detailQueries := osquery_utils.GetDetailQueries(appConfig, svc.config)
save := false
if r, ok := hostDetails["os_version"]; ok {
err := detailQueries["os_version"].IngestFunc(svc.logger, host, []map[string]string{r})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "Ingesting os_version")
}
save = true
}
if r, ok := hostDetails["osquery_info"]; ok {
err := detailQueries["osquery_info"].IngestFunc(svc.logger, host, []map[string]string{r})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "Ingesting osquery_info")
}
save = true
}
if r, ok := hostDetails["system_info"]; ok {
err := detailQueries["system_info"].IngestFunc(svc.logger, host, []map[string]string{r})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "Ingesting system_info")
}
save = true
}
if save {
if appConfig.ServerSettings.DeferredSaveHost {
go svc.serialUpdateHost(host)
} else {
if err := svc.ds.UpdateHost(ctx, host); err != nil {
return "", ctxerr.Wrap(ctx, err, "save host in enroll agent")
}
}
}
return nodeKey, nil
}
func (svc Service) serialUpdateHost(host *fleet.Host) {
newVal := atomic.AddInt64(&counter, 1)
defer func() {
atomic.AddInt64(&counter, -1)
}()
level.Debug(svc.logger).Log("background", newVal)
ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second)
defer cancelFunc()
err := svc.ds.SerialUpdateHost(ctx, host)
if err != nil {
level.Error(svc.logger).Log("background-err", err)
}
}
func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier string, details map[string](map[string]string)) string {
switch identifierOption {
case "provided":
// Use the host identifier already provided in the request.
return providedIdentifier
case "instance":
r, ok := details["osquery_info"]
if !ok {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing osquery_info",
"identifier", "instance",
)
} else if r["instance_id"] == "" {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing instance_id in osquery_info",
"identifier", "instance",
)
} else {
return r["instance_id"]
}
case "uuid":
r, ok := details["osquery_info"]
if !ok {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing osquery_info",
"identifier", "uuid",
)
} else if r["uuid"] == "" {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing instance_id in osquery_info",
"identifier", "uuid",
)
} else {
return r["uuid"]
}
case "hostname":
r, ok := details["system_info"]
if !ok {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing system_info",
"identifier", "hostname",
)
} else if r["hostname"] == "" {
level.Info(logger).Log(
"msg", "could not get host identifier",
"reason", "missing instance_id in system_info",
"identifier", "hostname",
)
} else {
return r["hostname"]
}
default:
panic("Unknown option for host_identifier: " + identifierOption)
}
return providedIdentifier
}
// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value
// that is properly balanced. Balance in this context means that hosts would be distributed uniformly
// across the total jitter time so there are no spikes.
// The way this structure works is as follows:
// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to
// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should
// end up with 1 per bucket.
// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount
// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we
// increase the maxCapacity by 1 for all buckets, and start placing hosts.
// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is
// running.
// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the
// next one will be tried. If all buckets are full, then capacity gets increased and the bucket
// selection process restarts.
// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of
// minutes added to the host check in time.
// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to
// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case.
// In the worst possible case that all hosts start at the same time, max jitter percent can be set to
// 100, and this method will distribute hosts evenly.
// The main caveat of this approach is that it works at the fleet instance. So depending on what
// instance gets chosen by the load balancer, the jitter might be different. However, load tests have
// shown that the distribution in practice is pretty balance even when all hosts try to check in at
// the same time.
type jitterHashTable struct {
mu sync.Mutex
maxCapacity int
bucketCount int
buckets map[int]int
cache map[uint]time.Duration
}
func newJitterHashTable(bucketCount int) *jitterHashTable {
if bucketCount == 0 {
bucketCount = 1
}
return &jitterHashTable{
maxCapacity: 1,
bucketCount: bucketCount,
buckets: make(map[int]int),
cache: make(map[uint]time.Duration),
}
}
func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration {
// if no jitter is configured just return 0
if jh.bucketCount <= 1 {
return 0
}
jh.mu.Lock()
if jitter, ok := jh.cache[hostID]; ok {
jh.mu.Unlock()
return jitter
}
for i := 0; i < jh.bucketCount; i++ {
possibleBucket := (int(hostID) + i) % jh.bucketCount
// if the next bucket has capacity, great!
if jh.buckets[possibleBucket] < jh.maxCapacity {
jh.buckets[possibleBucket]++
jitter := time.Duration(possibleBucket) * time.Minute
jh.cache[hostID] = jitter
jh.mu.Unlock()
return jitter
}
}
// otherwise, bump the capacity and restart the process
jh.maxCapacity++
jh.mu.Unlock()
return jh.jitterForHost(hostID)
}

File diff suppressed because it is too large Load diff

View file

@ -1,325 +0,0 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/xml"
"errors"
"fmt"
"net/url"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/sso"
"github.com/go-kit/kit/log/level"
)
// SSOSettings returns a subset of the Single Sign-On settings as configured in
// the app config. Those can be exposed e.g. via the response to an HTTP request,
// and as such should not contain sensitive information.
func (svc *Service) SSOSettings(ctx context.Context) (*fleet.SessionSSOSettings, error) {
// skipauth: Basic SSO settings are available to unauthenticated users (so
// that they have the necessary information to initiate SSO).
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "SessionSSOSettings getting app config")
}
settings := &fleet.SessionSSOSettings{
IDPName: appConfig.SSOSettings.IDPName,
IDPImageURL: appConfig.SSOSettings.IDPImageURL,
SSOEnabled: appConfig.SSOSettings.EnableSSO,
}
return settings, nil
}
// InitiateSSO initiates a Single Sign-On flow for a request to visit the
// protected URL identified by redirectURL. It returns the URL of the identity
// provider to make a request to to proceed with the authentication via that
// external service, and stores ephemeral session state to validate the
// callback from the identity provider to finalize the SSO flow.
func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string, error) {
// skipauth: User context does not yet exist. Unauthenticated users may
// initiate SSO.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config")
}
metadata, err := svc.getMetadata(appConfig)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata")
}
serverURL := appConfig.ServerSettings.ServerURL
settings := sso.Settings{
Metadata: metadata,
// Construct call back url to send to idp
AssertionConsumerServiceURL: serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/sso/callback",
SessionStore: svc.ssoSessionStore,
OriginalURL: redirectURL,
}
// If issuer is not explicitly set, default to host name.
var issuer string
entityID := appConfig.SSOSettings.EntityID
if entityID == "" {
u, err := url.Parse(serverURL)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "parse server url")
}
issuer = u.Hostname()
} else {
issuer = entityID
}
idpURL, err := sso.CreateAuthorizationRequest(&settings, issuer)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization")
}
return idpURL, nil
}
func (svc *Service) getMetadata(config *fleet.AppConfig) (*sso.Metadata, error) {
if config.SSOSettings.MetadataURL != "" {
metadata, err := sso.GetMetadata(config.SSOSettings.MetadataURL)
if err != nil {
return nil, err
}
return metadata, nil
}
if config.SSOSettings.Metadata != "" {
metadata, err := sso.ParseMetadata(config.SSOSettings.Metadata)
if err != nil {
return nil, err
}
return metadata, nil
}
return nil, fmt.Errorf("missing metadata for idp %s", config.SSOSettings.IDPName)
}
func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SSOSession, error) {
// skipauth: User context does not yet exist. Unauthenticated users may
// hit the SSO callback.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get config for sso")
}
// Load the request metadata if available
// localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080
var metadata *sso.Metadata
var redirectURL string
if appConfig.SSOSettings.EnableSSOIdPLogin && auth.RequestID() == "" {
// Missing request ID indicates this was IdP-initiated. Only allow if
// configured to do so.
metadata, err = svc.getMetadata(appConfig)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get sso metadata")
}
redirectURL = "/"
} else {
session, err := svc.ssoSessionStore.Get(auth.RequestID())
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "sso request invalid")
}
// Remove session to so that is can't be reused before it expires.
err = svc.ssoSessionStore.Expire(auth.RequestID())
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "remove sso request")
}
if err := xml.Unmarshal([]byte(session.Metadata), &metadata); err != nil {
return nil, ctxerr.Wrap(ctx, err, "unmarshal metadata")
}
redirectURL = session.OriginalURL
}
// Validate response
validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience(
appConfig.SSOSettings.EntityID,
appConfig.ServerSettings.ServerURL,
appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS
))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "create validator from metadata")
}
// make sure the response hasn't been tampered with
auth, err = validator.ValidateSignature(auth)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "signature validation failed")
}
// make sure the response isn't stale
err = validator.ValidateResponse(auth)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "response validation failed")
}
// Get and log in user
user, err := svc.ds.UserByEmail(ctx, auth.UserID())
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "find user in sso callback")
}
// if the user is not sso enabled they are not authorized
if !user.SSOEnabled {
return nil, ctxerr.New(ctx, "user not configured to use sso")
}
token, err := svc.makeSession(ctx, user.ID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "make session in sso callback")
}
result := &fleet.SSOSession{
Token: token,
RedirectURL: redirectURL,
}
return result, nil
}
func (svc *Service) Login(ctx context.Context, email, password string) (*fleet.User, string, error) {
// skipauth: No user context available yet to authorize against.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(logging.WithNoUser(ctx), level.Info)
// If there is an error, sleep until the request has taken at least 1
// second. This means that generally a login failure for any reason will
// take ~1s and frustrate a timing attack.
var err error
defer func(start time.Time) {
if err != nil {
time.Sleep(time.Until(start.Add(1 * time.Second)))
}
}(time.Now())
user, err := svc.ds.UserByEmail(ctx, email)
var nfe fleet.NotFoundError
if errors.As(err, &nfe) {
return nil, "", fleet.NewAuthFailedError("user not found")
}
if err != nil {
return nil, "", fleet.NewAuthFailedError(err.Error())
}
if err = user.ValidatePassword(password); err != nil {
return nil, "", fleet.NewAuthFailedError("invalid password")
}
if user.SSOEnabled {
return nil, "", fleet.NewAuthFailedError("password login disabled for sso users")
}
token, err := svc.makeSession(ctx, user.ID)
if err != nil {
return nil, "", fleet.NewAuthFailedError(err.Error())
}
return user, token, nil
}
// makeSession is a helper that creates a new session after authentication
func (svc *Service) makeSession(ctx context.Context, id uint) (string, error) {
sessionKeySize := svc.config.Session.KeySize
key := make([]byte, sessionKeySize)
_, err := rand.Read(key)
if err != nil {
return "", err
}
sessionKey := base64.StdEncoding.EncodeToString(key)
session := &fleet.Session{
UserID: id,
Key: sessionKey,
AccessedAt: time.Now().UTC(),
}
_, err = svc.ds.NewSession(ctx, session)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "creating new session")
}
return sessionKey, nil
}
func (svc *Service) Logout(ctx context.Context) error {
// skipauth: Any user can always log out of their own session.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
// TODO: this should not return an error if the user wasn't logged in
return svc.DestroySession(ctx)
}
func (svc *Service) DestroySession(ctx context.Context) error {
vc, ok := viewer.FromContext(ctx)
if !ok {
return fleet.ErrNoContext
}
session, err := svc.ds.SessionByID(ctx, vc.SessionID())
if err != nil {
return err
}
if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.DestroySession(ctx, session)
}
func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
session, err := svc.ds.SessionByKey(ctx, key)
if err != nil {
return nil, err
}
err = svc.validateSession(ctx, session)
if err != nil {
return nil, err
}
return session, nil
}
func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error {
if session == nil {
return fleet.NewAuthRequiredError("active session not present")
}
sessionDuration := svc.config.Session.Duration
if session.APIOnly != nil && *session.APIOnly {
sessionDuration = 0 // make API-only tokens unlimited
}
// duration 0 = unlimited
if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration {
err := svc.ds.DestroySession(ctx, session)
if err != nil {
return ctxerr.Wrap(ctx, err, "destroying session")
}
return fleet.NewAuthRequiredError("expired session")
}
return svc.ds.MarkSessionAccessed(ctx, session)
}

View file

@ -1,111 +0,0 @@
package service
import (
"context"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthenticate(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
var loginTests = []struct {
name string
email string
password string
wantErr error
}{
{
name: "admin1",
email: testUsers["admin1"].Email,
password: testUsers["admin1"].PlaintextPassword,
},
{
name: "user1",
email: testUsers["user1"].Email,
password: testUsers["user1"].PlaintextPassword,
},
}
for _, tt := range loginTests {
t.Run(tt.email, func(st *testing.T) {
loggedIn, token, err := svc.Login(test.UserContext(test.UserAdmin), tt.email, tt.password)
require.Nil(st, err, "login unsuccessful")
assert.Equal(st, tt.email, loggedIn.Email)
assert.NotEmpty(st, token)
sessions, err := svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), loggedIn.ID)
require.Nil(st, err)
require.Len(st, sessions, 1, "user should have one session")
session := sessions[0]
assert.NotZero(st, session.UserID)
assert.WithinDuration(st, time.Now(), session.AccessedAt, 3*time.Second,
"access time should be set with current time at session creation")
})
}
}
func TestGetSessionByKey(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
cfg := config.TestConfig()
theSession := &fleet.Session{UserID: 123, Key: "abc"}
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
return theSession, nil
}
ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error {
return nil
}
ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error {
return nil
}
cases := []struct {
desc string
accessed time.Duration
apiOnly bool
fail bool
}{
{"real user, accessed recently", -1 * time.Hour, false, false},
{"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true},
{"api-only, accessed recently", -1 * time.Hour, true, false},
{"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var authErr *fleet.AuthRequiredError
ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false
theSession.AccessedAt = time.Now().Add(tc.accessed)
theSession.APIOnly = ptr.Bool(tc.apiOnly)
_, err := svc.GetSessionByKey(context.Background(), theSession.Key)
if tc.fail {
require.Error(t, err)
require.ErrorAs(t, err, &authErr)
require.True(t, ds.SessionByKeyFuncInvoked)
require.True(t, ds.DestroySessionFuncInvoked)
require.False(t, ds.MarkSessionAccessedFuncInvoked)
} else {
require.NoError(t, err)
require.True(t, ds.SessionByKeyFuncInvoked)
require.False(t, ds.DestroySessionFuncInvoked)
require.True(t, ds.MarkSessionAccessedFuncInvoked)
}
})
}
}

View file

@ -1,15 +0,0 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func (svc *Service) ListAvailableTeamsForUser(ctx context.Context, user *fleet.User) ([]*fleet.TeamSummary, error) {
// skipauth: No authorization check needed due to implementation returning
// only license error.
svc.authz.SkipAuthorization(ctx)
return nil, fleet.ErrMissingLicense
}

View file

@ -2,45 +2,14 @@ package service
import (
"context"
"encoding/base64"
"html/template"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mail"
"github.com/fleetdm/fleet/v4/server/ptr"
)
func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
// skipauth: There is no viewer context at this point. We rely on verifying
// the invite for authNZ.
svc.authz.SkipAuthorization(ctx)
invite, err := svc.VerifyInvite(ctx, *p.InviteToken)
if err != nil {
return nil, err
}
// set the payload role property based on an existing invite.
p.GlobalRole = invite.GlobalRole.Ptr()
p.Teams = &invite.Teams
user, err := svc.newUser(ctx, p)
if err != nil {
return nil, err
}
err = svc.ds.DeleteInvite(ctx, invite.ID)
if err != nil {
return nil, err
}
return user, nil
}
func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
// skipauth: Only the initial user creation should be allowed to skip
// authorization (because there is not yet a user context to check against).
@ -88,154 +57,3 @@ func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User,
// Explicitly no authorization check. Should only be used by middleware.
return svc.ds.UserByID(ctx, id)
}
// setNewPassword is a helper for changing a user's password. It should be
// called to set the new password after proper authorization has been
// performed.
func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, password string) error {
err := user.SetPassword(password, svc.config.Auth.SaltKeySize, svc.config.Auth.BcryptCost)
if err != nil {
return ctxerr.Wrap(ctx, err, "setting new password")
}
if user.SSOEnabled {
return ctxerr.New(ctx, "set password for single sign on user not allowed")
}
err = svc.saveUser(ctx, user)
if err != nil {
return ctxerr.Wrap(ctx, err, "saving changed password")
}
return nil
}
func (svc *Service) ResetPassword(ctx context.Context, token, password string) error {
// skipauth: No viewer context available. The user is locked out of their
// account and authNZ is performed entirely by providing a valid password
// reset token.
svc.authz.SkipAuthorization(ctx)
reset, err := svc.ds.FindPassswordResetByToken(ctx, token)
if err != nil {
return ctxerr.Wrap(ctx, err, "looking up reset by token")
}
user, err := svc.ds.UserByID(ctx, reset.UserID)
if err != nil {
return ctxerr.Wrap(ctx, err, "retrieving user")
}
if user.SSOEnabled {
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
// prevent setting the same password
if err := user.ValidatePassword(password); err == nil {
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
}
err = svc.setNewPassword(ctx, user, password)
if err != nil {
return ctxerr.Wrap(ctx, err, "setting new password")
}
// delete password reset tokens for user
if err := svc.ds.DeletePasswordResetRequestsForUser(ctx, user.ID); err != nil {
return ctxerr.Wrap(ctx, err, "delete password reset requests")
}
// Clear sessions so that any other browsers will have to log in with
// the new password
if err := svc.ds.DestroyAllSessionsForUser(ctx, user.ID); err != nil {
return ctxerr.Wrap(ctx, err, "delete user sessions")
}
return nil
}
func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password string) (*fleet.User, error) {
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, fleet.ErrNoContext
}
user := vc.User
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
return nil, err
}
if user.SSOEnabled {
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
if !user.IsAdminForcedPasswordReset() {
return nil, ctxerr.New(ctx, "user does not require password reset")
}
// prevent setting the same password
if err := user.ValidatePassword(password); err == nil {
return nil, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
}
user.AdminForcedPasswordReset = false
err := svc.setNewPassword(ctx, user, password)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "setting new password")
}
// Sessions should already have been cleared when the reset was
// required
return user, nil
}
func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error {
// skipauth: No viewer context available. The user is locked out of their
// account and trying to reset their password.
svc.authz.SkipAuthorization(ctx)
// Regardless of error, sleep until the request has taken at least 1 second.
// This means that any request to this method will take ~1s and frustrate a timing attack.
defer func(start time.Time) {
time.Sleep(time.Until(start.Add(1 * time.Second)))
}(time.Now())
user, err := svc.ds.UserByEmail(ctx, email)
if err != nil {
return err
}
if user.SSOEnabled {
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
if err != nil {
return err
}
token := base64.URLEncoding.EncodeToString([]byte(random))
request := &fleet.PasswordResetRequest{
ExpiresAt: time.Now().Add(time.Hour * 24),
UserID: user.ID,
Token: token,
}
_, err = svc.ds.NewPasswordResetRequest(ctx, request)
if err != nil {
return err
}
config, err := svc.ds.AppConfig(ctx)
if err != nil {
return err
}
resetEmail := fleet.Email{
Subject: "Reset Your Fleet Password",
To: []string{user.Email},
Config: config,
Mailer: &mail.PasswordResetMailer{
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
AssetURL: getAssetURL(),
Token: token,
},
}
return svc.mailService.SendEmail(resetEmail)
}

View file

@ -1,188 +0,0 @@
package service
import (
"context"
"database/sql"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthenticatedUser(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
createTestUsers(t, ds)
svc := newTestService(ds, nil, nil)
admin1, err := ds.UserByEmail(context.Background(), "admin1@example.com")
assert.Nil(t, err)
admin1Session, err := ds.NewSession(context.Background(), &fleet.Session{
UserID: admin1.ID,
Key: "admin1",
})
assert.Nil(t, err)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: admin1, Session: admin1Session})
user, err := svc.AuthenticatedUser(ctx)
assert.Nil(t, err)
assert.Equal(t, user, admin1)
}
func TestResetPassword(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
passwordResetTests := []struct {
token string
newPassword string
wantErr error
}{
{ // all good
token: "abcd",
newPassword: "123cat!",
},
{ // prevent reuse
token: "abcd",
newPassword: "123cat!",
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
},
{ // bad token
token: "dcbaz",
newPassword: "123cat!",
wantErr: sql.ErrNoRows,
},
{ // missing token
newPassword: "123cat!",
wantErr: fleet.NewInvalidArgumentError("token", "Token cannot be empty field"),
},
}
for _, tt := range passwordResetTests {
t.Run("", func(t *testing.T) {
request := &fleet.PasswordResetRequest{
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
CreateTimestamp: fleet.CreateTimestamp{
CreatedAt: time.Now(),
},
UpdateTimestamp: fleet.UpdateTimestamp{
UpdatedAt: time.Now(),
},
},
ExpiresAt: time.Now().Add(time.Hour * 24),
UserID: 1,
Token: "abcd",
}
_, err := ds.NewPasswordResetRequest(context.Background(), request)
assert.Nil(t, err)
serr := svc.ResetPassword(test.UserContext(&fleet.User{ID: 1}), tt.token, tt.newPassword)
if tt.wantErr != nil {
assert.Equal(t, tt.wantErr.Error(), ctxerr.Cause(serr).Error())
} else {
assert.Nil(t, serr)
}
})
}
}
func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context {
reloadedUser, err := ds.UserByEmail(ctx, user.Email)
require.NoError(t, err)
return viewer.NewContext(ctx, viewer.Viewer{User: reloadedUser, Session: session})
}
func TestPerformRequiredPasswordReset(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
for _, tt := range testUsers {
t.Run(tt.Email, func(t *testing.T) {
user, err := ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
ctx := context.Background()
_, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, nil)
session, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID})
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, session)
// should error when reset not required
_, err = svc.RequirePasswordReset(ctx, user.ID, false)
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, session)
_, err = svc.PerformRequiredPasswordReset(ctx, "new_pass")
require.NotNil(t, err)
_, err = svc.RequirePasswordReset(ctx, user.ID, true)
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, session)
// should error when using same password
_, err = svc.PerformRequiredPasswordReset(ctx, tt.PlaintextPassword)
require.Equal(t, "validation failed: new_password cannot reuse old password", err.Error())
// should succeed with good new password
u, err := svc.PerformRequiredPasswordReset(ctx, "new_pass")
require.Nil(t, err)
assert.False(t, u.AdminForcedPasswordReset)
ctx = context.Background()
// Now user should be able to login with new password
u, _, err = svc.Login(ctx, tt.Email, "new_pass")
require.Nil(t, err)
assert.False(t, u.AdminForcedPasswordReset)
})
}
}
func TestUserPasswordRequirements(t *testing.T) {
passwordTests := []struct {
password string
wantErr bool
}{
{
password: "foobar",
wantErr: true,
},
{
password: "foobarbaz",
wantErr: true,
},
{
password: "foobarbaz!",
wantErr: true,
},
{
password: "foobarbaz!3",
},
}
for _, tt := range passwordTests {
t.Run(tt.password, func(t *testing.T) {
err := validatePasswordRequirements(tt.password)
if tt.wantErr {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}

View file

@ -1,10 +1,25 @@
package service
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/xml"
"errors"
"fmt"
"html/template"
"net/http"
"net/url"
"strings"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/sso"
"github.com/go-kit/kit/log/level"
)
////////////////////////////////////////////////////////////////////////////////
@ -91,3 +106,480 @@ func (svc *Service) DeleteSession(ctx context.Context, id uint) error {
return svc.ds.DestroySession(ctx, session)
}
////////////////////////////////////////////////////////////////////////////////
// Login
////////////////////////////////////////////////////////////////////////////////
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type loginResponse struct {
User *fleet.User `json:"user,omitempty"`
AvailableTeams []*fleet.TeamSummary `json:"available_teams"`
Token string `json:"token,omitempty"`
Err error `json:"error,omitempty"`
}
func (r loginResponse) error() error { return r.Err }
func loginEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*loginRequest)
req.Email = strings.ToLower(req.Email)
user, token, err := svc.Login(ctx, req.Email, req.Password)
if err != nil {
return loginResponse{Err: err}, nil
}
// Add viewer context allow access to service teams for list of available teams
v, err := authViewer(ctx, token, svc)
if err != nil {
return loginResponse{Err: err}, nil
}
ctx = viewer.NewContext(ctx, *v)
availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user)
if err != nil {
if errors.Is(err, fleet.ErrMissingLicense) {
availableTeams = []*fleet.TeamSummary{}
} else {
return loginResponse{Err: err}, nil
}
}
return loginResponse{user, availableTeams, token, nil}, nil
}
func (svc *Service) Login(ctx context.Context, email, password string) (*fleet.User, string, error) {
// skipauth: No user context available yet to authorize against.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(logging.WithNoUser(ctx), level.Info)
// If there is an error, sleep until the request has taken at least 1
// second. This means that generally a login failure for any reason will
// take ~1s and frustrate a timing attack.
var err error
defer func(start time.Time) {
if err != nil {
time.Sleep(time.Until(start.Add(1 * time.Second)))
}
}(time.Now())
user, err := svc.ds.UserByEmail(ctx, email)
var nfe fleet.NotFoundError
if errors.As(err, &nfe) {
return nil, "", fleet.NewAuthFailedError("user not found")
}
if err != nil {
return nil, "", fleet.NewAuthFailedError(err.Error())
}
if err = user.ValidatePassword(password); err != nil {
return nil, "", fleet.NewAuthFailedError("invalid password")
}
if user.SSOEnabled {
return nil, "", fleet.NewAuthFailedError("password login disabled for sso users")
}
token, err := svc.makeSession(ctx, user.ID)
if err != nil {
return nil, "", fleet.NewAuthFailedError(err.Error())
}
return user, token, nil
}
////////////////////////////////////////////////////////////////////////////////
// Logout
////////////////////////////////////////////////////////////////////////////////
type logoutResponse struct {
Err error `json:"error,omitempty"`
}
func (r logoutResponse) error() error { return r.Err }
func logoutEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
err := svc.Logout(ctx)
if err != nil {
return logoutResponse{Err: err}, nil
}
return logoutResponse{}, nil
}
func (svc *Service) Logout(ctx context.Context) error {
// skipauth: Any user can always log out of their own session.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
// TODO: this should not return an error if the user wasn't logged in
return svc.DestroySession(ctx)
}
func (svc *Service) DestroySession(ctx context.Context) error {
vc, ok := viewer.FromContext(ctx)
if !ok {
return fleet.ErrNoContext
}
session, err := svc.ds.SessionByID(ctx, vc.SessionID())
if err != nil {
return err
}
if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.DestroySession(ctx, session)
}
////////////////////////////////////////////////////////////////////////////////
// Initiate SSO
////////////////////////////////////////////////////////////////////////////////
type initiateSSORequest struct {
RelayURL string `json:"relay_url"`
}
type initiateSSOResponse struct {
URL string `json:"url,omitempty"`
Err error `json:"error,omitempty"`
}
func (r initiateSSOResponse) error() error { return r.Err }
func initiateSSOEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*initiateSSORequest)
idProviderURL, err := svc.InitiateSSO(ctx, req.RelayURL)
if err != nil {
return initiateSSOResponse{Err: err}, nil
}
return initiateSSOResponse{URL: idProviderURL}, nil
}
// InitiateSSO initiates a Single Sign-On flow for a request to visit the
// protected URL identified by redirectURL. It returns the URL of the identity
// provider to make a request to to proceed with the authentication via that
// external service, and stores ephemeral session state to validate the
// callback from the identity provider to finalize the SSO flow.
func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string, error) {
// skipauth: User context does not yet exist. Unauthenticated users may
// initiate SSO.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config")
}
metadata, err := svc.getMetadata(appConfig)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata")
}
serverURL := appConfig.ServerSettings.ServerURL
settings := sso.Settings{
Metadata: metadata,
// Construct call back url to send to idp
AssertionConsumerServiceURL: serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/sso/callback",
SessionStore: svc.ssoSessionStore,
OriginalURL: redirectURL,
}
// If issuer is not explicitly set, default to host name.
var issuer string
entityID := appConfig.SSOSettings.EntityID
if entityID == "" {
u, err := url.Parse(serverURL)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "parse server url")
}
issuer = u.Hostname()
} else {
issuer = entityID
}
idpURL, err := sso.CreateAuthorizationRequest(&settings, issuer)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization")
}
return idpURL, nil
}
////////////////////////////////////////////////////////////////////////////////
// Callback SSO
////////////////////////////////////////////////////////////////////////////////
type callbackSSORequest struct{}
func (callbackSSORequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) {
err := r.ParseForm()
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "decode sso callback")
}
authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse"))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "decoding sso callback")
}
return authResponse, nil
}
type callbackSSOResponse struct {
content string
Err error `json:"error,omitempty"`
}
func (r callbackSSOResponse) error() error { return r.Err }
// If html is present we return a web page
func (r callbackSSOResponse) html() string { return r.content }
func makeCallbackSSOEndpoint(urlPrefix string) handlerFunc {
return func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
authResponse := request.(fleet.Auth)
session, err := svc.CallbackSSO(ctx, authResponse)
var resp callbackSSOResponse
if err != nil {
// redirect to login page on front end if there was some problem,
// errors should still be logged
session = &fleet.SSOSession{
RedirectURL: urlPrefix + "/login",
Token: "",
}
resp.Err = err
}
relayStateLoadPage := ` <html>
<script type='text/javascript'>
var redirectURL = {{ .RedirectURL }};
window.localStorage.setItem('FLEET::auth_token', '{{ .Token }}');
window.location = redirectURL;
</script>
<body>
Redirecting to Fleet at {{ .RedirectURL }} ...
</body>
</html>
`
tmpl, err := template.New("relayStateLoader").Parse(relayStateLoadPage)
if err != nil {
return nil, err
}
var writer bytes.Buffer
err = tmpl.Execute(&writer, session)
if err != nil {
return nil, err
}
resp.content = writer.String()
return resp, nil
}
}
func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SSOSession, error) {
// skipauth: User context does not yet exist. Unauthenticated users may
// hit the SSO callback.
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get config for sso")
}
// Load the request metadata if available
// localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080
var metadata *sso.Metadata
var redirectURL string
if appConfig.SSOSettings.EnableSSOIdPLogin && auth.RequestID() == "" {
// Missing request ID indicates this was IdP-initiated. Only allow if
// configured to do so.
metadata, err = svc.getMetadata(appConfig)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get sso metadata")
}
redirectURL = "/"
} else {
session, err := svc.ssoSessionStore.Get(auth.RequestID())
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "sso request invalid")
}
// Remove session to so that is can't be reused before it expires.
err = svc.ssoSessionStore.Expire(auth.RequestID())
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "remove sso request")
}
if err := xml.Unmarshal([]byte(session.Metadata), &metadata); err != nil {
return nil, ctxerr.Wrap(ctx, err, "unmarshal metadata")
}
redirectURL = session.OriginalURL
}
// Validate response
validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience(
appConfig.SSOSettings.EntityID,
appConfig.ServerSettings.ServerURL,
appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS
))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "create validator from metadata")
}
// make sure the response hasn't been tampered with
auth, err = validator.ValidateSignature(auth)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "signature validation failed")
}
// make sure the response isn't stale
err = validator.ValidateResponse(auth)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "response validation failed")
}
// Get and log in user
user, err := svc.ds.UserByEmail(ctx, auth.UserID())
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "find user in sso callback")
}
// if the user is not sso enabled they are not authorized
if !user.SSOEnabled {
return nil, ctxerr.New(ctx, "user not configured to use sso")
}
token, err := svc.makeSession(ctx, user.ID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "make session in sso callback")
}
result := &fleet.SSOSession{
Token: token,
RedirectURL: redirectURL,
}
return result, nil
}
////////////////////////////////////////////////////////////////////////////////
// SSO Settings
////////////////////////////////////////////////////////////////////////////////
type ssoSettingsResponse struct {
Settings *fleet.SessionSSOSettings `json:"settings,omitempty"`
Err error `json:"error,omitempty"`
}
func (r ssoSettingsResponse) error() error { return r.Err }
func settingsSSOEndpoint(ctx context.Context, _ interface{}, svc fleet.Service) (interface{}, error) {
settings, err := svc.SSOSettings(ctx)
if err != nil {
return ssoSettingsResponse{Err: err}, nil
}
return ssoSettingsResponse{Settings: settings}, nil
}
// SSOSettings returns a subset of the Single Sign-On settings as configured in
// the app config. Those can be exposed e.g. via the response to an HTTP request,
// and as such should not contain sensitive information.
func (svc *Service) SSOSettings(ctx context.Context) (*fleet.SessionSSOSettings, error) {
// skipauth: Basic SSO settings are available to unauthenticated users (so
// that they have the necessary information to initiate SSO).
svc.authz.SkipAuthorization(ctx)
logging.WithLevel(ctx, level.Info)
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "SessionSSOSettings getting app config")
}
settings := &fleet.SessionSSOSettings{
IDPName: appConfig.SSOSettings.IDPName,
IDPImageURL: appConfig.SSOSettings.IDPImageURL,
SSOEnabled: appConfig.SSOSettings.EnableSSO,
}
return settings, nil
}
// makeSession is a helper that creates a new session after authentication
func (svc *Service) makeSession(ctx context.Context, id uint) (string, error) {
sessionKeySize := svc.config.Session.KeySize
key := make([]byte, sessionKeySize)
_, err := rand.Read(key)
if err != nil {
return "", err
}
sessionKey := base64.StdEncoding.EncodeToString(key)
session := &fleet.Session{
UserID: id,
Key: sessionKey,
AccessedAt: time.Now().UTC(),
}
_, err = svc.ds.NewSession(ctx, session)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "creating new session")
}
return sessionKey, nil
}
func (svc *Service) getMetadata(config *fleet.AppConfig) (*sso.Metadata, error) {
if config.SSOSettings.MetadataURL != "" {
metadata, err := sso.GetMetadata(config.SSOSettings.MetadataURL)
if err != nil {
return nil, err
}
return metadata, nil
}
if config.SSOSettings.Metadata != "" {
metadata, err := sso.ParseMetadata(config.SSOSettings.Metadata)
if err != nil {
return nil, err
}
return metadata, nil
}
return nil, fmt.Errorf("missing metadata for idp %s", config.SSOSettings.IDPName)
}
func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
session, err := svc.ds.SessionByKey(ctx, key)
if err != nil {
return nil, err
}
err = svc.validateSession(ctx, session)
if err != nil {
return nil, err
}
return session, nil
}
func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error {
if session == nil {
return fleet.NewAuthRequiredError("active session not present")
}
sessionDuration := svc.config.Session.Duration
if session.APIOnly != nil && *session.APIOnly {
sessionDuration = 0 // make API-only tokens unlimited
}
// duration 0 = unlimited
if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration {
err := svc.ds.DestroySession(ctx, session)
if err != nil {
return ctxerr.Wrap(ctx, err, "destroying session")
}
return fleet.NewAuthRequiredError("expired session")
}
return svc.ds.MarkSessionAccessed(ctx, session)
}

View file

@ -5,10 +5,15 @@ import (
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSessionAuth(t *testing.T) {
@ -85,3 +90,98 @@ func TestSessionAuth(t *testing.T) {
})
}
}
func TestAuthenticate(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
var loginTests = []struct {
name string
email string
password string
wantErr error
}{
{
name: "admin1",
email: testUsers["admin1"].Email,
password: testUsers["admin1"].PlaintextPassword,
},
{
name: "user1",
email: testUsers["user1"].Email,
password: testUsers["user1"].PlaintextPassword,
},
}
for _, tt := range loginTests {
t.Run(tt.email, func(st *testing.T) {
loggedIn, token, err := svc.Login(test.UserContext(test.UserAdmin), tt.email, tt.password)
require.Nil(st, err, "login unsuccessful")
assert.Equal(st, tt.email, loggedIn.Email)
assert.NotEmpty(st, token)
sessions, err := svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), loggedIn.ID)
require.Nil(st, err)
require.Len(st, sessions, 1, "user should have one session")
session := sessions[0]
assert.NotZero(st, session.UserID)
assert.WithinDuration(st, time.Now(), session.AccessedAt, 3*time.Second,
"access time should be set with current time at session creation")
})
}
}
func TestGetSessionByKey(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
cfg := config.TestConfig()
theSession := &fleet.Session{UserID: 123, Key: "abc"}
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
return theSession, nil
}
ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error {
return nil
}
ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error {
return nil
}
cases := []struct {
desc string
accessed time.Duration
apiOnly bool
fail bool
}{
{"real user, accessed recently", -1 * time.Hour, false, false},
{"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true},
{"api-only, accessed recently", -1 * time.Hour, true, false},
{"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var authErr *fleet.AuthRequiredError
ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false
theSession.AccessedAt = time.Now().Add(tc.accessed)
theSession.APIOnly = ptr.Bool(tc.apiOnly)
_, err := svc.GetSessionByKey(context.Background(), theSession.Key)
if tc.fail {
require.Error(t, err)
require.ErrorAs(t, err, &authErr)
require.True(t, ds.SessionByKeyFuncInvoked)
require.True(t, ds.DestroySessionFuncInvoked)
require.False(t, ds.MarkSessionAccessedFuncInvoked)
} else {
require.NoError(t, err)
require.True(t, ds.SessionByKeyFuncInvoked)
require.False(t, ds.DestroySessionFuncInvoked)
require.True(t, ds.MarkSessionAccessedFuncInvoked)
}
})
}
}

View file

@ -36,9 +36,10 @@ func (ts *withDS) TearDownSuite() {
type withServer struct {
withDS
server *httptest.Server
users map[string]fleet.User
token string
server *httptest.Server
users map[string]fleet.User
token string
cachedAdminToken string
}
func (ts *withServer) SetupSuite(dbName string) {
@ -49,6 +50,7 @@ func (ts *withServer) SetupSuite(dbName string) {
ts.server = server
ts.users = users
ts.token = ts.getTestAdminToken()
ts.cachedAdminToken = ts.token
}
func (ts *withServer) TearDownSuite() {
@ -122,7 +124,13 @@ func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStat
func (ts *withServer) getTestAdminToken() string {
testUser := testUsers["admin1"]
return ts.getTestToken(testUser.Email, testUser.PlaintextPassword)
// because the login endpoint is rate-limited, use the cached admin token
// if available (if for some reason a test needs to logout the admin user,
// then set cachedAdminToken = "" so that a new token is retrieved).
if ts.cachedAdminToken == "" {
ts.cachedAdminToken = ts.getTestToken(testUser.Email, testUser.PlaintextPassword)
}
return ts.cachedAdminToken
}
func (ts *withServer) getTestToken(email string, password string) string {

View file

@ -60,7 +60,7 @@ func translateHostToID(ctx context.Context, ds fleet.Datastore, identifier strin
return host.ID, nil
}
func (svc Service) Translate(ctx context.Context, payloads []fleet.TranslatePayload) ([]fleet.TranslatePayload, error) {
func (svc *Service) Translate(ctx context.Context, payloads []fleet.TranslatePayload) ([]fleet.TranslatePayload, error) {
var finalPayload []fleet.TranslatePayload
for _, payload := range payloads {

View file

@ -294,10 +294,6 @@ func userListOptionsFromRequest(r *http.Request) (fleet.UserListOptions, error)
return uopt, nil
}
func decodeNoParamsRequest(ctx context.Context, r *http.Request) (interface{}, error) {
return nil, nil
}
type getGenericSpecRequest struct {
Name string `url:"name"`
}

View file

@ -1,20 +0,0 @@
package service
import (
"context"
"encoding/json"
"net/http"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
)
func decodeCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, error) {
defer r.Body.Close()
var req carveBlockRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, ctxerr.Wrap(ctx, err, "decoding JSON")
}
return req, nil
}

View file

@ -1,17 +0,0 @@
package service
import (
"context"
"net/http"
"github.com/gorilla/mux"
)
func decodeVerifyInviteRequest(ctx context.Context, r *http.Request) (interface{}, error) {
vars := mux.Vars(r)
token, ok := vars["token"]
if !ok {
return 0, errBadRoute
}
return verifyInviteRequest{Token: token}, nil
}

View file

@ -1,17 +0,0 @@
package service
import (
"context"
"encoding/json"
"net/http"
)
func decodeEnrollAgentRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req enrollAgentRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
defer r.Body.Close()
return req, nil
}

View file

@ -1,36 +0,0 @@
package service
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecodeEnrollAgentRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeEnrollAgentRequest(context.Background(), request)
require.Nil(t, err)
params := r.(enrollAgentRequest)
assert.Equal(t, "secret", params.EnrollSecret)
assert.Equal(t, "uuid", params.HostIdentifier)
}).Methods("POST")
var body bytes.Buffer
body.Write([]byte(`{
"enroll_secret": "secret",
"host_identifier": "uuid"
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/", &body),
)
}

View file

@ -1,41 +0,0 @@
package service
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/sso"
)
func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
req.Email = strings.ToLower(req.Email)
return req, nil
}
func decodeInitiateSSORequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req initiateSSORequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
return nil, err
}
return req, nil
}
func decodeCallbackSSORequest(ctx context.Context, r *http.Request) (interface{}, error) {
err := r.ParseForm()
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "decode sso callback")
}
authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse"))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "decoding sso callback")
}
return authResponse, nil
}

View file

@ -1,49 +0,0 @@
package service
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
)
func TestDecodeLoginRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeLoginRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(loginRequest)
assert.Equal(t, "foo", params.Email)
assert.Equal(t, "bar", params.Password)
}).Methods("POST")
t.Run("lowercase email", func(t *testing.T) {
var body bytes.Buffer
body.Write([]byte(`{
"email": "foo",
"password": "bar"
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/api/v1/fleet/login", &body),
)
})
t.Run("uppercase email", func(t *testing.T) {
var body bytes.Buffer
body.Write([]byte(`{
"email": "Foo",
"password": "bar"
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/api/v1/fleet/login", &body),
)
})
}

View file

@ -1,42 +0,0 @@
package service
import (
"context"
"encoding/json"
"net/http"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
)
func decodeCreateUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req createUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
return req, nil
}
func decodePerformRequiredPasswordResetRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req performRequiredPasswordResetRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, ctxerr.Wrap(ctx, err, "decoding JSON")
}
return req, nil
}
func decodeForgotPasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req forgotPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
return req, nil
}
func decodeResetPasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req resetPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
return req, nil
}

View file

@ -1,35 +0,0 @@
package service
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
)
func TestDecodeResetPasswordRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/users/{id}/password", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeResetPasswordRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(resetPasswordRequest)
assert.Equal(t, "bar", params.NewPassword)
assert.Equal(t, "baz", params.PasswordResetToken)
}).Methods("POST")
var body bytes.Buffer
body.Write([]byte(`{
"new_password": "bar",
"password_reset_token": "baz"
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/api/v1/fleet/users/1/password", &body),
)
}

View file

@ -26,7 +26,7 @@ func applyUserRoleSpecsEndpoint(ctx context.Context, request interface{}, svc fl
return applyUserRoleSpecsResponse{}, nil
}
func (svc Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRoleSpec) error {
func (svc *Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRoleSpec) error {
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionWrite); err != nil {
return err
}
@ -61,7 +61,7 @@ func (svc Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRol
return svc.ds.SaveUsers(ctx, users)
}
func (svc Service) checkAtLeastOneAdmin(ctx context.Context, user *fleet.User, spec *fleet.UserRoleSpec, email string) error {
func (svc *Service) checkAtLeastOneAdmin(ctx context.Context, user *fleet.User, spec *fleet.UserRoleSpec, email string) error {
if null.StringFromPtr(user.GlobalRole).ValueOrZero() == fleet.RoleAdmin &&
null.StringFromPtr(spec.GlobalRole).ValueOrZero() != fleet.RoleAdmin {
users, err := svc.ds.ListUsers(ctx, fleet.UserListOptions{})

View file

@ -6,6 +6,8 @@ import (
"encoding/base64"
"errors"
"html/template"
"net/http"
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/authz"
@ -49,6 +51,10 @@ func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet
return nil, err
}
if err := p.VerifyAdminCreate(); err != nil {
return nil, ctxerr.Wrap(ctx, err, "verify user payload")
}
if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil {
return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email)
}
@ -61,6 +67,49 @@ func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet
return svc.newUser(ctx, p)
}
////////////////////////////////////////////////////////////////////////////////
// Create User From Invite
////////////////////////////////////////////////////////////////////////////////
func createUserFromInviteEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*createUserRequest)
user, err := svc.CreateUserFromInvite(ctx, req.UserPayload)
if err != nil {
return createUserResponse{Err: err}, nil
}
return createUserResponse{User: user}, nil
}
func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
// skipauth: There is no viewer context at this point. We rely on verifying
// the invite for authNZ.
svc.authz.SkipAuthorization(ctx)
if err := p.VerifyInviteCreate(); err != nil {
return nil, ctxerr.Wrap(ctx, err, "verify user payload")
}
invite, err := svc.VerifyInvite(ctx, *p.InviteToken)
if err != nil {
return nil, err
}
// set the payload role property based on an existing invite.
p.GlobalRole = invite.GlobalRole.Ptr()
p.Teams = &invite.Teams
user, err := svc.newUser(ctx, p)
if err != nil {
return nil, err
}
err = svc.ds.DeleteInvite(ctx, invite.ID)
if err != nil {
return nil, err
}
return user, nil
}
////////////////////////////////////////////////////////////////////////////////
// List Users
////////////////////////////////////////////////////////////////////////////////
@ -215,6 +264,14 @@ func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPay
return nil, err
}
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, ctxerr.New(ctx, "viewer not present") // should never happen, authorize would've failed
}
if err := p.VerifyModify(vc.UserID() == userID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "verify user payload")
}
if p.GlobalRole != nil || p.Teams != nil {
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
return nil, err
@ -386,13 +443,21 @@ func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string)
return err
}
if oldPass == "" {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"))
}
if newPass == "" {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "New password cannot be empty"))
}
if err := fleet.ValidatePasswordRequirements(newPass); err != nil {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", err.Error()))
}
if vc.User.SSOEnabled {
return ctxerr.New(ctx, "change password for single sign on user not allowed")
}
if err := vc.User.ValidatePassword(newPass); err == nil {
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"))
}
if err := vc.User.ValidatePassword(oldPass); err != nil {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match"))
}
@ -622,3 +687,245 @@ func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, em
func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error {
return svc.ds.SaveUser(ctx, user)
}
////////////////////////////////////////////////////////////////////////////////
// Perform Required Password Reset
////////////////////////////////////////////////////////////////////////////////
type performRequiredPasswordResetRequest struct {
Password string `json:"new_password"`
ID uint `json:"id"`
}
type performRequiredPasswordResetResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r performRequiredPasswordResetResponse) error() error { return r.Err }
func performRequiredPasswordResetEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*performRequiredPasswordResetRequest)
user, err := svc.PerformRequiredPasswordReset(ctx, req.Password)
if err != nil {
return performRequiredPasswordResetResponse{Err: err}, nil
}
return performRequiredPasswordResetResponse{User: user}, nil
}
func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password string) (*fleet.User, error) {
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, fleet.ErrNoContext
}
if !vc.CanPerformPasswordReset() {
return nil, fleet.NewPermissionError("cannot reset password")
}
user := vc.User
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
return nil, err
}
if user.SSOEnabled {
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
if !user.IsAdminForcedPasswordReset() {
return nil, ctxerr.New(ctx, "user does not require password reset")
}
// prevent setting the same password
if err := user.ValidatePassword(password); err == nil {
return nil, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
}
user.AdminForcedPasswordReset = false
err := svc.setNewPassword(ctx, user, password)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "setting new password")
}
// Sessions should already have been cleared when the reset was
// required
return user, nil
}
// setNewPassword is a helper for changing a user's password. It should be
// called to set the new password after proper authorization has been
// performed.
func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, password string) error {
err := user.SetPassword(password, svc.config.Auth.SaltKeySize, svc.config.Auth.BcryptCost)
if err != nil {
return ctxerr.Wrap(ctx, err, "setting new password")
}
if user.SSOEnabled {
return ctxerr.New(ctx, "set password for single sign on user not allowed")
}
err = svc.saveUser(ctx, user)
if err != nil {
return ctxerr.Wrap(ctx, err, "saving changed password")
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// Reset Password
////////////////////////////////////////////////////////////////////////////////
type resetPasswordRequest struct {
PasswordResetToken string `json:"password_reset_token"`
NewPassword string `json:"new_password"`
}
type resetPasswordResponse struct {
Err error `json:"error,omitempty"`
}
func (r resetPasswordResponse) error() error { return r.Err }
func resetPasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*resetPasswordRequest)
err := svc.ResetPassword(ctx, req.PasswordResetToken, req.NewPassword)
return resetPasswordResponse{Err: err}, nil
}
func (svc *Service) ResetPassword(ctx context.Context, token, password string) error {
// skipauth: No viewer context available. The user is locked out of their
// account and authNZ is performed entirely by providing a valid password
// reset token.
svc.authz.SkipAuthorization(ctx)
if token == "" {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("token", "Token cannot be empty field"))
}
if password == "" {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "New password cannot be empty field"))
}
if err := fleet.ValidatePasswordRequirements(password); err != nil {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", err.Error()))
}
reset, err := svc.ds.FindPasswordResetByToken(ctx, token)
if err != nil {
return ctxerr.Wrap(ctx, err, "looking up reset by token")
}
user, err := svc.ds.UserByID(ctx, reset.UserID)
if err != nil {
return ctxerr.Wrap(ctx, err, "retrieving user")
}
if user.SSOEnabled {
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
// prevent setting the same password
if err := user.ValidatePassword(password); err == nil {
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
}
err = svc.setNewPassword(ctx, user, password)
if err != nil {
return ctxerr.Wrap(ctx, err, "setting new password")
}
// delete password reset tokens for user
if err := svc.ds.DeletePasswordResetRequestsForUser(ctx, user.ID); err != nil {
return ctxerr.Wrap(ctx, err, "delete password reset requests")
}
// Clear sessions so that any other browsers will have to log in with
// the new password
if err := svc.ds.DestroyAllSessionsForUser(ctx, user.ID); err != nil {
return ctxerr.Wrap(ctx, err, "delete user sessions")
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// Forgot Password
////////////////////////////////////////////////////////////////////////////////
type forgotPasswordRequest struct {
Email string `json:"email"`
}
type forgotPasswordResponse struct {
Err error `json:"error,omitempty"`
}
func (r forgotPasswordResponse) error() error { return r.Err }
func (r forgotPasswordResponse) status() int { return http.StatusAccepted }
func forgotPasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*forgotPasswordRequest)
// Any error returned by the service should not be returned to the
// client to prevent information disclosure (it will be logged in the
// server logs).
_ = svc.RequestPasswordReset(ctx, req.Email)
return forgotPasswordResponse{}, nil
}
func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error {
// skipauth: No viewer context available. The user is locked out of their
// account and trying to reset their password.
svc.authz.SkipAuthorization(ctx)
// Regardless of error, sleep until the request has taken at least 1 second.
// This means that any request to this method will take ~1s and frustrate a timing attack.
defer func(start time.Time) {
time.Sleep(time.Until(start.Add(1 * time.Second)))
}(time.Now())
user, err := svc.ds.UserByEmail(ctx, email)
if err != nil {
return err
}
if user.SSOEnabled {
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
if err != nil {
return err
}
token := base64.URLEncoding.EncodeToString([]byte(random))
request := &fleet.PasswordResetRequest{
ExpiresAt: time.Now().Add(time.Hour * 24),
UserID: user.ID,
Token: token,
}
_, err = svc.ds.NewPasswordResetRequest(ctx, request)
if err != nil {
return err
}
config, err := svc.ds.AppConfig(ctx)
if err != nil {
return err
}
resetEmail := fleet.Email{
Subject: "Reset Your Fleet Password",
To: []string{user.Email},
Config: config,
Mailer: &mail.PasswordResetMailer{
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
AssetURL: getAssetURL(),
Token: token,
},
}
return svc.mailService.SendEmail(resetEmail)
}
func (svc *Service) ListAvailableTeamsForUser(ctx context.Context, user *fleet.User) ([]*fleet.TeamSummary, error) {
// skipauth: No authorization check needed due to implementation returning
// only license error.
svc.authz.SkipAuthorization(ctx)
return nil, fleet.ErrMissingLicense
}

View file

@ -2,6 +2,7 @@ package service
import (
"context"
"database/sql"
"errors"
"testing"
"time"
@ -471,6 +472,7 @@ func testUsersChangePassword(t *testing.T, ds *mysql.Datastore) {
anyErr: true,
},
{ // missing old password
user: users["user1@example.com"],
newPassword: "123cataaa!",
wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"),
},
@ -540,3 +542,141 @@ func testUsersRequirePasswordReset(t *testing.T, ds *mysql.Datastore) {
})
}
}
func TestPerformRequiredPasswordReset(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
for _, tt := range testUsers {
t.Run(tt.Email, func(t *testing.T) {
user, err := ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
ctx := context.Background()
_, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, nil)
session, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID})
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, session)
// should error when reset not required
_, err = svc.RequirePasswordReset(ctx, user.ID, false)
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, session)
_, err = svc.PerformRequiredPasswordReset(ctx, "new_pass")
require.NotNil(t, err)
_, err = svc.RequirePasswordReset(ctx, user.ID, true)
require.Nil(t, err)
ctx = refreshCtx(t, ctx, user, ds, session)
// should error when using same password
_, err = svc.PerformRequiredPasswordReset(ctx, tt.PlaintextPassword)
require.Equal(t, "validation failed: new_password cannot reuse old password", err.Error())
// should succeed with good new password
u, err := svc.PerformRequiredPasswordReset(ctx, "new_pass")
require.Nil(t, err)
assert.False(t, u.AdminForcedPasswordReset)
ctx = context.Background()
// Now user should be able to login with new password
u, _, err = svc.Login(ctx, tt.Email, "new_pass")
require.Nil(t, err)
assert.False(t, u.AdminForcedPasswordReset)
})
}
}
func TestResetPassword(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
passwordResetTests := []struct {
token string
newPassword string
wantErr error
}{
{ // all good
token: "abcd",
newPassword: "123cat!",
},
{ // prevent reuse
token: "abcd",
newPassword: "123cat!",
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
},
{ // bad token
token: "dcbaz",
newPassword: "123cat!",
wantErr: sql.ErrNoRows,
},
{ // missing token
newPassword: "123cat!",
wantErr: fleet.NewInvalidArgumentError("token", "Token cannot be empty field"),
},
}
for _, tt := range passwordResetTests {
t.Run("", func(t *testing.T) {
request := &fleet.PasswordResetRequest{
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
CreateTimestamp: fleet.CreateTimestamp{
CreatedAt: time.Now(),
},
UpdateTimestamp: fleet.UpdateTimestamp{
UpdatedAt: time.Now(),
},
},
ExpiresAt: time.Now().Add(time.Hour * 24),
UserID: 1,
Token: "abcd",
}
_, err := ds.NewPasswordResetRequest(context.Background(), request)
assert.Nil(t, err)
serr := svc.ResetPassword(test.UserContext(&fleet.User{ID: 1}), tt.token, tt.newPassword)
if tt.wantErr != nil {
assert.Equal(t, tt.wantErr.Error(), ctxerr.Cause(serr).Error())
} else {
assert.Nil(t, serr)
}
})
}
}
func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context {
reloadedUser, err := ds.UserByEmail(ctx, user.Email)
require.NoError(t, err)
return viewer.NewContext(ctx, viewer.Viewer{User: reloadedUser, Session: session})
}
func TestAuthenticatedUser(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
createTestUsers(t, ds)
svc := newTestService(ds, nil, nil)
admin1, err := ds.UserByEmail(context.Background(), "admin1@example.com")
assert.Nil(t, err)
admin1Session, err := ds.NewSession(context.Background(), &fleet.Session{
UserID: admin1.ID,
Key: "admin1",
})
assert.Nil(t, err)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: admin1, Session: admin1Session})
user, err := svc.AuthenticatedUser(ctx)
assert.Nil(t, err)
assert.Equal(t, user, admin1)
}

View file

@ -1,202 +0,0 @@
package service
import (
"context"
"errors"
"unicode"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func (mw validationMiddleware) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
invalid := &fleet.InvalidArgumentError{}
if p.Name == nil {
invalid.Append("name", "Full name missing required argument")
} else {
if *p.Name == "" {
invalid.Append("name", "Full name cannot be empty")
}
}
// we don't need a password for single sign on
if p.SSOInvite == nil || !*p.SSOInvite {
if p.Password == nil {
invalid.Append("password", "Password missing required argument")
} else {
if *p.Password == "" {
invalid.Append("password", "Password cannot be empty")
}
if err := validatePasswordRequirements(*p.Password); err != nil {
invalid.Append("password", err.Error())
}
}
}
if p.Email == nil {
invalid.Append("email", "Email missing required argument")
} else {
if *p.Email == "" {
invalid.Append("email", "Email cannot be empty")
}
}
if p.InviteToken == nil {
invalid.Append("invite_token", "Invite token missing required argument")
} else {
if *p.InviteToken == "" {
invalid.Append("invite_token", "Invite token cannot be empty")
}
}
if invalid.HasErrors() {
return nil, ctxerr.Wrap(ctx, invalid)
}
return mw.Service.CreateUserFromInvite(ctx, p)
}
func (mw validationMiddleware) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
invalid := &fleet.InvalidArgumentError{}
if p.Name == nil {
invalid.Append("name", "Full name missing required argument")
} else {
if *p.Name == "" {
invalid.Append("name", "Full name cannot be empty")
}
}
// we don't need a password for single sign on
if (p.SSOInvite == nil || !*p.SSOInvite) && (p.SSOEnabled == nil || !*p.SSOEnabled) {
if p.Password == nil {
invalid.Append("password", "Password missing required argument")
} else {
if *p.Password == "" {
invalid.Append("password", "Password cannot be empty")
}
// Skip password validation in the case of admin created users
}
}
if p.SSOEnabled != nil && *p.SSOEnabled && p.Password != nil && len(*p.Password) > 0 {
invalid.Append("password", "not allowed for SSO users")
}
if p.Email == nil {
invalid.Append("email", "Email missing required argument")
} else {
if *p.Email == "" {
invalid.Append("email", "Email cannot be empty")
}
}
if p.InviteToken != nil {
invalid.Append("invite_token", "Invite token should not be specified with admin user creation")
}
if invalid.HasErrors() {
return nil, ctxerr.Wrap(ctx, invalid)
}
return mw.Service.CreateUser(ctx, p)
}
func (mw validationMiddleware) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) {
invalid := &fleet.InvalidArgumentError{}
if p.Name != nil {
if *p.Name == "" {
invalid.Append("name", "Full name cannot be empty")
}
}
if p.Email != nil {
if *p.Email == "" {
invalid.Append("email", "Email cannot be empty")
}
// if the user is not an admin, or if an admin is changing their own email
// address a password is required,
if passwordRequiredForEmailChange(ctx, userID, invalid) {
if p.Password == nil {
invalid.Append("password", "Password cannot be empty if email is changed")
}
}
}
if invalid.HasErrors() {
return nil, ctxerr.Wrap(ctx, invalid)
}
return mw.Service.ModifyUser(ctx, userID, p)
}
func passwordRequiredForEmailChange(ctx context.Context, uid uint, invalid *fleet.InvalidArgumentError) bool {
vc, ok := viewer.FromContext(ctx)
if !ok {
invalid.Append("viewer", "Viewer not present")
return false
}
// if a user is changing own email need a password no matter what
return vc.UserID() == uid
}
func (mw validationMiddleware) ChangePassword(ctx context.Context, oldPass, newPass string) error {
invalid := &fleet.InvalidArgumentError{}
if oldPass == "" {
invalid.Append("old_password", "Old password cannot be empty")
}
if newPass == "" {
invalid.Append("new_password", "New password cannot be empty")
}
if err := validatePasswordRequirements(newPass); err != nil {
invalid.Append("new_password", err.Error())
}
if invalid.HasErrors() {
return ctxerr.Wrap(ctx, invalid)
}
return mw.Service.ChangePassword(ctx, oldPass, newPass)
}
func (mw validationMiddleware) ResetPassword(ctx context.Context, token, password string) error {
invalid := &fleet.InvalidArgumentError{}
if token == "" {
invalid.Append("token", "Token cannot be empty field")
}
if password == "" {
invalid.Append("new_password", "New password cannot be empty field")
}
if err := validatePasswordRequirements(password); err != nil {
invalid.Append("new_password", err.Error())
}
if invalid.HasErrors() {
return ctxerr.Wrap(ctx, invalid)
}
return mw.Service.ResetPassword(ctx, token, password)
}
// Requirements for user password:
// at least 7 character length
// at least 1 symbol
// at least 1 number
func validatePasswordRequirements(password string) error {
var (
number bool
symbol bool
)
for _, s := range password {
switch {
case unicode.IsNumber(s):
number = true
case unicode.IsPunct(s) || unicode.IsSymbol(s):
symbol = true
}
}
if len(password) >= 7 &&
number &&
symbol {
return nil
}
return errors.New("Password does not meet validation requirements")
}