mirror of
https://github.com/fleetdm/fleet
synced 2026-05-06 06:48:54 +00:00
Migrate special-case endpoints to new pattern (#4511)
This commit is contained in:
parent
c14640ca84
commit
c8bc026d6f
55 changed files with 5091 additions and 5100 deletions
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
|
@ -74,7 +75,10 @@ func (ds *Datastore) VerifyEnrollSecret(ctx context.Context, secret string) (*fl
|
|||
var s fleet.EnrollSecret
|
||||
err := sqlx.GetContext(ctx, ds.reader, &s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret)
|
||||
if err != nil {
|
||||
return nil, ctxerr.New(ctx, "no matching secret found")
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ctxerr.New(ctx, "no matching secret found")
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "verify enroll secret")
|
||||
}
|
||||
|
||||
return &s, nil
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ func (ds *Datastore) DeletePasswordResetRequestsForUser(ctx context.Context, use
|
|||
|
||||
return nil
|
||||
}
|
||||
func (ds *Datastore) FindPassswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
|
||||
|
||||
func (ds *Datastore) FindPasswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
|
||||
sqlStatement := `
|
||||
SELECT * FROM password_reset_requests
|
||||
WHERE token = ? LIMIT 1
|
||||
|
|
@ -48,5 +49,4 @@ func (ds *Datastore) FindPassswordResetByToken(ctx context.Context, token string
|
|||
}
|
||||
|
||||
return passwordResetRequest, nil
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ type Datastore interface {
|
|||
|
||||
NewPasswordResetRequest(ctx context.Context, req *PasswordResetRequest) (*PasswordResetRequest, error)
|
||||
DeletePasswordResetRequestsForUser(ctx context.Context, userID uint) error
|
||||
FindPassswordResetByToken(ctx context.Context, token string) (*PasswordResetRequest, error)
|
||||
FindPasswordResetByToken(ctx context.Context, token string) (*PasswordResetRequest, error)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SessionStore is the abstract interface that all session backends must conform to.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
package fleet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"unicode"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
|
@ -69,6 +71,104 @@ type UserPayload struct {
|
|||
Teams *[]UserTeam `json:"teams,omitempty"`
|
||||
}
|
||||
|
||||
func (p *UserPayload) VerifyInviteCreate() error {
|
||||
invalid := &InvalidArgumentError{}
|
||||
if p.Name == nil {
|
||||
invalid.Append("name", "Full name missing required argument")
|
||||
} else if *p.Name == "" {
|
||||
invalid.Append("name", "Full name cannot be empty")
|
||||
}
|
||||
|
||||
// we don't need a password for single sign on
|
||||
if p.SSOInvite == nil || !*p.SSOInvite {
|
||||
if p.Password == nil {
|
||||
invalid.Append("password", "Password missing required argument")
|
||||
} else if *p.Password == "" {
|
||||
invalid.Append("password", "Password cannot be empty")
|
||||
} else if err := ValidatePasswordRequirements(*p.Password); err != nil {
|
||||
invalid.Append("password", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if p.Email == nil {
|
||||
invalid.Append("email", "Email missing required argument")
|
||||
} else if *p.Email == "" {
|
||||
invalid.Append("email", "Email cannot be empty")
|
||||
}
|
||||
|
||||
if p.InviteToken == nil {
|
||||
invalid.Append("invite_token", "Invite token missing required argument")
|
||||
} else if *p.InviteToken == "" {
|
||||
invalid.Append("invite_token", "Invite token cannot be empty")
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return invalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *UserPayload) VerifyAdminCreate() error {
|
||||
invalid := &InvalidArgumentError{}
|
||||
if p.Name == nil {
|
||||
invalid.Append("name", "Full name missing required argument")
|
||||
} else if *p.Name == "" {
|
||||
invalid.Append("name", "Full name cannot be empty")
|
||||
}
|
||||
|
||||
// we don't need a password for single sign on
|
||||
if (p.SSOInvite == nil || !*p.SSOInvite) && (p.SSOEnabled == nil || !*p.SSOEnabled) {
|
||||
if p.Password == nil {
|
||||
invalid.Append("password", "Password missing required argument")
|
||||
} else if *p.Password == "" {
|
||||
invalid.Append("password", "Password cannot be empty")
|
||||
}
|
||||
// Skip password validation in the case of admin created users
|
||||
}
|
||||
|
||||
if p.SSOEnabled != nil && *p.SSOEnabled && p.Password != nil && len(*p.Password) > 0 {
|
||||
invalid.Append("password", "not allowed for SSO users")
|
||||
}
|
||||
|
||||
if p.Email == nil {
|
||||
invalid.Append("email", "Email missing required argument")
|
||||
} else if *p.Email == "" {
|
||||
invalid.Append("email", "Email cannot be empty")
|
||||
}
|
||||
|
||||
if p.InviteToken != nil {
|
||||
invalid.Append("invite_token", "Invite token should not be specified with admin user creation")
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return invalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *UserPayload) VerifyModify(ownUser bool) error {
|
||||
invalid := &InvalidArgumentError{}
|
||||
if p.Name != nil && *p.Name == "" {
|
||||
invalid.Append("name", "Full name cannot be empty")
|
||||
}
|
||||
|
||||
if p.Email != nil {
|
||||
if *p.Email == "" {
|
||||
invalid.Append("email", "Email cannot be empty")
|
||||
}
|
||||
// if the user is not an admin, or if an admin is changing their own email
|
||||
// address a password is required,
|
||||
if ownUser && p.Password == nil {
|
||||
invalid.Append("password", "Password cannot be empty if email is changed")
|
||||
}
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return invalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// User creates a user from payload.
|
||||
func (p UserPayload) User(keySize, cost int) (*User, error) {
|
||||
user := &User{
|
||||
|
|
@ -130,3 +230,31 @@ func (u *User) SetPassword(plaintext string, keySize, cost int) error {
|
|||
u.Password = hashed
|
||||
return nil
|
||||
}
|
||||
|
||||
// Requirements for user password:
|
||||
// at least 7 character length
|
||||
// at least 1 symbol
|
||||
// at least 1 number
|
||||
func ValidatePasswordRequirements(password string) error {
|
||||
var (
|
||||
number bool
|
||||
symbol bool
|
||||
)
|
||||
|
||||
for _, s := range password {
|
||||
switch {
|
||||
case unicode.IsNumber(s):
|
||||
number = true
|
||||
case unicode.IsPunct(s) || unicode.IsSymbol(s):
|
||||
symbol = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(password) >= 7 &&
|
||||
number &&
|
||||
symbol {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("Password does not meet validation requirements")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,3 +42,37 @@ func newTestUser(t *testing.T, password, email string) *User {
|
|||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPasswordRequirements(t *testing.T) {
|
||||
passwordTests := []struct {
|
||||
password string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
password: "foobar",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
password: "foobarbaz",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
password: "foobarbaz!",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
password: "foobarbaz!3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range passwordTests {
|
||||
t.Run(tt.password, func(t *testing.T) {
|
||||
err := ValidatePasswordRequirements(tt.password)
|
||||
if tt.wantErr {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ type NewPasswordResetRequestFunc func(ctx context.Context, req *fleet.PasswordRe
|
|||
|
||||
type DeletePasswordResetRequestsForUserFunc func(ctx context.Context, userID uint) error
|
||||
|
||||
type FindPassswordResetByTokenFunc func(ctx context.Context, token string) (*fleet.PasswordResetRequest, error)
|
||||
type FindPasswordResetByTokenFunc func(ctx context.Context, token string) (*fleet.PasswordResetRequest, error)
|
||||
|
||||
type SessionByKeyFunc func(ctx context.Context, key string) (*fleet.Session, error)
|
||||
|
||||
|
|
@ -650,8 +650,8 @@ type DataStore struct {
|
|||
DeletePasswordResetRequestsForUserFunc DeletePasswordResetRequestsForUserFunc
|
||||
DeletePasswordResetRequestsForUserFuncInvoked bool
|
||||
|
||||
FindPassswordResetByTokenFunc FindPassswordResetByTokenFunc
|
||||
FindPassswordResetByTokenFuncInvoked bool
|
||||
FindPasswordResetByTokenFunc FindPasswordResetByTokenFunc
|
||||
FindPasswordResetByTokenFuncInvoked bool
|
||||
|
||||
SessionByKeyFunc SessionByKeyFunc
|
||||
SessionByKeyFuncInvoked bool
|
||||
|
|
@ -1384,9 +1384,9 @@ func (s *DataStore) DeletePasswordResetRequestsForUser(ctx context.Context, user
|
|||
return s.DeletePasswordResetRequestsForUserFunc(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *DataStore) FindPassswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
|
||||
s.FindPassswordResetByTokenFuncInvoked = true
|
||||
return s.FindPassswordResetByTokenFunc(ctx, token)
|
||||
func (s *DataStore) FindPasswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) {
|
||||
s.FindPasswordResetByTokenFuncInvoked = true
|
||||
return s.FindPasswordResetByTokenFunc(ctx, token)
|
||||
}
|
||||
|
||||
func (s *DataStore) SessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
|
|
|
|||
|
|
@ -232,3 +232,75 @@ func (svc *Service) CarveBegin(ctx context.Context, payload fleet.CarveBeginPayl
|
|||
|
||||
return carve, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Receive Block for File Carve
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type carveBlockRequest struct {
|
||||
BlockId int64 `json:"block_id"`
|
||||
SessionId string `json:"session_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
|
||||
type carveBlockResponse struct {
|
||||
Success bool `json:"success,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r carveBlockResponse) error() error { return r.Err }
|
||||
|
||||
func carveBlockEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*carveBlockRequest)
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
SessionId: req.SessionId,
|
||||
RequestId: req.RequestId,
|
||||
BlockId: req.BlockId,
|
||||
Data: req.Data,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(ctx, payload)
|
||||
if err != nil {
|
||||
return carveBlockResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
return carveBlockResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
// Note host did not authenticate via node key. We need to authenticate them
|
||||
// by the session ID and request ID
|
||||
carve, err := svc.carveStore.CarveBySessionId(ctx, payload.SessionId)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "find carve by session_id")
|
||||
}
|
||||
|
||||
if payload.RequestId != carve.RequestId {
|
||||
return errors.New("request_id does not match")
|
||||
}
|
||||
|
||||
// Request is now authenticated
|
||||
|
||||
if payload.BlockId > carve.BlockCount-1 {
|
||||
return fmt.Errorf("block_id exceeds expected max (%d): %d", carve.BlockCount-1, payload.BlockId)
|
||||
}
|
||||
|
||||
if payload.BlockId != carve.MaxBlock+1 {
|
||||
return fmt.Errorf("block_id does not match expected block (%d): %d", carve.MaxBlock+1, payload.BlockId)
|
||||
}
|
||||
|
||||
if int64(len(payload.Data)) > carve.BlockSize {
|
||||
return fmt.Errorf("exceeded declared block size %d: %d", carve.BlockSize, len(payload.Data))
|
||||
}
|
||||
|
||||
if err := svc.carveStore.NewBlock(ctx, carve, payload.BlockId, payload.Data); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "save block data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
|
|
@ -184,3 +186,420 @@ func TestCarveGetBlockExpired(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired carve")
|
||||
}
|
||||
|
||||
func TestCarveBegin(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
expectedMetadata := fleet.CarveMetadata{
|
||||
ID: 7,
|
||||
HostId: host.ID,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
|
||||
metadata.ID = 7
|
||||
return metadata, nil
|
||||
}
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
metadata, err := svc.CarveBegin(ctx, payload)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, metadata.SessionId)
|
||||
metadata.SessionId = "" // Clear this before comparison
|
||||
metadata.Name = "" // Clear this before comparison
|
||||
metadata.CreatedAt = time.Time{} // Clear this before comparison
|
||||
assert.Equal(t, expectedMetadata, *metadata)
|
||||
}
|
||||
|
||||
func TestCarveBeginNewCarveError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
|
||||
return nil, errors.New("ouch!")
|
||||
}
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ouch!")
|
||||
}
|
||||
|
||||
func TestCarveBeginEmptyError(t *testing.T) {
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
ctx := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1})
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if id != 1 {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{}, nil
|
||||
}
|
||||
|
||||
_, err := svc.CarveBegin(ctx, fleet.CarveBeginPayload{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size must be greater than 0")
|
||||
}
|
||||
|
||||
func TestCarveBeginMissingHostError(t *testing.T) {
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
|
||||
_, err := svc.CarveBegin(context.Background(), fleet.CarveBeginPayload{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing host")
|
||||
}
|
||||
|
||||
func TestCarveBeginBlockSizeMaxError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 10,
|
||||
BlockSize: 1024 * 1024 * 1024 * 1024, // 1TB
|
||||
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "block_size exceeds max")
|
||||
}
|
||||
|
||||
func TestCarveBeginCarveSizeMaxError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 1024 * 1024,
|
||||
BlockSize: 10 * 1024 * 1024, // 1TB
|
||||
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size exceeds max")
|
||||
}
|
||||
|
||||
func TestCarveBeginCarveSizeError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 7,
|
||||
BlockSize: 13,
|
||||
CarveSize: 7*13 + 1,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Too big
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size does not match")
|
||||
|
||||
// Too small
|
||||
payload.CarveSize = 6 * 13
|
||||
_, err = svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size does not match")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockGetCarveError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
return nil, errors.New("ouch!")
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
SessionId: sessionId,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ouch!")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockRequestIdError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "not_matching",
|
||||
SessionId: sessionId,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "request_id does not match")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockBlockCountExceedError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 23,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "block_id exceeds expected max")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockBlockCountMatchError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 7,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "block_id does not match")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockBlockSizeError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 16,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :) TOO LONG!!!"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 4,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeded declared block size")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockNewBlockError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
|
||||
return errors.New("kaboom!")
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 4,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "kaboom!")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlock(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 4,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
|
||||
assert.Equal(t, metadata, carve)
|
||||
assert.Equal(t, int64(4), blockId)
|
||||
assert.Equal(t, payload.Data, data)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ms.NewBlockFuncInvoked)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Receive Block for File Carve
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type carveBlockRequest struct {
|
||||
NodeKey string `json:"node_key"`
|
||||
BlockId int64 `json:"block_id"`
|
||||
SessionId string `json:"session_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
|
||||
type carveBlockResponse struct {
|
||||
Success bool `json:"success,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r carveBlockResponse) error() error { return r.Err }
|
||||
|
||||
func makeCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(carveBlockRequest)
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
SessionId: req.SessionId,
|
||||
RequestId: req.RequestId,
|
||||
BlockId: req.BlockId,
|
||||
Data: req.Data,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(ctx, payload)
|
||||
if err != nil {
|
||||
return carveBlockResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
return carveBlockResponse{Success: true}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
type verifyInviteRequest struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
type verifyInviteResponse struct {
|
||||
Invite *fleet.Invite `json:"invite"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r verifyInviteResponse) error() error { return r.Err }
|
||||
|
||||
func makeVerifyInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(verifyInviteRequest)
|
||||
invite, err := svc.VerifyInvite(ctx, req.Token)
|
||||
if err != nil {
|
||||
return verifyInviteResponse{Err: err}, nil
|
||||
}
|
||||
return verifyInviteResponse{Invite: invite}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -137,6 +137,10 @@ func authenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo
|
|||
return logged(authUserFunc)
|
||||
}
|
||||
|
||||
func unauthenticatedRequest(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
|
||||
return logged(next)
|
||||
}
|
||||
|
||||
// logged wraps an endpoint and adds the error if the context supports it
|
||||
func logged(next endpoint.Endpoint) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
||||
|
|
@ -167,16 +171,3 @@ func authViewer(ctx context.Context, sessionKey string, svc fleet.Service) (*vie
|
|||
}
|
||||
return &viewer.Viewer{User: user, Session: session}, nil
|
||||
}
|
||||
|
||||
func canPerformPasswordReset(next endpoint.Endpoint) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fleet.ErrNoContext
|
||||
}
|
||||
if !vc.CanPerformPasswordReset() {
|
||||
return nil, fleet.NewPermissionError("cannot reset password")
|
||||
}
|
||||
return next(ctx, request)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Enroll Agent
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type enrollAgentRequest struct {
|
||||
EnrollSecret string `json:"enroll_secret"`
|
||||
HostIdentifier string `json:"host_identifier"`
|
||||
HostDetails map[string](map[string]string) `json:"host_details"`
|
||||
}
|
||||
|
||||
type enrollAgentResponse struct {
|
||||
NodeKey string `json:"node_key,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r enrollAgentResponse) error() error { return r.Err }
|
||||
|
||||
func makeEnrollAgentEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(enrollAgentRequest)
|
||||
nodeKey, err := svc.EnrollAgent(ctx, req.EnrollSecret, req.HostIdentifier, req.HostDetails)
|
||||
if err != nil {
|
||||
return enrollAgentResponse{Err: err}, nil
|
||||
}
|
||||
return enrollAgentResponse{NodeKey: nodeKey}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"html/template"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Login
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type loginRequest struct {
|
||||
Email string
|
||||
Password string
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
AvailableTeams []*fleet.TeamSummary `json:"available_teams"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r loginResponse) error() error { return r.Err }
|
||||
|
||||
func makeLoginEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(loginRequest)
|
||||
user, token, err := svc.Login(ctx, req.Email, req.Password)
|
||||
if err != nil {
|
||||
return loginResponse{Err: err}, nil
|
||||
}
|
||||
// Add viewer context allow access to service teams for list of available teams
|
||||
v, err := authViewer(ctx, token, svc)
|
||||
if err != nil {
|
||||
return loginResponse{Err: err}, nil
|
||||
}
|
||||
ctx = viewer.NewContext(ctx, *v)
|
||||
availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, fleet.ErrMissingLicense) {
|
||||
availableTeams = []*fleet.TeamSummary{}
|
||||
} else {
|
||||
return loginResponse{Err: err}, nil
|
||||
}
|
||||
}
|
||||
return loginResponse{user, availableTeams, token, nil}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Logout
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type logoutResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r logoutResponse) error() error { return r.Err }
|
||||
|
||||
func makeLogoutEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
err := svc.Logout(ctx)
|
||||
if err != nil {
|
||||
return logoutResponse{Err: err}, nil
|
||||
}
|
||||
return logoutResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type initiateSSORequest struct {
|
||||
RelayURL string `json:"relay_url"`
|
||||
}
|
||||
|
||||
type initiateSSOResponse struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r initiateSSOResponse) error() error { return r.Err }
|
||||
|
||||
func makeInitiateSSOEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(initiateSSORequest)
|
||||
idProviderURL, err := svc.InitiateSSO(ctx, req.RelayURL)
|
||||
if err != nil {
|
||||
return initiateSSOResponse{Err: err}, nil
|
||||
}
|
||||
return initiateSSOResponse{URL: idProviderURL}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type callbackSSOResponse struct {
|
||||
content string
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r callbackSSOResponse) error() error { return r.Err }
|
||||
|
||||
// If html is present we return a web page
|
||||
func (r callbackSSOResponse) html() string { return r.content }
|
||||
|
||||
func makeCallbackSSOEndpoint(svc fleet.Service, urlPrefix string) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
authResponse := request.(fleet.Auth)
|
||||
session, err := svc.CallbackSSO(ctx, authResponse)
|
||||
var resp callbackSSOResponse
|
||||
if err != nil {
|
||||
// redirect to login page on front end if there was some problem,
|
||||
// errors should still be logged
|
||||
session = &fleet.SSOSession{
|
||||
RedirectURL: urlPrefix + "/login",
|
||||
Token: "",
|
||||
}
|
||||
resp.Err = err
|
||||
}
|
||||
relayStateLoadPage := ` <html>
|
||||
<script type='text/javascript'>
|
||||
var redirectURL = {{ .RedirectURL }};
|
||||
window.localStorage.setItem('FLEET::auth_token', '{{ .Token }}');
|
||||
window.location = redirectURL;
|
||||
</script>
|
||||
<body>
|
||||
Redirecting to Fleet at {{ .RedirectURL }} ...
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
tmpl, err := template.New("relayStateLoader").Parse(relayStateLoadPage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var writer bytes.Buffer
|
||||
err = tmpl.Execute(&writer, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.content = writer.String()
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
type ssoSettingsResponse struct {
|
||||
Settings *fleet.SessionSSOSettings `json:"settings,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r ssoSettingsResponse) error() error { return r.Err }
|
||||
|
||||
func makeSSOSettingsEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, unused interface{}) (interface{}, error) {
|
||||
settings, err := svc.SSOSettings(ctx)
|
||||
if err != nil {
|
||||
return ssoSettingsResponse{Err: err}, nil
|
||||
}
|
||||
return ssoSettingsResponse{Settings: settings}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create User With Invite
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(createUserRequest)
|
||||
user, err := svc.CreateUserFromInvite(ctx, req.UserPayload)
|
||||
if err != nil {
|
||||
return createUserResponse{Err: err}, nil
|
||||
}
|
||||
return createUserResponse{User: user}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Reset Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type resetPasswordRequest struct {
|
||||
PasswordResetToken string `json:"password_reset_token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type resetPasswordResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r resetPasswordResponse) error() error { return r.Err }
|
||||
|
||||
func makeResetPasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(resetPasswordRequest)
|
||||
err := svc.ResetPassword(ctx, req.PasswordResetToken, req.NewPassword)
|
||||
return resetPasswordResponse{Err: err}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Perform Required Password Reset
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type performRequiredPasswordResetRequest struct {
|
||||
Password string `json:"new_password"`
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
|
||||
type performRequiredPasswordResetResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r performRequiredPasswordResetResponse) error() error { return r.Err }
|
||||
|
||||
func makePerformRequiredPasswordResetEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(performRequiredPasswordResetRequest)
|
||||
user, err := svc.PerformRequiredPasswordReset(ctx, req.Password)
|
||||
if err != nil {
|
||||
return performRequiredPasswordResetResponse{Err: err}, nil
|
||||
}
|
||||
return performRequiredPasswordResetResponse{User: user}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Forgot Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type forgotPasswordRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type forgotPasswordResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r forgotPasswordResponse) error() error { return r.Err }
|
||||
func (r forgotPasswordResponse) status() int { return http.StatusAccepted }
|
||||
|
||||
func makeForgotPasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(forgotPasswordRequest)
|
||||
// Any error returned by the service should not be returned to the
|
||||
// client to prevent information disclosure (it will be logged in the
|
||||
// server logs).
|
||||
_ = svc.RequestPasswordReset(ctx, req.Email)
|
||||
return forgotPasswordResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -67,6 +67,10 @@ func allFields(ifv reflect.Value) []reflect.StructField {
|
|||
return fields
|
||||
}
|
||||
|
||||
type requestDecoder interface {
|
||||
DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error)
|
||||
}
|
||||
|
||||
// makeDecoder creates a decoder for the type for the struct passed on. If the
|
||||
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
|
||||
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
|
||||
|
|
@ -79,12 +83,22 @@ func allFields(ifv reflect.Value) []reflect.StructField {
|
|||
// follows: `url:"some-id,optional"`.
|
||||
// The "list_options" are optional by default and it'll ignore the optional
|
||||
// portion of the tag.
|
||||
//
|
||||
// If iface implements the requestDecoder interface, it returns a function that
|
||||
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
|
||||
// own decoding.
|
||||
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
|
||||
if iface == nil {
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if rd, ok := iface.(requestDecoder); ok {
|
||||
return func(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
return rd.DecodeRequest(ctx, r)
|
||||
}
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(iface)
|
||||
if t.Kind() != reflect.Struct {
|
||||
panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface))
|
||||
|
|
@ -272,6 +286,7 @@ type authEndpointer struct {
|
|||
startingAtVersion string
|
||||
endingAtVersion string
|
||||
alternativePaths []string
|
||||
customMiddleware []endpoint.Middleware
|
||||
}
|
||||
|
||||
func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
|
||||
|
|
@ -297,6 +312,16 @@ func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts [
|
|||
}
|
||||
}
|
||||
|
||||
func newNoAuthEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
|
||||
return &authEndpointer{
|
||||
svc: svc,
|
||||
opts: opts,
|
||||
r: r,
|
||||
authFunc: unauthenticatedRequest,
|
||||
versions: versions,
|
||||
}
|
||||
}
|
||||
|
||||
var pathReplacer = strings.NewReplacer(
|
||||
"/", "_",
|
||||
"{", "_",
|
||||
|
|
@ -374,7 +399,15 @@ func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler
|
|||
next := func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
return f(ctx, request, e.svc)
|
||||
}
|
||||
return newServer(e.authFunc(e.svc, next), makeDecoder(v), e.opts)
|
||||
endp := e.authFunc(e.svc, next)
|
||||
|
||||
// apply middleware in reverse order so that the first wraps the second
|
||||
// wraps the third etc.
|
||||
for i := len(e.customMiddleware) - 1; i >= 0; i-- {
|
||||
mw := e.customMiddleware[i]
|
||||
endp = mw(endp)
|
||||
}
|
||||
return newServer(endp, makeDecoder(v), e.opts)
|
||||
}
|
||||
|
||||
func (e *authEndpointer) StartingAtVersion(version string) *authEndpointer {
|
||||
|
|
@ -394,3 +427,9 @@ func (e *authEndpointer) WithAltPaths(paths ...string) *authEndpointer {
|
|||
ae.alternativePaths = paths
|
||||
return &ae
|
||||
}
|
||||
|
||||
func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authEndpointer {
|
||||
ae := *e
|
||||
ae.customMiddleware = mws
|
||||
return &ae
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
|
@ -14,6 +15,7 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/gorilla/mux"
|
||||
|
|
@ -251,7 +253,6 @@ func TestUniversalDecoderQueryAndListPlayNice(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEndpointer(t *testing.T) {
|
||||
|
||||
r := mux.NewRouter()
|
||||
ds := new(mock.Store)
|
||||
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
|
|
@ -395,3 +396,72 @@ func TestEndpointer(t *testing.T) {
|
|||
require.False(t, doesItMatch(route.method, route.path, false), route)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointerCustomMiddleware(t *testing.T) {
|
||||
r := mux.NewRouter()
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
fleetAPIOptions := []kithttp.ServerOption{
|
||||
kithttp.ServerBefore(
|
||||
kithttp.PopulateRequestContext,
|
||||
setRequestsContexts(svc),
|
||||
),
|
||||
kithttp.ServerErrorHandler(&errorHandler{kitlog.NewNopLogger()}),
|
||||
kithttp.ServerErrorEncoder(encodeError),
|
||||
kithttp.ServerAfter(
|
||||
kithttp.SetContentType("application/json; charset=utf-8"),
|
||||
logRequestEnd(kitlog.NewNopLogger()),
|
||||
checkLicenseExpiration(svc),
|
||||
),
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
e := newNoAuthEndpointer(svc, fleetAPIOptions, r, "v1")
|
||||
e.GET("/none/", func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
buf.WriteString("H1")
|
||||
return nil, nil
|
||||
}, nil)
|
||||
|
||||
e.WithCustomMiddleware(
|
||||
func(e endpoint.Endpoint) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
||||
buf.WriteString("A")
|
||||
return e(ctx, request)
|
||||
}
|
||||
},
|
||||
func(e endpoint.Endpoint) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
||||
buf.WriteString("B")
|
||||
return e(ctx, request)
|
||||
}
|
||||
},
|
||||
func(e endpoint.Endpoint) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
|
||||
buf.WriteString("C")
|
||||
return e(ctx, request)
|
||||
}
|
||||
},
|
||||
).
|
||||
GET("/mw/", func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
buf.WriteString("H2")
|
||||
return nil, nil
|
||||
}, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/none/", nil)
|
||||
var m1 mux.RouteMatch
|
||||
|
||||
require.True(t, r.Match(req, &m1))
|
||||
rec := httptest.NewRecorder()
|
||||
m1.Handler.ServeHTTP(rec, req)
|
||||
require.Equal(t, "H1", buf.String())
|
||||
|
||||
buf.Reset()
|
||||
req = httptest.NewRequest("GET", "/mw/", nil)
|
||||
var m2 mux.RouteMatch
|
||||
|
||||
require.True(t, r.Match(req, &m2))
|
||||
rec = httptest.NewRecorder()
|
||||
m2.Handler.ServeHTTP(rec, req)
|
||||
require.Equal(t, "ABCH2", buf.String())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"regexp"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
|
|
@ -22,92 +22,6 @@ import (
|
|||
otmiddleware "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux"
|
||||
)
|
||||
|
||||
// FleetEndpoints is a collection of RPC endpoints implemented by the Fleet API.
|
||||
type FleetEndpoints struct {
|
||||
Login endpoint.Endpoint
|
||||
Logout endpoint.Endpoint
|
||||
ForgotPassword endpoint.Endpoint
|
||||
ResetPassword endpoint.Endpoint
|
||||
CreateUserWithInvite endpoint.Endpoint
|
||||
PerformRequiredPasswordReset endpoint.Endpoint
|
||||
VerifyInvite endpoint.Endpoint
|
||||
EnrollAgent endpoint.Endpoint
|
||||
CarveBlock endpoint.Endpoint
|
||||
InitiateSSO endpoint.Endpoint
|
||||
CallbackSSO endpoint.Endpoint
|
||||
SSOSettings endpoint.Endpoint
|
||||
}
|
||||
|
||||
// MakeFleetServerEndpoints creates the Fleet API endpoints.
|
||||
func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore throttled.GCRAStore, logger kitlog.Logger) FleetEndpoints {
|
||||
limiter := ratelimit.NewMiddleware(limitStore)
|
||||
|
||||
return FleetEndpoints{
|
||||
Login: limiter.Limit(
|
||||
throttled.RateQuota{MaxRate: throttled.PerMin(10), MaxBurst: 9})(
|
||||
makeLoginEndpoint(svc),
|
||||
),
|
||||
Logout: logged(makeLogoutEndpoint(svc)),
|
||||
ForgotPassword: limiter.Limit(
|
||||
throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})(
|
||||
logged(makeForgotPasswordEndpoint(svc)),
|
||||
),
|
||||
ResetPassword: logged(makeResetPasswordEndpoint(svc)),
|
||||
CreateUserWithInvite: logged(makeCreateUserFromInviteEndpoint(svc)),
|
||||
VerifyInvite: logged(makeVerifyInviteEndpoint(svc)),
|
||||
InitiateSSO: logged(makeInitiateSSOEndpoint(svc)),
|
||||
CallbackSSO: logged(makeCallbackSSOEndpoint(svc, urlPrefix)),
|
||||
SSOSettings: logged(makeSSOSettingsEndpoint(svc)),
|
||||
|
||||
// PerformRequiredPasswordReset needs only to authenticate the
|
||||
// logged in user
|
||||
PerformRequiredPasswordReset: logged(canPerformPasswordReset(makePerformRequiredPasswordResetEndpoint(svc))),
|
||||
|
||||
// Osquery endpoints
|
||||
EnrollAgent: logged(makeEnrollAgentEndpoint(svc)),
|
||||
// For some reason osquery does not provide a node key with the block
|
||||
// data. Instead the carve session ID should be verified in the service
|
||||
// method.
|
||||
CarveBlock: logged(makeCarveBlockEndpoint(svc)),
|
||||
}
|
||||
}
|
||||
|
||||
type fleetHandlers struct {
|
||||
Login http.Handler
|
||||
Logout http.Handler
|
||||
ForgotPassword http.Handler
|
||||
ResetPassword http.Handler
|
||||
CreateUserWithInvite http.Handler
|
||||
PerformRequiredPasswordReset http.Handler
|
||||
VerifyInvite http.Handler
|
||||
EnrollAgent http.Handler
|
||||
CarveBlock http.Handler
|
||||
InitiateSSO http.Handler
|
||||
CallbackSSO http.Handler
|
||||
SettingsSSO http.Handler
|
||||
}
|
||||
|
||||
func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandlers {
|
||||
newServer := func(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc) http.Handler {
|
||||
e = authzcheck.NewMiddleware().AuthzCheck()(e)
|
||||
return kithttp.NewServer(e, decodeFn, encodeResponse, opts...)
|
||||
}
|
||||
return &fleetHandlers{
|
||||
Login: newServer(e.Login, decodeLoginRequest),
|
||||
Logout: newServer(e.Logout, decodeNoParamsRequest),
|
||||
ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest),
|
||||
ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest),
|
||||
CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest),
|
||||
PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest),
|
||||
VerifyInvite: newServer(e.VerifyInvite, decodeVerifyInviteRequest),
|
||||
EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest),
|
||||
CarveBlock: newServer(e.CarveBlock, decodeCarveBlockRequest),
|
||||
InitiateSSO: newServer(e.InitiateSSO, decodeInitiateSSORequest),
|
||||
CallbackSSO: newServer(e.CallbackSSO, decodeCallbackSSORequest),
|
||||
SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest),
|
||||
}
|
||||
}
|
||||
|
||||
type errorHandler struct {
|
||||
logger kitlog.Logger
|
||||
}
|
||||
|
|
@ -176,18 +90,16 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log
|
|||
),
|
||||
}
|
||||
|
||||
fleetEndpoints := MakeFleetServerEndpoints(svc, config.Server.URLPrefix, limitStore, logger)
|
||||
fleetHandlers := makeKitHandlers(fleetEndpoints, fleetAPIOptions)
|
||||
|
||||
r := mux.NewRouter()
|
||||
if config.Logging.TracingEnabled && config.Logging.TracingType == "opentelemetry" {
|
||||
r.Use(otmiddleware.Middleware("fleet"))
|
||||
}
|
||||
|
||||
attachFleetAPIRoutes(r, fleetHandlers)
|
||||
attachNewStyleFleetAPIRoutes(r, svc, logger, fleetAPIOptions)
|
||||
attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions)
|
||||
|
||||
// Results endpoint is handled different due to websockets use
|
||||
// TODO: this would not work once v1 is deprecated - note that the handler too uses the /v1/ path
|
||||
// and this routes on path prefix, not exact path (unlike the authendpointer struct).
|
||||
r.PathPrefix("/api/v1/fleet/results/").
|
||||
Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, logger)).
|
||||
Name("distributed_query_results")
|
||||
|
|
@ -277,22 +189,9 @@ func addMetrics(r *mux.Router) {
|
|||
r.Walk(walkFn)
|
||||
}
|
||||
|
||||
func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
||||
r.Handle("/api/v1/fleet/login", h.Login).Methods("POST").Name("login")
|
||||
r.Handle("/api/v1/fleet/logout", h.Logout).Methods("POST").Name("logout")
|
||||
r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password")
|
||||
r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password")
|
||||
r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset")
|
||||
r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso")
|
||||
r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config")
|
||||
r.Handle("/api/v1/fleet/sso/callback", h.CallbackSSO).Methods("POST").Name("callback_sso")
|
||||
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
|
||||
r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite")
|
||||
r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent")
|
||||
r.Handle("/api/v1/osquery/carve/block", h.CarveBlock).Methods("POST").Name("carve_block")
|
||||
}
|
||||
func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig,
|
||||
logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption) {
|
||||
|
||||
func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlog.Logger, opts []kithttp.ServerOption) {
|
||||
// user-authenticated endpoints
|
||||
ue := newUserAuthenticatedEndpointer(svc, opts, r, "v1")
|
||||
|
||||
|
|
@ -441,10 +340,42 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlo
|
|||
he.POST("/api/_version_/osquery/distributed/write", submitDistributedQueryResultsEndpoint, submitDistributedQueryResultsRequestShim{})
|
||||
he.POST("/api/_version_/osquery/carve/begin", carveBeginEndpoint, carveBeginRequest{})
|
||||
he.POST("/api/_version_/osquery/log", submitLogsEndpoint, submitLogsRequest{})
|
||||
|
||||
// unauthenticated endpoints - most of those are either login-related,
|
||||
// invite-related or host-enrolling. So they typically do some kind of
|
||||
// one-time authentication by verifying that a valid secret token is provided
|
||||
// with the request.
|
||||
ne := newNoAuthEndpointer(svc, opts, r, "v1")
|
||||
ne.POST("/api/_version_/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{})
|
||||
|
||||
// For some reason osquery does not provide a node key with the block data.
|
||||
// Instead the carve session ID should be verified in the service method.
|
||||
ne.POST("/api/_version_/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{})
|
||||
|
||||
ne.POST("/api/_version_/fleet/perform_required_password_reset", performRequiredPasswordResetEndpoint, performRequiredPasswordResetRequest{})
|
||||
ne.POST("/api/_version_/fleet/users", createUserFromInviteEndpoint, createUserRequest{})
|
||||
ne.GET("/api/_version_/fleet/invites/{token}", verifyInviteEndpoint, verifyInviteRequest{})
|
||||
ne.POST("/api/_version_/fleet/reset_password", resetPasswordEndpoint, resetPasswordRequest{})
|
||||
ne.POST("/api/_version_/fleet/logout", logoutEndpoint, nil)
|
||||
ne.POST("/api/_version_/fleet/sso", initiateSSOEndpoint, initiateSSORequest{})
|
||||
ne.POST("/api/_version_/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{})
|
||||
ne.GET("/api/_version_/fleet/sso", settingsSSOEndpoint, nil)
|
||||
|
||||
limiter := ratelimit.NewMiddleware(limitStore)
|
||||
ne.
|
||||
WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})).
|
||||
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})
|
||||
|
||||
ne.
|
||||
WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: throttled.PerMin(10), MaxBurst: 9})).
|
||||
POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{})
|
||||
}
|
||||
|
||||
// TODO: this duplicates the one in makeKitHandler
|
||||
func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []kithttp.ServerOption) http.Handler {
|
||||
// TODO: some handlers don't have authz checks, and because the SkipAuth call is done only in the
|
||||
// endpoint handler, any middleware that raises errors before the handler is reached will end up
|
||||
// returning authz check missing instead of the more relevant error. Should be addressed as part
|
||||
// of #4406.
|
||||
e = authzcheck.NewMiddleware().AuthzCheck()(e)
|
||||
return kithttp.NewServer(e, decodeFn, encodeResponse, opts...)
|
||||
}
|
||||
|
|
@ -453,15 +384,19 @@ func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []k
|
|||
// If setup hasn't been completed it serves the API with a setup middleware.
|
||||
// If the server is already configured, the default API handler is exposed.
|
||||
func WithSetup(svc fleet.Service, logger kitlog.Logger, next http.Handler) http.HandlerFunc {
|
||||
|
||||
rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
configRouter := http.NewServeMux()
|
||||
// TODO: hard-codes v1 as a path fragment, which would probably not work once we
|
||||
// deprecate it for newer versions, unless we want to treat the setup differently (not versioned?)
|
||||
configRouter.Handle("/api/v1/setup", kithttp.NewServer(
|
||||
makeSetupEndpoint(svc),
|
||||
decodeSetupRequest,
|
||||
encodeResponse,
|
||||
))
|
||||
// whitelist osqueryd endpoints
|
||||
if strings.HasPrefix(r.URL.Path, "/api/v1/osquery") {
|
||||
if rxOsquery.MatchString(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,63 +16,10 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/throttled/throttled/v2/store/memstore"
|
||||
)
|
||||
|
||||
func TestAPIRoutes(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
r := mux.NewRouter()
|
||||
limitStore, _ := memstore.New(0)
|
||||
ke := MakeFleetServerEndpoints(svc, "", limitStore, kitlog.NewNopLogger())
|
||||
kh := makeKitHandlers(ke, nil)
|
||||
attachFleetAPIRoutes(r, kh)
|
||||
handler := mux.NewRouter()
|
||||
handler.PathPrefix("/").Handler(r)
|
||||
|
||||
routes := []struct {
|
||||
verb string
|
||||
uri string
|
||||
}{
|
||||
{
|
||||
verb: "POST",
|
||||
uri: "/api/v1/fleet/users",
|
||||
},
|
||||
{
|
||||
verb: "POST",
|
||||
uri: "/api/v1/fleet/login",
|
||||
},
|
||||
{
|
||||
verb: "POST",
|
||||
uri: "/api/v1/fleet/forgot_password",
|
||||
},
|
||||
{
|
||||
verb: "POST",
|
||||
uri: "/api/v1/fleet/reset_password",
|
||||
},
|
||||
{
|
||||
verb: "POST",
|
||||
uri: "/api/v1/osquery/enroll",
|
||||
},
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
t.Run(fmt.Sprintf(": %v", route.uri), func(st *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(
|
||||
recorder,
|
||||
httptest.NewRequest(route.verb, route.uri, nil),
|
||||
)
|
||||
assert.NotEqual(st, 404, recorder.Code)
|
||||
assert.NotEqual(st, 405, recorder.Code, route.verb) // if it matches a path but with wrong verb
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRoutesConflicts(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,13 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/ghodss/yaml"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
|
@ -978,6 +980,20 @@ func (s *integrationTestSuite) TestInvites() {
|
|||
require.NotZero(t, createInviteResp.Invite.ID)
|
||||
validInvite := *createInviteResp.Invite
|
||||
|
||||
// create user from valid invite - the token was not returned via the
|
||||
// response's json, must get it from the db
|
||||
inv, err := s.ds.Invite(context.Background(), validInvite.ID)
|
||||
require.NoError(t, err)
|
||||
validInviteToken := inv.Token
|
||||
|
||||
// verify the token with valid invite
|
||||
var verifyInvResp verifyInviteResponse
|
||||
s.DoJSON("GET", "/api/v1/fleet/invites/"+validInviteToken, nil, http.StatusOK, &verifyInvResp)
|
||||
require.Equal(t, validInvite.ID, verifyInvResp.Invite.ID)
|
||||
|
||||
// verify the token with an invalid invite
|
||||
s.DoJSON("GET", "/api/v1/fleet/invites/invalid", nil, http.StatusNotFound, &verifyInvResp)
|
||||
|
||||
// create invite without an email
|
||||
createInviteReq = createInviteRequest{InvitePayload: fleet.InvitePayload{
|
||||
Email: nil,
|
||||
|
|
@ -1076,9 +1092,21 @@ func (s *integrationTestSuite) TestInvites() {
|
|||
require.Len(t, verify.Teams, 1)
|
||||
assert.Equal(t, team.ID, verify.Teams[0].ID)
|
||||
|
||||
var createFromInviteResp createUserResponse
|
||||
s.DoJSON("POST", "/api/v1/fleet/users", fleet.UserPayload{
|
||||
Name: ptr.String("Full Name"),
|
||||
Password: ptr.String("pass1word!"),
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String(validInviteToken),
|
||||
}, http.StatusOK, &createFromInviteResp)
|
||||
|
||||
// keep the invite token from the other valid invite (before deleting it)
|
||||
inv, err = s.ds.Invite(context.Background(), createInviteResp.Invite.ID)
|
||||
require.NoError(t, err)
|
||||
deletedInviteToken := inv.Token
|
||||
|
||||
// delete an existing invite
|
||||
var delResp deleteInviteResponse
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", validInvite.ID), nil, http.StatusOK, &delResp)
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", createInviteResp.Invite.ID), nil, http.StatusOK, &delResp)
|
||||
|
||||
// list invites, is now empty
|
||||
|
|
@ -1088,6 +1116,111 @@ func (s *integrationTestSuite) TestInvites() {
|
|||
|
||||
// delete a now non-existing invite
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", validInvite.ID), nil, http.StatusNotFound, &delResp)
|
||||
|
||||
// create user from never used but deleted invite
|
||||
s.DoJSON("POST", "/api/v1/fleet/users", fleet.UserPayload{
|
||||
Name: ptr.String("Full Name"),
|
||||
Password: ptr.String("pass1word!"),
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String(deletedInviteToken),
|
||||
}, http.StatusNotFound, &createFromInviteResp)
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestCreateUserFromInviteErrors() {
|
||||
t := s.T()
|
||||
|
||||
// create a valid invite
|
||||
createInviteReq := createInviteRequest{InvitePayload: fleet.InvitePayload{
|
||||
Email: ptr.String("a@b.c"),
|
||||
Name: ptr.String("A"),
|
||||
GlobalRole: null.StringFrom(fleet.RoleObserver),
|
||||
}}
|
||||
createInviteResp := createInviteResponse{}
|
||||
s.DoJSON("POST", "/api/v1/fleet/invites", createInviteReq, http.StatusOK, &createInviteResp)
|
||||
|
||||
// make sure to delete it on exit
|
||||
defer func() {
|
||||
var delResp deleteInviteResponse
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", createInviteResp.Invite.ID), nil, http.StatusOK, &delResp)
|
||||
}()
|
||||
|
||||
// the token is not returned via the response's json, must get it from the db
|
||||
invite, err := s.ds.Invite(context.Background(), createInviteResp.Invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
pld fleet.UserPayload
|
||||
want int
|
||||
}{
|
||||
{
|
||||
"empty name",
|
||||
fleet.UserPayload{
|
||||
Name: ptr.String(""),
|
||||
Password: ptr.String("pass1word!"),
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String(invite.Token),
|
||||
},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
fleet.UserPayload{
|
||||
Name: ptr.String("Name"),
|
||||
Password: ptr.String("pass1word!"),
|
||||
Email: ptr.String(""),
|
||||
InviteToken: ptr.String(invite.Token),
|
||||
},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
fleet.UserPayload{
|
||||
Name: ptr.String("Name"),
|
||||
Password: ptr.String(""),
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String(invite.Token),
|
||||
},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
"empty token",
|
||||
fleet.UserPayload{
|
||||
Name: ptr.String("Name"),
|
||||
Password: ptr.String("pass1word!"),
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String(""),
|
||||
},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
{
|
||||
"invalid token",
|
||||
fleet.UserPayload{
|
||||
Name: ptr.String("Name"),
|
||||
Password: ptr.String("pass1word!"),
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String("invalid"),
|
||||
},
|
||||
http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
"invalid password",
|
||||
fleet.UserPayload{
|
||||
Name: ptr.String("Name"),
|
||||
Password: ptr.String("password"), // no number or symbol
|
||||
Email: ptr.String("a@b.c"),
|
||||
InviteToken: ptr.String(invite.Token),
|
||||
},
|
||||
http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
var resp createUserResponse
|
||||
s.DoJSON("POST", "/api/v1/fleet/users", c.pld, c.want, &resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestGetHostSummary() {
|
||||
|
|
@ -2302,6 +2435,9 @@ func (s *integrationTestSuite) TestLabelSpecs() {
|
|||
}
|
||||
|
||||
func (s *integrationTestSuite) TestUsers() {
|
||||
// ensure that on exit, the admin token is used
|
||||
defer func() { s.token = s.getTestAdminToken() }()
|
||||
|
||||
t := s.T()
|
||||
|
||||
// list existing users
|
||||
|
|
@ -2324,14 +2460,16 @@ func (s *integrationTestSuite) TestUsers() {
|
|||
|
||||
// create a new user
|
||||
var createResp createUserResponse
|
||||
userRawPwd := "pass"
|
||||
params := fleet.UserPayload{
|
||||
Name: ptr.String("extra"),
|
||||
Email: ptr.String("extra@asd.com"),
|
||||
Password: ptr.String("pass"),
|
||||
Password: ptr.String(userRawPwd),
|
||||
GlobalRole: ptr.String(fleet.RoleObserver),
|
||||
}
|
||||
s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp)
|
||||
assert.NotZero(t, createResp.User.ID)
|
||||
assert.True(t, createResp.User.AdminForcedPasswordReset)
|
||||
u := *createResp.User
|
||||
|
||||
// login as that user and check that teams info is empty
|
||||
|
|
@ -2407,6 +2545,46 @@ func (s *integrationTestSuite) TestUsers() {
|
|||
}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), params, http.StatusNotFound, &modResp)
|
||||
|
||||
// perform a required password change as the user themselves
|
||||
s.token = s.getTestToken(u.Email, userRawPwd)
|
||||
var perfPwdResetResp performRequiredPasswordResetResponse
|
||||
newRawPwd := "new_password!"
|
||||
s.DoJSON("POST", "/api/v1/fleet/perform_required_password_reset", performRequiredPasswordResetRequest{
|
||||
Password: newRawPwd,
|
||||
ID: u.ID,
|
||||
}, http.StatusOK, &perfPwdResetResp)
|
||||
assert.False(t, perfPwdResetResp.User.AdminForcedPasswordReset)
|
||||
oldUserRawPwd := userRawPwd
|
||||
userRawPwd = newRawPwd
|
||||
|
||||
// perform a required password change again, this time it fails as there is no request pending
|
||||
perfPwdResetResp = performRequiredPasswordResetResponse{}
|
||||
newRawPwd = "new_password2!"
|
||||
s.DoJSON("POST", "/api/v1/fleet/perform_required_password_reset", performRequiredPasswordResetRequest{
|
||||
Password: newRawPwd,
|
||||
ID: u.ID,
|
||||
}, http.StatusInternalServerError, &perfPwdResetResp) // TODO: should be 40?, see #4406
|
||||
s.token = s.getTestAdminToken()
|
||||
|
||||
// login as that user to verify that the new password is active (userRawPwd was updated to the new pwd)
|
||||
loginResp = loginResponse{}
|
||||
s.DoJSON("POST", "/api/v1/fleet/login", loginRequest{Email: u.Email, Password: userRawPwd}, http.StatusOK, &loginResp)
|
||||
require.Equal(t, loginResp.User.ID, u.ID)
|
||||
|
||||
// logout for that user
|
||||
s.token = loginResp.Token
|
||||
var logoutResp logoutResponse
|
||||
s.DoJSON("POST", "/api/v1/fleet/logout", nil, http.StatusOK, &logoutResp)
|
||||
|
||||
// logout again, even though not logged in
|
||||
s.DoJSON("POST", "/api/v1/fleet/logout", nil, http.StatusInternalServerError, &logoutResp) // TODO: should be OK even if not logged in, see #4406.
|
||||
|
||||
s.token = s.getTestAdminToken()
|
||||
|
||||
// login as that user with previous pwd fails
|
||||
loginResp = loginResponse{}
|
||||
s.DoJSON("POST", "/api/v1/fleet/login", loginRequest{Email: u.Email, Password: oldUserRawPwd}, http.StatusUnauthorized, &loginResp)
|
||||
|
||||
// require a password reset
|
||||
var reqResetResp requirePasswordResetResponse
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID), map[string]bool{"require": true}, http.StatusOK, &reqResetResp)
|
||||
|
|
@ -3094,6 +3272,245 @@ func (s *integrationTestSuite) TestOsqueryConfig() {
|
|||
assert.Contains(t, errRes["error"], "invalid node key")
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestEnrollHost() {
|
||||
t := s.T()
|
||||
|
||||
// set the enroll secret
|
||||
var applyResp applyEnrollSecretSpecResponse
|
||||
s.DoJSON("POST", "/api/v1/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{
|
||||
Spec: &fleet.EnrollSecretSpec{
|
||||
Secrets: []*fleet.EnrollSecret{{Secret: t.Name()}},
|
||||
},
|
||||
}, http.StatusOK, &applyResp)
|
||||
|
||||
// invalid enroll secret fails
|
||||
j, err := json.Marshal(&enrollAgentRequest{
|
||||
EnrollSecret: "nosuchsecret",
|
||||
HostIdentifier: "abcd",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusUnauthorized)
|
||||
|
||||
// valid enroll secret succeeds
|
||||
j, err = json.Marshal(&enrollAgentRequest{
|
||||
EnrollSecret: t.Name(),
|
||||
HostIdentifier: t.Name(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp enrollAgentResponse
|
||||
hres := s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusOK)
|
||||
defer hres.Body.Close()
|
||||
require.NoError(t, json.NewDecoder(hres.Body).Decode(&resp))
|
||||
require.NotEmpty(t, resp.NodeKey)
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestCarve() {
|
||||
t := s.T()
|
||||
hosts := s.createHosts(t)
|
||||
|
||||
// begin a carve with an invalid node key
|
||||
var errRes map[string]interface{}
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
|
||||
NodeKey: hosts[0].NodeKey + "zzz",
|
||||
BlockCount: 1,
|
||||
BlockSize: 1,
|
||||
CarveSize: 1,
|
||||
CarveId: "c1",
|
||||
}, http.StatusUnauthorized, &errRes)
|
||||
assert.Contains(t, errRes["error"], "invalid node key")
|
||||
|
||||
// invalid carve size
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
|
||||
NodeKey: hosts[0].NodeKey,
|
||||
BlockCount: 3,
|
||||
BlockSize: 3,
|
||||
CarveSize: 0,
|
||||
CarveId: "c1",
|
||||
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
|
||||
assert.Contains(t, errRes["error"], "carve_size must be greater")
|
||||
|
||||
// invalid block size too big
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
|
||||
NodeKey: hosts[0].NodeKey,
|
||||
BlockCount: 3,
|
||||
BlockSize: maxBlockSize + 1,
|
||||
CarveSize: maxCarveSize,
|
||||
CarveId: "c1",
|
||||
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
|
||||
assert.Contains(t, errRes["error"], "block_size exceeds max")
|
||||
|
||||
// invalid carve size too big
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
|
||||
NodeKey: hosts[0].NodeKey,
|
||||
BlockCount: 3,
|
||||
BlockSize: maxBlockSize,
|
||||
CarveSize: maxCarveSize + 1,
|
||||
CarveId: "c1",
|
||||
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
|
||||
assert.Contains(t, errRes["error"], "carve_size exceeds max")
|
||||
|
||||
// invalid carve size, does not match blocks
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
|
||||
NodeKey: hosts[0].NodeKey,
|
||||
BlockCount: 3,
|
||||
BlockSize: 3,
|
||||
CarveSize: 1,
|
||||
CarveId: "c1",
|
||||
}, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406
|
||||
assert.Contains(t, errRes["error"], "carve_size does not match")
|
||||
|
||||
// valid carve begin
|
||||
var beginResp carveBeginResponse
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{
|
||||
NodeKey: hosts[0].NodeKey,
|
||||
BlockCount: 3,
|
||||
BlockSize: 3,
|
||||
CarveSize: 8,
|
||||
CarveId: "c1",
|
||||
RequestId: "r1",
|
||||
}, http.StatusOK, &beginResp)
|
||||
require.NotEmpty(t, beginResp.SessionId)
|
||||
sid := beginResp.SessionId
|
||||
|
||||
// sending a block with invalid session id
|
||||
var blockResp carveBlockResponse
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 1,
|
||||
SessionId: sid + "zz",
|
||||
RequestId: "??",
|
||||
Data: []byte("p1."),
|
||||
}, http.StatusNotFound, &blockResp)
|
||||
|
||||
// sending a block with valid session id but invalid request id
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 1,
|
||||
SessionId: sid,
|
||||
RequestId: "??",
|
||||
Data: []byte("p1."),
|
||||
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
|
||||
|
||||
// sending a block with unexpected block id (expects 0, got 1)
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 1,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p1."),
|
||||
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
|
||||
|
||||
// sending a block with valid payload, block 0
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 0,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p1."),
|
||||
}, http.StatusOK, &blockResp)
|
||||
require.True(t, blockResp.Success)
|
||||
|
||||
// sending next block
|
||||
blockResp = carveBlockResponse{}
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 1,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p2."),
|
||||
}, http.StatusOK, &blockResp)
|
||||
require.True(t, blockResp.Success)
|
||||
|
||||
// sending already-sent block again
|
||||
blockResp = carveBlockResponse{}
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 1,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p2."),
|
||||
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
|
||||
|
||||
// sending final block with too many bytes
|
||||
blockResp = carveBlockResponse{}
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 2,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p3extra"),
|
||||
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
|
||||
|
||||
// sending actual final block
|
||||
blockResp = carveBlockResponse{}
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 2,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p3"),
|
||||
}, http.StatusOK, &blockResp)
|
||||
require.True(t, blockResp.Success)
|
||||
|
||||
// sending unexpected block
|
||||
blockResp = carveBlockResponse{}
|
||||
s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{
|
||||
BlockId: 3,
|
||||
SessionId: sid,
|
||||
RequestId: "r1",
|
||||
Data: []byte("p4."),
|
||||
}, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestPasswordReset() {
|
||||
t := s.T()
|
||||
|
||||
// create a new user
|
||||
var createResp createUserResponse
|
||||
userRawPwd := "passw0rd!"
|
||||
params := fleet.UserPayload{
|
||||
Name: ptr.String("forgotpwd"),
|
||||
Email: ptr.String("forgotpwd@example.com"),
|
||||
Password: ptr.String(userRawPwd),
|
||||
GlobalRole: ptr.String(fleet.RoleObserver),
|
||||
}
|
||||
s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp)
|
||||
require.NotZero(t, createResp.User.ID)
|
||||
u := *createResp.User
|
||||
_ = u
|
||||
|
||||
// request forgot password, invalid email
|
||||
res := s.DoRawNoAuth("POST", "/api/v1/fleet/forgot_password", jsonMustMarshal(t, forgotPasswordRequest{Email: "invalid@asd.com"}), http.StatusAccepted)
|
||||
res.Body.Close()
|
||||
|
||||
// TODO: tested manually (adds too much time to the test), works but hitting the rate
|
||||
// limit returns 500 instead of 429, see #4406. We get the authz check missing error instead.
|
||||
//// trigger the rate limit with a batch of requests in a short burst
|
||||
//for i := 0; i < 20; i++ {
|
||||
// s.DoJSON("POST", "/api/v1/fleet/forgot_password", forgotPasswordRequest{Email: "invalid@asd.com"}, http.StatusAccepted, &forgotResp)
|
||||
//}
|
||||
|
||||
// request forgot password, valid email
|
||||
res = s.DoRawNoAuth("POST", "/api/v1/fleet/forgot_password", jsonMustMarshal(t, forgotPasswordRequest{Email: u.Email}), http.StatusAccepted)
|
||||
res.Body.Close()
|
||||
|
||||
var token string
|
||||
mysql.ExecAdhocSQL(t, s.ds, func(db sqlx.ExtContext) error {
|
||||
return sqlx.GetContext(context.Background(), db, &token, "SELECT token FROM password_reset_requests WHERE user_id = ?", u.ID)
|
||||
})
|
||||
|
||||
// proceed with reset password
|
||||
userNewPwd := "newpassw0rd!"
|
||||
res = s.DoRawNoAuth("POST", "/api/v1/fleet/reset_password", jsonMustMarshal(t, resetPasswordRequest{PasswordResetToken: token, NewPassword: userNewPwd}), http.StatusOK)
|
||||
res.Body.Close()
|
||||
|
||||
// attempt it again with already-used token
|
||||
userUnusedPwd := "unusedpassw0rd!"
|
||||
res = s.DoRawNoAuth("POST", "/api/v1/fleet/reset_password", jsonMustMarshal(t, resetPasswordRequest{PasswordResetToken: token, NewPassword: userUnusedPwd}), http.StatusInternalServerError) // TODO: should be 40x, see #4406
|
||||
res.Body.Close()
|
||||
|
||||
// login with the old password, should not succeed
|
||||
res = s.DoRawNoAuth("POST", "/api/v1/fleet/login", jsonMustMarshal(t, loginRequest{Email: u.Email, Password: userRawPwd}), http.StatusUnauthorized)
|
||||
res.Body.Close()
|
||||
|
||||
// login with the new password, should succeed
|
||||
res = s.DoRawNoAuth("POST", "/api/v1/fleet/login", jsonMustMarshal(t, loginRequest{Email: u.Email, Password: userNewPwd}), http.StatusOK)
|
||||
res.Body.Close()
|
||||
}
|
||||
|
||||
// creates a session and returns it, its key is to be passed as authorization header.
|
||||
func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session {
|
||||
key := make([]byte, 64)
|
||||
|
|
@ -3116,3 +3533,9 @@ func cleanupQuery(s *integrationTestSuite, queryID uint) {
|
|||
var delResp deleteQueryByIDResponse
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/queries/id/%d", queryID), nil, http.StatusOK, &delResp)
|
||||
}
|
||||
|
||||
func jsonMustMarshal(t testing.TB, v interface{}) []byte {
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return b
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mail"
|
||||
|
|
@ -264,3 +265,51 @@ func (svc *Service) DeleteInvite(ctx context.Context, id uint) error {
|
|||
}
|
||||
return svc.ds.DeleteInvite(ctx, id)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Verify invite
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type verifyInviteRequest struct {
|
||||
Token string `url:"token"`
|
||||
}
|
||||
|
||||
type verifyInviteResponse struct {
|
||||
Invite *fleet.Invite `json:"invite"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r verifyInviteResponse) error() error { return r.Err }
|
||||
|
||||
func verifyInviteEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*verifyInviteRequest)
|
||||
invite, err := svc.VerifyInvite(ctx, req.Token)
|
||||
if err != nil {
|
||||
return verifyInviteResponse{Err: err}, nil
|
||||
}
|
||||
return verifyInviteResponse{Invite: invite}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) VerifyInvite(ctx context.Context, token string) (*fleet.Invite, error) {
|
||||
// skipauth: There is no viewer context at this point. We rely on verifying
|
||||
// the invite for authNZ.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithExtras(ctx, "token", token)
|
||||
|
||||
invite, err := svc.ds.InviteByToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if invite.Token != token {
|
||||
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite Token does not match Email Address.")
|
||||
}
|
||||
|
||||
expiresAt := invite.CreatedAt.Add(svc.config.App.InviteTokenValidityPeriod)
|
||||
if svc.clock.Now().After(expiresAt) {
|
||||
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite token has expired.")
|
||||
}
|
||||
|
||||
return invite, nil
|
||||
}
|
||||
|
|
|
|||
84
server/service/jitter.go
Normal file
84
server/service/jitter.go
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value
|
||||
// that is properly balanced. Balance in this context means that hosts would be distributed uniformly
|
||||
// across the total jitter time so there are no spikes.
|
||||
// The way this structure works is as follows:
|
||||
// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to
|
||||
// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should
|
||||
// end up with 1 per bucket.
|
||||
// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount
|
||||
// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we
|
||||
// increase the maxCapacity by 1 for all buckets, and start placing hosts.
|
||||
// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is
|
||||
// running.
|
||||
// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the
|
||||
// next one will be tried. If all buckets are full, then capacity gets increased and the bucket
|
||||
// selection process restarts.
|
||||
// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of
|
||||
// minutes added to the host check in time.
|
||||
// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to
|
||||
// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case.
|
||||
// In the worst possible case that all hosts start at the same time, max jitter percent can be set to
|
||||
// 100, and this method will distribute hosts evenly.
|
||||
// The main caveat of this approach is that it works at the fleet instance. So depending on what
|
||||
// instance gets chosen by the load balancer, the jitter might be different. However, load tests have
|
||||
// shown that the distribution in practice is pretty balance even when all hosts try to check in at
|
||||
// the same time.
|
||||
type jitterHashTable struct {
|
||||
mu sync.Mutex
|
||||
maxCapacity int
|
||||
bucketCount int
|
||||
buckets map[int]int
|
||||
cache map[uint]time.Duration
|
||||
}
|
||||
|
||||
func newJitterHashTable(bucketCount int) *jitterHashTable {
|
||||
if bucketCount == 0 {
|
||||
bucketCount = 1
|
||||
}
|
||||
return &jitterHashTable{
|
||||
maxCapacity: 1,
|
||||
bucketCount: bucketCount,
|
||||
buckets: make(map[int]int),
|
||||
cache: make(map[uint]time.Duration),
|
||||
}
|
||||
}
|
||||
|
||||
func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration {
|
||||
// if no jitter is configured just return 0
|
||||
if jh.bucketCount <= 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
jh.mu.Lock()
|
||||
if jitter, ok := jh.cache[hostID]; ok {
|
||||
jh.mu.Unlock()
|
||||
return jitter
|
||||
}
|
||||
|
||||
for i := 0; i < jh.bucketCount; i++ {
|
||||
possibleBucket := (int(hostID) + i) % jh.bucketCount
|
||||
|
||||
// if the next bucket has capacity, great!
|
||||
if jh.buckets[possibleBucket] < jh.maxCapacity {
|
||||
jh.buckets[possibleBucket]++
|
||||
jitter := time.Duration(possibleBucket) * time.Minute
|
||||
jh.cache[hostID] = jitter
|
||||
|
||||
jh.mu.Unlock()
|
||||
return jitter
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise, bump the capacity and restart the process
|
||||
jh.maxCapacity++
|
||||
|
||||
jh.mu.Unlock()
|
||||
return jh.jitterForHost(hostID)
|
||||
}
|
||||
52
server/service/jitter_test.go
Normal file
52
server/service/jitter_test.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"math"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJitterForHost(t *testing.T) {
|
||||
jh := newJitterHashTable(30)
|
||||
|
||||
histogram := make(map[int64]int)
|
||||
hostCount := 3000
|
||||
for i := 0; i < hostCount; i++ {
|
||||
hostID, err := crand.Int(crand.Reader, big.NewInt(10000))
|
||||
require.NoError(t, err)
|
||||
jitter := jh.jitterForHost(uint(hostID.Int64() + 10000))
|
||||
jitterMinutes := int64(jitter.Minutes())
|
||||
histogram[jitterMinutes]++
|
||||
}
|
||||
min, max := math.MaxInt, 0
|
||||
for jitterMinutes, count := range histogram {
|
||||
if count < min {
|
||||
min = count
|
||||
}
|
||||
if count > max {
|
||||
max = count
|
||||
}
|
||||
t.Logf("jitterMinutes=%d \t count=%d\n", jitterMinutes, count)
|
||||
}
|
||||
variation := max - min
|
||||
t.Logf("min=%d \t max=%d \t variation=%d\n", min, max, variation)
|
||||
|
||||
// check that variation is below 1% of the total amount of hosts
|
||||
require.Less(t, variation, int(float32(hostCount)/0.01))
|
||||
}
|
||||
|
||||
func TestNoJitter(t *testing.T) {
|
||||
jh := newJitterHashTable(0)
|
||||
|
||||
hostCount := 3000
|
||||
for i := 0; i < hostCount; i++ {
|
||||
hostID, err := crand.Int(crand.Reader, big.NewInt(10000))
|
||||
require.NoError(t, err)
|
||||
jitter := jh.jitterForHost(uint(hostID.Int64() + 10000))
|
||||
jitterMinutes := int64(jitter.Minutes())
|
||||
require.Equal(t, int64(0), jitterMinutes)
|
||||
}
|
||||
}
|
||||
|
|
@ -4,8 +4,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"runtime"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
|
|
@ -28,19 +26,15 @@ func NewMiddleware(store throttled.GCRAStore) *Middleware {
|
|||
}
|
||||
|
||||
// Limit returns a new middleware function enforcing the provided quota.
|
||||
func (m *Middleware) Limit(quota throttled.RateQuota) endpoint.Middleware {
|
||||
func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware {
|
||||
return func(next endpoint.Endpoint) endpoint.Endpoint {
|
||||
// Get function name to use as a key for rate limiting (each wrapped function
|
||||
// gets a separate quota)
|
||||
funcName := runtime.FuncForPC(reflect.ValueOf(next).Pointer()).Name()
|
||||
|
||||
limiter, err := throttled.NewGCRARateLimiter(m.store, quota)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req interface{}) (response interface{}, err error) {
|
||||
limited, result, err := limiter.RateLimit(funcName, 1)
|
||||
limited, result, err := limiter.RateLimit(keyName, 1)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "check rate limit")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ func TestLimit(t *testing.T) {
|
|||
limiter := NewMiddleware(store)
|
||||
endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
|
||||
wrapped := limiter.Limit(
|
||||
"test_limit",
|
||||
throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0},
|
||||
)(endpoint)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,10 @@ import (
|
|||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
|
|
@ -22,6 +24,255 @@ import (
|
|||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type osqueryError struct {
|
||||
message string
|
||||
nodeInvalid bool
|
||||
}
|
||||
|
||||
func (e osqueryError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e osqueryError) NodeInvalid() bool {
|
||||
return e.nodeInvalid
|
||||
}
|
||||
|
||||
func (svc *Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
if nodeKey == "" {
|
||||
return nil, false, osqueryError{
|
||||
message: "authentication error: missing node key",
|
||||
nodeInvalid: true,
|
||||
}
|
||||
}
|
||||
|
||||
host, err := svc.ds.LoadHostByNodeKey(ctx, nodeKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
// OK
|
||||
case fleet.IsNotFound(err):
|
||||
return nil, false, osqueryError{
|
||||
message: "authentication error: invalid node key: " + nodeKey,
|
||||
nodeInvalid: true,
|
||||
}
|
||||
default:
|
||||
return nil, false, osqueryError{
|
||||
message: "authentication error: " + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Update the "seen" time used to calculate online status. These updates are
|
||||
// batched for MySQL performance reasons. Because this is done
|
||||
// asynchronously, it is possible for the server to shut down before
|
||||
// updating the seen time for these hosts. This seems to be an acceptable
|
||||
// tradeoff as an online host will continue to check in and quickly be
|
||||
// marked online again.
|
||||
svc.seenHostSet.addHostID(host.ID)
|
||||
host.SeenTime = svc.clock.Now()
|
||||
|
||||
return host, svc.debugEnabledForHost(ctx, host.ID), nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Enroll Agent
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type enrollAgentRequest struct {
|
||||
EnrollSecret string `json:"enroll_secret"`
|
||||
HostIdentifier string `json:"host_identifier"`
|
||||
HostDetails map[string](map[string]string) `json:"host_details"`
|
||||
}
|
||||
|
||||
type enrollAgentResponse struct {
|
||||
NodeKey string `json:"node_key,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r enrollAgentResponse) error() error { return r.Err }
|
||||
|
||||
func enrollAgentEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*enrollAgentRequest)
|
||||
nodeKey, err := svc.EnrollAgent(ctx, req.EnrollSecret, req.HostIdentifier, req.HostDetails)
|
||||
if err != nil {
|
||||
return enrollAgentResponse{Err: err}, nil
|
||||
}
|
||||
return enrollAgentResponse{NodeKey: nodeKey}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) EnrollAgent(ctx context.Context, enrollSecret, hostIdentifier string, hostDetails map[string](map[string]string)) (string, error) {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithExtras(ctx, "hostIdentifier", hostIdentifier)
|
||||
|
||||
secret, err := svc.ds.VerifyEnrollSecret(ctx, enrollSecret)
|
||||
if err != nil {
|
||||
return "", osqueryError{
|
||||
message: "enroll failed: " + err.Error(),
|
||||
nodeInvalid: true,
|
||||
}
|
||||
}
|
||||
|
||||
nodeKey, err := server.GenerateRandomText(svc.config.Osquery.NodeKeySize)
|
||||
if err != nil {
|
||||
return "", osqueryError{
|
||||
message: "generate node key failed: " + err.Error(),
|
||||
nodeInvalid: true,
|
||||
}
|
||||
}
|
||||
|
||||
hostIdentifier = getHostIdentifier(svc.logger, svc.config.Osquery.HostIdentifier, hostIdentifier, hostDetails)
|
||||
|
||||
host, err := svc.ds.EnrollHost(ctx, hostIdentifier, nodeKey, secret.TeamID, svc.config.Osquery.EnrollCooldown)
|
||||
if err != nil {
|
||||
return "", osqueryError{message: "save enroll failed: " + err.Error(), nodeInvalid: true}
|
||||
}
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return "", osqueryError{message: "app config load failed: " + err.Error(), nodeInvalid: true}
|
||||
}
|
||||
|
||||
// Save enrollment details if provided
|
||||
detailQueries := osquery_utils.GetDetailQueries(appConfig, svc.config)
|
||||
save := false
|
||||
if r, ok := hostDetails["os_version"]; ok {
|
||||
err := detailQueries["os_version"].IngestFunc(svc.logger, host, []map[string]string{r})
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "Ingesting os_version")
|
||||
}
|
||||
save = true
|
||||
}
|
||||
if r, ok := hostDetails["osquery_info"]; ok {
|
||||
err := detailQueries["osquery_info"].IngestFunc(svc.logger, host, []map[string]string{r})
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "Ingesting osquery_info")
|
||||
}
|
||||
save = true
|
||||
}
|
||||
if r, ok := hostDetails["system_info"]; ok {
|
||||
err := detailQueries["system_info"].IngestFunc(svc.logger, host, []map[string]string{r})
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "Ingesting system_info")
|
||||
}
|
||||
save = true
|
||||
}
|
||||
|
||||
if save {
|
||||
if appConfig.ServerSettings.DeferredSaveHost {
|
||||
go svc.serialUpdateHost(host)
|
||||
} else {
|
||||
if err := svc.ds.UpdateHost(ctx, host); err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "save host in enroll agent")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nodeKey, nil
|
||||
}
|
||||
|
||||
var counter = int64(0)
|
||||
|
||||
func (svc *Service) serialUpdateHost(host *fleet.Host) {
|
||||
newVal := atomic.AddInt64(&counter, 1)
|
||||
defer func() {
|
||||
atomic.AddInt64(&counter, -1)
|
||||
}()
|
||||
level.Debug(svc.logger).Log("background", newVal)
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancelFunc()
|
||||
err := svc.ds.SerialUpdateHost(ctx, host)
|
||||
if err != nil {
|
||||
level.Error(svc.logger).Log("background-err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier string, details map[string](map[string]string)) string {
|
||||
switch identifierOption {
|
||||
case "provided":
|
||||
// Use the host identifier already provided in the request.
|
||||
return providedIdentifier
|
||||
|
||||
case "instance":
|
||||
r, ok := details["osquery_info"]
|
||||
if !ok {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing osquery_info",
|
||||
"identifier", "instance",
|
||||
)
|
||||
} else if r["instance_id"] == "" {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing instance_id in osquery_info",
|
||||
"identifier", "instance",
|
||||
)
|
||||
} else {
|
||||
return r["instance_id"]
|
||||
}
|
||||
|
||||
case "uuid":
|
||||
r, ok := details["osquery_info"]
|
||||
if !ok {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing osquery_info",
|
||||
"identifier", "uuid",
|
||||
)
|
||||
} else if r["uuid"] == "" {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing instance_id in osquery_info",
|
||||
"identifier", "uuid",
|
||||
)
|
||||
} else {
|
||||
return r["uuid"]
|
||||
}
|
||||
|
||||
case "hostname":
|
||||
r, ok := details["system_info"]
|
||||
if !ok {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing system_info",
|
||||
"identifier", "hostname",
|
||||
)
|
||||
} else if r["hostname"] == "" {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing instance_id in system_info",
|
||||
"identifier", "hostname",
|
||||
)
|
||||
} else {
|
||||
return r["hostname"]
|
||||
}
|
||||
|
||||
default:
|
||||
panic("Unknown option for host_identifier: " + identifierOption)
|
||||
}
|
||||
|
||||
return providedIdentifier
|
||||
}
|
||||
|
||||
func (svc *Service) debugEnabledForHost(ctx context.Context, id uint) bool {
|
||||
hlogger := log.With(svc.logger, "host-id", id)
|
||||
ac, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
level.Debug(hlogger).Log("err", ctxerr.Wrap(ctx, err, "getting app config for host debug"))
|
||||
return false
|
||||
}
|
||||
|
||||
for _, hostID := range ac.ServerSettings.DebugHostIDs {
|
||||
if hostID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get Client Config
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -90,7 +90,7 @@ func NewService(
|
|||
return validationMiddleware{svc, ds, sso}, nil
|
||||
}
|
||||
|
||||
func (s Service) SendEmail(mail fleet.Email) error {
|
||||
func (s *Service) SendEmail(mail fleet.Email) error {
|
||||
return s.mailService.SendEmail(mail)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
// Note host did not authenticate via node key. We need to authenticate them
|
||||
// by the session ID and request ID
|
||||
carve, err := svc.carveStore.CarveBySessionId(ctx, payload.SessionId)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "find carve by session_id")
|
||||
}
|
||||
|
||||
if payload.RequestId != carve.RequestId {
|
||||
return errors.New("request_id does not match")
|
||||
}
|
||||
|
||||
// Request is now authenticated
|
||||
|
||||
if payload.BlockId > carve.BlockCount-1 {
|
||||
return fmt.Errorf("block_id exceeds expected max (%d): %d", carve.BlockCount-1, payload.BlockId)
|
||||
}
|
||||
|
||||
if payload.BlockId != carve.MaxBlock+1 {
|
||||
return fmt.Errorf("block_id does not match expected block (%d): %d", carve.MaxBlock+1, payload.BlockId)
|
||||
}
|
||||
|
||||
if int64(len(payload.Data)) > carve.BlockSize {
|
||||
return fmt.Errorf("exceeded declared block size %d: %d", carve.BlockSize, len(payload.Data))
|
||||
}
|
||||
|
||||
if err := svc.carveStore.NewBlock(ctx, carve, payload.BlockId, payload.Data); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "save block data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,432 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
||||
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCarveBegin(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
expectedMetadata := fleet.CarveMetadata{
|
||||
ID: 7,
|
||||
HostId: host.ID,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
|
||||
metadata.ID = 7
|
||||
return metadata, nil
|
||||
}
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
metadata, err := svc.CarveBegin(ctx, payload)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, metadata.SessionId)
|
||||
metadata.SessionId = "" // Clear this before comparison
|
||||
metadata.Name = "" // Clear this before comparison
|
||||
metadata.CreatedAt = time.Time{} // Clear this before comparison
|
||||
assert.Equal(t, expectedMetadata, *metadata)
|
||||
}
|
||||
|
||||
func TestCarveBeginNewCarveError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
|
||||
return nil, errors.New("ouch!")
|
||||
}
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ouch!")
|
||||
}
|
||||
|
||||
func TestCarveBeginEmptyError(t *testing.T) {
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
ctx := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1})
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if id != 1 {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{}, nil
|
||||
}
|
||||
|
||||
_, err := svc.CarveBegin(ctx, fleet.CarveBeginPayload{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size must be greater than 0")
|
||||
}
|
||||
|
||||
func TestCarveBeginMissingHostError(t *testing.T) {
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
|
||||
_, err := svc.CarveBegin(context.Background(), fleet.CarveBeginPayload{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing host")
|
||||
}
|
||||
|
||||
func TestCarveBeginBlockSizeMaxError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 10,
|
||||
BlockSize: 1024 * 1024 * 1024 * 1024, // 1TB
|
||||
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "block_size exceeds max")
|
||||
}
|
||||
|
||||
func TestCarveBeginCarveSizeMaxError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 1024 * 1024,
|
||||
BlockSize: 10 * 1024 * 1024, // 1TB
|
||||
CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size exceeds max")
|
||||
}
|
||||
|
||||
func TestCarveBeginCarveSizeError(t *testing.T) {
|
||||
host := fleet.Host{ID: 3}
|
||||
payload := fleet.CarveBeginPayload{
|
||||
BlockCount: 7,
|
||||
BlockSize: 13,
|
||||
CarveSize: 7*13 + 1,
|
||||
RequestId: "carve_request",
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
ds := new(mock.Store)
|
||||
svc := &Service{
|
||||
carveStore: ms,
|
||||
ds: ds,
|
||||
}
|
||||
ctx := hostctx.NewContext(context.Background(), &host)
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
||||
if host.ID != id {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &fleet.Host{
|
||||
Hostname: host.Hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Too big
|
||||
_, err := svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size does not match")
|
||||
|
||||
// Too small
|
||||
payload.CarveSize = 6 * 13
|
||||
_, err = svc.CarveBegin(ctx, payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "carve_size does not match")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockGetCarveError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
return nil, errors.New("ouch!")
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
SessionId: sessionId,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ouch!")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockRequestIdError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "not_matching",
|
||||
SessionId: sessionId,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "request_id does not match")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockBlockCountExceedError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 23,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "block_id exceeds expected max")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockBlockCountMatchError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 7,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "block_id does not match")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockBlockSizeError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 16,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :) TOO LONG!!!"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 4,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeded declared block size")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlockNewBlockError(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
|
||||
return errors.New("kaboom!")
|
||||
}
|
||||
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 4,
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "kaboom!")
|
||||
}
|
||||
|
||||
func TestCarveCarveBlock(t *testing.T) {
|
||||
sessionId := "foobar"
|
||||
metadata := &fleet.CarveMetadata{
|
||||
ID: 2,
|
||||
HostId: 3,
|
||||
BlockCount: 23,
|
||||
BlockSize: 64,
|
||||
CarveSize: 23 * 64,
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
MaxBlock: 3,
|
||||
}
|
||||
payload := fleet.CarveBlockPayload{
|
||||
Data: []byte("this is the carve data :)"),
|
||||
RequestId: "carve_request",
|
||||
SessionId: sessionId,
|
||||
BlockId: 4,
|
||||
}
|
||||
ms := new(mock.Store)
|
||||
svc := &Service{carveStore: ms}
|
||||
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
|
||||
assert.Equal(t, metadata.SessionId, sessionId)
|
||||
return metadata, nil
|
||||
}
|
||||
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
|
||||
assert.Equal(t, metadata, carve)
|
||||
assert.Equal(t, int64(4), blockId)
|
||||
assert.Equal(t, payload.Data, data)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := svc.CarveBlock(context.Background(), payload)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ms.NewBlockFuncInvoked)
|
||||
}
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
func (svc *Service) VerifyInvite(ctx context.Context, token string) (*fleet.Invite, error) {
|
||||
// skipauth: There is no viewer context at this point. We rely on verifying
|
||||
// the invite for authNZ.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithExtras(ctx, "token", token)
|
||||
|
||||
invite, err := svc.ds.InviteByToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if invite.Token != token {
|
||||
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite Token does not match Email Address.")
|
||||
}
|
||||
|
||||
expiresAt := invite.CreatedAt.Add(svc.config.App.InviteTokenValidityPeriod)
|
||||
if svc.clock.Now().After(expiresAt) {
|
||||
return nil, fleet.NewInvalidArgumentError("invite_token", "Invite token has expired.")
|
||||
}
|
||||
|
||||
return invite, nil
|
||||
|
||||
}
|
||||
|
|
@ -1,317 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/service/osquery_utils"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
)
|
||||
|
||||
type osqueryError struct {
|
||||
message string
|
||||
nodeInvalid bool
|
||||
}
|
||||
|
||||
func (e osqueryError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e osqueryError) NodeInvalid() bool {
|
||||
return e.nodeInvalid
|
||||
}
|
||||
|
||||
var counter = int64(0)
|
||||
|
||||
func (svc Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
if nodeKey == "" {
|
||||
return nil, false, osqueryError{
|
||||
message: "authentication error: missing node key",
|
||||
nodeInvalid: true,
|
||||
}
|
||||
}
|
||||
|
||||
host, err := svc.ds.LoadHostByNodeKey(ctx, nodeKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
// OK
|
||||
case fleet.IsNotFound(err):
|
||||
return nil, false, osqueryError{
|
||||
message: "authentication error: invalid node key: " + nodeKey,
|
||||
nodeInvalid: true,
|
||||
}
|
||||
default:
|
||||
return nil, false, osqueryError{
|
||||
message: "authentication error: " + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Update the "seen" time used to calculate online status. These updates are
|
||||
// batched for MySQL performance reasons. Because this is done
|
||||
// asynchronously, it is possible for the server to shut down before
|
||||
// updating the seen time for these hosts. This seems to be an acceptable
|
||||
// tradeoff as an online host will continue to check in and quickly be
|
||||
// marked online again.
|
||||
svc.seenHostSet.addHostID(host.ID)
|
||||
host.SeenTime = svc.clock.Now()
|
||||
|
||||
return host, svc.debugEnabledForHost(ctx, host.ID), nil
|
||||
}
|
||||
|
||||
func (svc Service) debugEnabledForHost(ctx context.Context, id uint) bool {
|
||||
hlogger := log.With(svc.logger, "host-id", id)
|
||||
ac, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
level.Debug(hlogger).Log("err", ctxerr.Wrap(ctx, err, "getting app config for host debug"))
|
||||
return false
|
||||
}
|
||||
|
||||
for _, hostID := range ac.ServerSettings.DebugHostIDs {
|
||||
if hostID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (svc Service) EnrollAgent(ctx context.Context, enrollSecret, hostIdentifier string, hostDetails map[string](map[string]string)) (string, error) {
|
||||
// skipauth: Authorization is currently for user endpoints only.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithExtras(ctx, "hostIdentifier", hostIdentifier)
|
||||
|
||||
secret, err := svc.ds.VerifyEnrollSecret(ctx, enrollSecret)
|
||||
if err != nil {
|
||||
return "", osqueryError{
|
||||
message: "enroll failed: " + err.Error(),
|
||||
nodeInvalid: true,
|
||||
}
|
||||
}
|
||||
|
||||
nodeKey, err := server.GenerateRandomText(svc.config.Osquery.NodeKeySize)
|
||||
if err != nil {
|
||||
return "", osqueryError{
|
||||
message: "generate node key failed: " + err.Error(),
|
||||
nodeInvalid: true,
|
||||
}
|
||||
}
|
||||
|
||||
hostIdentifier = getHostIdentifier(svc.logger, svc.config.Osquery.HostIdentifier, hostIdentifier, hostDetails)
|
||||
|
||||
host, err := svc.ds.EnrollHost(ctx, hostIdentifier, nodeKey, secret.TeamID, svc.config.Osquery.EnrollCooldown)
|
||||
if err != nil {
|
||||
return "", osqueryError{message: "save enroll failed: " + err.Error(), nodeInvalid: true}
|
||||
}
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return "", osqueryError{message: "app config load failed: " + err.Error(), nodeInvalid: true}
|
||||
}
|
||||
|
||||
// Save enrollment details if provided
|
||||
detailQueries := osquery_utils.GetDetailQueries(appConfig, svc.config)
|
||||
save := false
|
||||
if r, ok := hostDetails["os_version"]; ok {
|
||||
err := detailQueries["os_version"].IngestFunc(svc.logger, host, []map[string]string{r})
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "Ingesting os_version")
|
||||
}
|
||||
save = true
|
||||
}
|
||||
if r, ok := hostDetails["osquery_info"]; ok {
|
||||
err := detailQueries["osquery_info"].IngestFunc(svc.logger, host, []map[string]string{r})
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "Ingesting osquery_info")
|
||||
}
|
||||
save = true
|
||||
}
|
||||
if r, ok := hostDetails["system_info"]; ok {
|
||||
err := detailQueries["system_info"].IngestFunc(svc.logger, host, []map[string]string{r})
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "Ingesting system_info")
|
||||
}
|
||||
save = true
|
||||
}
|
||||
|
||||
if save {
|
||||
if appConfig.ServerSettings.DeferredSaveHost {
|
||||
go svc.serialUpdateHost(host)
|
||||
} else {
|
||||
if err := svc.ds.UpdateHost(ctx, host); err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "save host in enroll agent")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nodeKey, nil
|
||||
}
|
||||
|
||||
func (svc Service) serialUpdateHost(host *fleet.Host) {
|
||||
newVal := atomic.AddInt64(&counter, 1)
|
||||
defer func() {
|
||||
atomic.AddInt64(&counter, -1)
|
||||
}()
|
||||
level.Debug(svc.logger).Log("background", newVal)
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancelFunc()
|
||||
err := svc.ds.SerialUpdateHost(ctx, host)
|
||||
if err != nil {
|
||||
level.Error(svc.logger).Log("background-err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier string, details map[string](map[string]string)) string {
|
||||
switch identifierOption {
|
||||
case "provided":
|
||||
// Use the host identifier already provided in the request.
|
||||
return providedIdentifier
|
||||
|
||||
case "instance":
|
||||
r, ok := details["osquery_info"]
|
||||
if !ok {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing osquery_info",
|
||||
"identifier", "instance",
|
||||
)
|
||||
} else if r["instance_id"] == "" {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing instance_id in osquery_info",
|
||||
"identifier", "instance",
|
||||
)
|
||||
} else {
|
||||
return r["instance_id"]
|
||||
}
|
||||
|
||||
case "uuid":
|
||||
r, ok := details["osquery_info"]
|
||||
if !ok {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing osquery_info",
|
||||
"identifier", "uuid",
|
||||
)
|
||||
} else if r["uuid"] == "" {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing instance_id in osquery_info",
|
||||
"identifier", "uuid",
|
||||
)
|
||||
} else {
|
||||
return r["uuid"]
|
||||
}
|
||||
|
||||
case "hostname":
|
||||
r, ok := details["system_info"]
|
||||
if !ok {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing system_info",
|
||||
"identifier", "hostname",
|
||||
)
|
||||
} else if r["hostname"] == "" {
|
||||
level.Info(logger).Log(
|
||||
"msg", "could not get host identifier",
|
||||
"reason", "missing instance_id in system_info",
|
||||
"identifier", "hostname",
|
||||
)
|
||||
} else {
|
||||
return r["hostname"]
|
||||
}
|
||||
|
||||
default:
|
||||
panic("Unknown option for host_identifier: " + identifierOption)
|
||||
}
|
||||
|
||||
return providedIdentifier
|
||||
}
|
||||
|
||||
// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value
|
||||
// that is properly balanced. Balance in this context means that hosts would be distributed uniformly
|
||||
// across the total jitter time so there are no spikes.
|
||||
// The way this structure works is as follows:
|
||||
// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to
|
||||
// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should
|
||||
// end up with 1 per bucket.
|
||||
// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount
|
||||
// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we
|
||||
// increase the maxCapacity by 1 for all buckets, and start placing hosts.
|
||||
// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is
|
||||
// running.
|
||||
// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the
|
||||
// next one will be tried. If all buckets are full, then capacity gets increased and the bucket
|
||||
// selection process restarts.
|
||||
// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of
|
||||
// minutes added to the host check in time.
|
||||
// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to
|
||||
// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case.
|
||||
// In the worst possible case that all hosts start at the same time, max jitter percent can be set to
|
||||
// 100, and this method will distribute hosts evenly.
|
||||
// The main caveat of this approach is that it works at the fleet instance. So depending on what
|
||||
// instance gets chosen by the load balancer, the jitter might be different. However, load tests have
|
||||
// shown that the distribution in practice is pretty balance even when all hosts try to check in at
|
||||
// the same time.
|
||||
type jitterHashTable struct {
|
||||
mu sync.Mutex
|
||||
maxCapacity int
|
||||
bucketCount int
|
||||
buckets map[int]int
|
||||
cache map[uint]time.Duration
|
||||
}
|
||||
|
||||
func newJitterHashTable(bucketCount int) *jitterHashTable {
|
||||
if bucketCount == 0 {
|
||||
bucketCount = 1
|
||||
}
|
||||
return &jitterHashTable{
|
||||
maxCapacity: 1,
|
||||
bucketCount: bucketCount,
|
||||
buckets: make(map[int]int),
|
||||
cache: make(map[uint]time.Duration),
|
||||
}
|
||||
}
|
||||
|
||||
func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration {
|
||||
// if no jitter is configured just return 0
|
||||
if jh.bucketCount <= 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
jh.mu.Lock()
|
||||
if jitter, ok := jh.cache[hostID]; ok {
|
||||
jh.mu.Unlock()
|
||||
return jitter
|
||||
}
|
||||
|
||||
for i := 0; i < jh.bucketCount; i++ {
|
||||
possibleBucket := (int(hostID) + i) % jh.bucketCount
|
||||
|
||||
// if the next bucket has capacity, great!
|
||||
if jh.buckets[possibleBucket] < jh.maxCapacity {
|
||||
jh.buckets[possibleBucket]++
|
||||
jitter := time.Duration(possibleBucket) * time.Minute
|
||||
jh.cache[hostID] = jitter
|
||||
|
||||
jh.mu.Unlock()
|
||||
return jitter
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise, bump the capacity and restart the process
|
||||
jh.maxCapacity++
|
||||
|
||||
jh.mu.Unlock()
|
||||
return jh.jitterForHost(hostID)
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,325 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/sso"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
)
|
||||
|
||||
// SSOSettings returns a subset of the Single Sign-On settings as configured in
|
||||
// the app config. Those can be exposed e.g. via the response to an HTTP request,
|
||||
// and as such should not contain sensitive information.
|
||||
func (svc *Service) SSOSettings(ctx context.Context) (*fleet.SessionSSOSettings, error) {
|
||||
// skipauth: Basic SSO settings are available to unauthenticated users (so
|
||||
// that they have the necessary information to initiate SSO).
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "SessionSSOSettings getting app config")
|
||||
}
|
||||
|
||||
settings := &fleet.SessionSSOSettings{
|
||||
IDPName: appConfig.SSOSettings.IDPName,
|
||||
IDPImageURL: appConfig.SSOSettings.IDPImageURL,
|
||||
SSOEnabled: appConfig.SSOSettings.EnableSSO,
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// InitiateSSO initiates a Single Sign-On flow for a request to visit the
|
||||
// protected URL identified by redirectURL. It returns the URL of the identity
|
||||
// provider to make a request to to proceed with the authentication via that
|
||||
// external service, and stores ephemeral session state to validate the
|
||||
// callback from the identity provider to finalize the SSO flow.
|
||||
func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string, error) {
|
||||
// skipauth: User context does not yet exist. Unauthenticated users may
|
||||
// initiate SSO.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config")
|
||||
}
|
||||
|
||||
metadata, err := svc.getMetadata(appConfig)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata")
|
||||
}
|
||||
|
||||
serverURL := appConfig.ServerSettings.ServerURL
|
||||
settings := sso.Settings{
|
||||
Metadata: metadata,
|
||||
// Construct call back url to send to idp
|
||||
AssertionConsumerServiceURL: serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/sso/callback",
|
||||
SessionStore: svc.ssoSessionStore,
|
||||
OriginalURL: redirectURL,
|
||||
}
|
||||
|
||||
// If issuer is not explicitly set, default to host name.
|
||||
var issuer string
|
||||
entityID := appConfig.SSOSettings.EntityID
|
||||
if entityID == "" {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "parse server url")
|
||||
}
|
||||
issuer = u.Hostname()
|
||||
} else {
|
||||
issuer = entityID
|
||||
}
|
||||
|
||||
idpURL, err := sso.CreateAuthorizationRequest(&settings, issuer)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization")
|
||||
}
|
||||
|
||||
return idpURL, nil
|
||||
}
|
||||
|
||||
func (svc *Service) getMetadata(config *fleet.AppConfig) (*sso.Metadata, error) {
|
||||
if config.SSOSettings.MetadataURL != "" {
|
||||
metadata, err := sso.GetMetadata(config.SSOSettings.MetadataURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
if config.SSOSettings.Metadata != "" {
|
||||
metadata, err := sso.ParseMetadata(config.SSOSettings.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("missing metadata for idp %s", config.SSOSettings.IDPName)
|
||||
}
|
||||
|
||||
func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SSOSession, error) {
|
||||
// skipauth: User context does not yet exist. Unauthenticated users may
|
||||
// hit the SSO callback.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "get config for sso")
|
||||
}
|
||||
|
||||
// Load the request metadata if available
|
||||
|
||||
// localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080
|
||||
var metadata *sso.Metadata
|
||||
var redirectURL string
|
||||
|
||||
if appConfig.SSOSettings.EnableSSOIdPLogin && auth.RequestID() == "" {
|
||||
// Missing request ID indicates this was IdP-initiated. Only allow if
|
||||
// configured to do so.
|
||||
metadata, err = svc.getMetadata(appConfig)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "get sso metadata")
|
||||
}
|
||||
redirectURL = "/"
|
||||
} else {
|
||||
session, err := svc.ssoSessionStore.Get(auth.RequestID())
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "sso request invalid")
|
||||
}
|
||||
// Remove session to so that is can't be reused before it expires.
|
||||
err = svc.ssoSessionStore.Expire(auth.RequestID())
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "remove sso request")
|
||||
}
|
||||
if err := xml.Unmarshal([]byte(session.Metadata), &metadata); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "unmarshal metadata")
|
||||
}
|
||||
redirectURL = session.OriginalURL
|
||||
}
|
||||
|
||||
// Validate response
|
||||
validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience(
|
||||
appConfig.SSOSettings.EntityID,
|
||||
appConfig.ServerSettings.ServerURL,
|
||||
appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS
|
||||
))
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "create validator from metadata")
|
||||
}
|
||||
// make sure the response hasn't been tampered with
|
||||
auth, err = validator.ValidateSignature(auth)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "signature validation failed")
|
||||
}
|
||||
// make sure the response isn't stale
|
||||
err = validator.ValidateResponse(auth)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "response validation failed")
|
||||
}
|
||||
|
||||
// Get and log in user
|
||||
user, err := svc.ds.UserByEmail(ctx, auth.UserID())
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "find user in sso callback")
|
||||
}
|
||||
// if the user is not sso enabled they are not authorized
|
||||
if !user.SSOEnabled {
|
||||
return nil, ctxerr.New(ctx, "user not configured to use sso")
|
||||
}
|
||||
token, err := svc.makeSession(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "make session in sso callback")
|
||||
}
|
||||
result := &fleet.SSOSession{
|
||||
Token: token,
|
||||
RedirectURL: redirectURL,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (svc *Service) Login(ctx context.Context, email, password string) (*fleet.User, string, error) {
|
||||
// skipauth: No user context available yet to authorize against.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(logging.WithNoUser(ctx), level.Info)
|
||||
|
||||
// If there is an error, sleep until the request has taken at least 1
|
||||
// second. This means that generally a login failure for any reason will
|
||||
// take ~1s and frustrate a timing attack.
|
||||
var err error
|
||||
defer func(start time.Time) {
|
||||
if err != nil {
|
||||
time.Sleep(time.Until(start.Add(1 * time.Second)))
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
user, err := svc.ds.UserByEmail(ctx, email)
|
||||
var nfe fleet.NotFoundError
|
||||
if errors.As(err, &nfe) {
|
||||
return nil, "", fleet.NewAuthFailedError("user not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fleet.NewAuthFailedError(err.Error())
|
||||
}
|
||||
|
||||
if err = user.ValidatePassword(password); err != nil {
|
||||
return nil, "", fleet.NewAuthFailedError("invalid password")
|
||||
}
|
||||
|
||||
if user.SSOEnabled {
|
||||
return nil, "", fleet.NewAuthFailedError("password login disabled for sso users")
|
||||
}
|
||||
|
||||
token, err := svc.makeSession(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, "", fleet.NewAuthFailedError(err.Error())
|
||||
}
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// makeSession is a helper that creates a new session after authentication
|
||||
func (svc *Service) makeSession(ctx context.Context, id uint) (string, error) {
|
||||
sessionKeySize := svc.config.Session.KeySize
|
||||
key := make([]byte, sessionKeySize)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sessionKey := base64.StdEncoding.EncodeToString(key)
|
||||
session := &fleet.Session{
|
||||
UserID: id,
|
||||
Key: sessionKey,
|
||||
AccessedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
_, err = svc.ds.NewSession(ctx, session)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "creating new session")
|
||||
}
|
||||
|
||||
return sessionKey, nil
|
||||
}
|
||||
|
||||
func (svc *Service) Logout(ctx context.Context) error {
|
||||
// skipauth: Any user can always log out of their own session.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
// TODO: this should not return an error if the user wasn't logged in
|
||||
return svc.DestroySession(ctx)
|
||||
}
|
||||
|
||||
func (svc *Service) DestroySession(ctx context.Context) error {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return fleet.ErrNoContext
|
||||
}
|
||||
|
||||
session, err := svc.ds.SessionByID(ctx, vc.SessionID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
||||
|
||||
func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
session, err := svc.ds.SessionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = svc.validateSession(ctx, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error {
|
||||
if session == nil {
|
||||
return fleet.NewAuthRequiredError("active session not present")
|
||||
}
|
||||
|
||||
sessionDuration := svc.config.Session.Duration
|
||||
if session.APIOnly != nil && *session.APIOnly {
|
||||
sessionDuration = 0 // make API-only tokens unlimited
|
||||
}
|
||||
|
||||
// duration 0 = unlimited
|
||||
if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration {
|
||||
err := svc.ds.DestroySession(ctx, session)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "destroying session")
|
||||
}
|
||||
return fleet.NewAuthRequiredError("expired session")
|
||||
}
|
||||
|
||||
return svc.ds.MarkSessionAccessed(ctx, session)
|
||||
}
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
defer ds.Close()
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
createTestUsers(t, ds)
|
||||
|
||||
var loginTests = []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "admin1",
|
||||
email: testUsers["admin1"].Email,
|
||||
password: testUsers["admin1"].PlaintextPassword,
|
||||
},
|
||||
{
|
||||
name: "user1",
|
||||
email: testUsers["user1"].Email,
|
||||
password: testUsers["user1"].PlaintextPassword,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range loginTests {
|
||||
t.Run(tt.email, func(st *testing.T) {
|
||||
loggedIn, token, err := svc.Login(test.UserContext(test.UserAdmin), tt.email, tt.password)
|
||||
require.Nil(st, err, "login unsuccessful")
|
||||
assert.Equal(st, tt.email, loggedIn.Email)
|
||||
assert.NotEmpty(st, token)
|
||||
|
||||
sessions, err := svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), loggedIn.ID)
|
||||
require.Nil(st, err)
|
||||
require.Len(st, sessions, 1, "user should have one session")
|
||||
session := sessions[0]
|
||||
assert.NotZero(st, session.UserID)
|
||||
assert.WithinDuration(st, time.Now(), session.AccessedAt, 3*time.Second,
|
||||
"access time should be set with current time at session creation")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSessionByKey(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
cfg := config.TestConfig()
|
||||
|
||||
theSession := &fleet.Session{UserID: 123, Key: "abc"}
|
||||
|
||||
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
return theSession, nil
|
||||
}
|
||||
ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
accessed time.Duration
|
||||
apiOnly bool
|
||||
fail bool
|
||||
}{
|
||||
{"real user, accessed recently", -1 * time.Hour, false, false},
|
||||
{"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true},
|
||||
{"api-only, accessed recently", -1 * time.Hour, true, false},
|
||||
{"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
var authErr *fleet.AuthRequiredError
|
||||
ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false
|
||||
|
||||
theSession.AccessedAt = time.Now().Add(tc.accessed)
|
||||
theSession.APIOnly = ptr.Bool(tc.apiOnly)
|
||||
_, err := svc.GetSessionByKey(context.Background(), theSession.Key)
|
||||
if tc.fail {
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &authErr)
|
||||
require.True(t, ds.SessionByKeyFuncInvoked)
|
||||
require.True(t, ds.DestroySessionFuncInvoked)
|
||||
require.False(t, ds.MarkSessionAccessedFuncInvoked)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.True(t, ds.SessionByKeyFuncInvoked)
|
||||
require.False(t, ds.DestroySessionFuncInvoked)
|
||||
require.True(t, ds.MarkSessionAccessedFuncInvoked)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
func (svc *Service) ListAvailableTeamsForUser(ctx context.Context, user *fleet.User) ([]*fleet.TeamSummary, error) {
|
||||
// skipauth: No authorization check needed due to implementation returning
|
||||
// only license error.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
return nil, fleet.ErrMissingLicense
|
||||
}
|
||||
|
|
@ -2,45 +2,14 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"html/template"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mail"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
)
|
||||
|
||||
func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
// skipauth: There is no viewer context at this point. We rely on verifying
|
||||
// the invite for authNZ.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
invite, err := svc.VerifyInvite(ctx, *p.InviteToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set the payload role property based on an existing invite.
|
||||
p.GlobalRole = invite.GlobalRole.Ptr()
|
||||
p.Teams = &invite.Teams
|
||||
|
||||
user, err := svc.newUser(ctx, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = svc.ds.DeleteInvite(ctx, invite.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
// skipauth: Only the initial user creation should be allowed to skip
|
||||
// authorization (because there is not yet a user context to check against).
|
||||
|
|
@ -88,154 +57,3 @@ func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User,
|
|||
// Explicitly no authorization check. Should only be used by middleware.
|
||||
return svc.ds.UserByID(ctx, id)
|
||||
}
|
||||
|
||||
// setNewPassword is a helper for changing a user's password. It should be
|
||||
// called to set the new password after proper authorization has been
|
||||
// performed.
|
||||
func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, password string) error {
|
||||
err := user.SetPassword(password, svc.config.Auth.SaltKeySize, svc.config.Auth.BcryptCost)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
if user.SSOEnabled {
|
||||
return ctxerr.New(ctx, "set password for single sign on user not allowed")
|
||||
}
|
||||
err = svc.saveUser(ctx, user)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "saving changed password")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) ResetPassword(ctx context.Context, token, password string) error {
|
||||
// skipauth: No viewer context available. The user is locked out of their
|
||||
// account and authNZ is performed entirely by providing a valid password
|
||||
// reset token.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
reset, err := svc.ds.FindPassswordResetByToken(ctx, token)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "looking up reset by token")
|
||||
}
|
||||
user, err := svc.ds.UserByID(ctx, reset.UserID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "retrieving user")
|
||||
}
|
||||
|
||||
if user.SSOEnabled {
|
||||
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
|
||||
// prevent setting the same password
|
||||
if err := user.ValidatePassword(password); err == nil {
|
||||
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
}
|
||||
|
||||
err = svc.setNewPassword(ctx, user, password)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
|
||||
// delete password reset tokens for user
|
||||
if err := svc.ds.DeletePasswordResetRequestsForUser(ctx, user.ID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete password reset requests")
|
||||
}
|
||||
|
||||
// Clear sessions so that any other browsers will have to log in with
|
||||
// the new password
|
||||
if err := svc.ds.DestroyAllSessionsForUser(ctx, user.ID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete user sessions")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password string) (*fleet.User, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fleet.ErrNoContext
|
||||
}
|
||||
user := vc.User
|
||||
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.SSOEnabled {
|
||||
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
if !user.IsAdminForcedPasswordReset() {
|
||||
return nil, ctxerr.New(ctx, "user does not require password reset")
|
||||
}
|
||||
|
||||
// prevent setting the same password
|
||||
if err := user.ValidatePassword(password); err == nil {
|
||||
return nil, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
}
|
||||
|
||||
user.AdminForcedPasswordReset = false
|
||||
err := svc.setNewPassword(ctx, user, password)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
|
||||
// Sessions should already have been cleared when the reset was
|
||||
// required
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error {
|
||||
// skipauth: No viewer context available. The user is locked out of their
|
||||
// account and trying to reset their password.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
// Regardless of error, sleep until the request has taken at least 1 second.
|
||||
// This means that any request to this method will take ~1s and frustrate a timing attack.
|
||||
defer func(start time.Time) {
|
||||
time.Sleep(time.Until(start.Add(1 * time.Second)))
|
||||
}(time.Now())
|
||||
|
||||
user, err := svc.ds.UserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.SSOEnabled {
|
||||
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
|
||||
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token := base64.URLEncoding.EncodeToString([]byte(random))
|
||||
|
||||
request := &fleet.PasswordResetRequest{
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
UserID: user.ID,
|
||||
Token: token,
|
||||
}
|
||||
_, err = svc.ds.NewPasswordResetRequest(ctx, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetEmail := fleet.Email{
|
||||
Subject: "Reset Your Fleet Password",
|
||||
To: []string{user.Email},
|
||||
Config: config,
|
||||
Mailer: &mail.PasswordResetMailer{
|
||||
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
|
||||
AssetURL: getAssetURL(),
|
||||
Token: token,
|
||||
},
|
||||
}
|
||||
|
||||
return svc.mailService.SendEmail(resetEmail)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,188 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuthenticatedUser(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
createTestUsers(t, ds)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
admin1, err := ds.UserByEmail(context.Background(), "admin1@example.com")
|
||||
assert.Nil(t, err)
|
||||
admin1Session, err := ds.NewSession(context.Background(), &fleet.Session{
|
||||
UserID: admin1.ID,
|
||||
Key: "admin1",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: admin1, Session: admin1Session})
|
||||
user, err := svc.AuthenticatedUser(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user, admin1)
|
||||
}
|
||||
|
||||
func TestResetPassword(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
createTestUsers(t, ds)
|
||||
passwordResetTests := []struct {
|
||||
token string
|
||||
newPassword string
|
||||
wantErr error
|
||||
}{
|
||||
{ // all good
|
||||
token: "abcd",
|
||||
newPassword: "123cat!",
|
||||
},
|
||||
{ // prevent reuse
|
||||
token: "abcd",
|
||||
newPassword: "123cat!",
|
||||
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
|
||||
},
|
||||
{ // bad token
|
||||
token: "dcbaz",
|
||||
newPassword: "123cat!",
|
||||
wantErr: sql.ErrNoRows,
|
||||
},
|
||||
{ // missing token
|
||||
newPassword: "123cat!",
|
||||
wantErr: fleet.NewInvalidArgumentError("token", "Token cannot be empty field"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range passwordResetTests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
request := &fleet.PasswordResetRequest{
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
UserID: 1,
|
||||
Token: "abcd",
|
||||
}
|
||||
_, err := ds.NewPasswordResetRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
serr := svc.ResetPassword(test.UserContext(&fleet.User{ID: 1}), tt.token, tt.newPassword)
|
||||
if tt.wantErr != nil {
|
||||
assert.Equal(t, tt.wantErr.Error(), ctxerr.Cause(serr).Error())
|
||||
} else {
|
||||
assert.Nil(t, serr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context {
|
||||
reloadedUser, err := ds.UserByEmail(ctx, user.Email)
|
||||
require.NoError(t, err)
|
||||
|
||||
return viewer.NewContext(ctx, viewer.Viewer{User: reloadedUser, Session: session})
|
||||
}
|
||||
|
||||
func TestPerformRequiredPasswordReset(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
createTestUsers(t, ds)
|
||||
|
||||
for _, tt := range testUsers {
|
||||
t.Run(tt.Email, func(t *testing.T) {
|
||||
user, err := ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
ctx = refreshCtx(t, ctx, user, ds, nil)
|
||||
|
||||
session, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID})
|
||||
require.Nil(t, err)
|
||||
ctx = refreshCtx(t, ctx, user, ds, session)
|
||||
|
||||
// should error when reset not required
|
||||
_, err = svc.RequirePasswordReset(ctx, user.ID, false)
|
||||
require.Nil(t, err)
|
||||
ctx = refreshCtx(t, ctx, user, ds, session)
|
||||
_, err = svc.PerformRequiredPasswordReset(ctx, "new_pass")
|
||||
require.NotNil(t, err)
|
||||
|
||||
_, err = svc.RequirePasswordReset(ctx, user.ID, true)
|
||||
require.Nil(t, err)
|
||||
ctx = refreshCtx(t, ctx, user, ds, session)
|
||||
|
||||
// should error when using same password
|
||||
_, err = svc.PerformRequiredPasswordReset(ctx, tt.PlaintextPassword)
|
||||
require.Equal(t, "validation failed: new_password cannot reuse old password", err.Error())
|
||||
|
||||
// should succeed with good new password
|
||||
u, err := svc.PerformRequiredPasswordReset(ctx, "new_pass")
|
||||
require.Nil(t, err)
|
||||
assert.False(t, u.AdminForcedPasswordReset)
|
||||
|
||||
ctx = context.Background()
|
||||
|
||||
// Now user should be able to login with new password
|
||||
u, _, err = svc.Login(ctx, tt.Email, "new_pass")
|
||||
require.Nil(t, err)
|
||||
assert.False(t, u.AdminForcedPasswordReset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPasswordRequirements(t *testing.T) {
|
||||
passwordTests := []struct {
|
||||
password string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
password: "foobar",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
password: "foobarbaz",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
password: "foobarbaz!",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
password: "foobarbaz!3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range passwordTests {
|
||||
t.Run(tt.password, func(t *testing.T) {
|
||||
err := validatePasswordRequirements(tt.password)
|
||||
if tt.wantErr {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,25 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/sso"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -91,3 +106,480 @@ func (svc *Service) DeleteSession(ctx context.Context, id uint) error {
|
|||
|
||||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Login
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
AvailableTeams []*fleet.TeamSummary `json:"available_teams"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r loginResponse) error() error { return r.Err }
|
||||
|
||||
func loginEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*loginRequest)
|
||||
req.Email = strings.ToLower(req.Email)
|
||||
|
||||
user, token, err := svc.Login(ctx, req.Email, req.Password)
|
||||
if err != nil {
|
||||
return loginResponse{Err: err}, nil
|
||||
}
|
||||
// Add viewer context allow access to service teams for list of available teams
|
||||
v, err := authViewer(ctx, token, svc)
|
||||
if err != nil {
|
||||
return loginResponse{Err: err}, nil
|
||||
}
|
||||
ctx = viewer.NewContext(ctx, *v)
|
||||
availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, fleet.ErrMissingLicense) {
|
||||
availableTeams = []*fleet.TeamSummary{}
|
||||
} else {
|
||||
return loginResponse{Err: err}, nil
|
||||
}
|
||||
}
|
||||
return loginResponse{user, availableTeams, token, nil}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) Login(ctx context.Context, email, password string) (*fleet.User, string, error) {
|
||||
// skipauth: No user context available yet to authorize against.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(logging.WithNoUser(ctx), level.Info)
|
||||
|
||||
// If there is an error, sleep until the request has taken at least 1
|
||||
// second. This means that generally a login failure for any reason will
|
||||
// take ~1s and frustrate a timing attack.
|
||||
var err error
|
||||
defer func(start time.Time) {
|
||||
if err != nil {
|
||||
time.Sleep(time.Until(start.Add(1 * time.Second)))
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
user, err := svc.ds.UserByEmail(ctx, email)
|
||||
var nfe fleet.NotFoundError
|
||||
if errors.As(err, &nfe) {
|
||||
return nil, "", fleet.NewAuthFailedError("user not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fleet.NewAuthFailedError(err.Error())
|
||||
}
|
||||
|
||||
if err = user.ValidatePassword(password); err != nil {
|
||||
return nil, "", fleet.NewAuthFailedError("invalid password")
|
||||
}
|
||||
|
||||
if user.SSOEnabled {
|
||||
return nil, "", fleet.NewAuthFailedError("password login disabled for sso users")
|
||||
}
|
||||
|
||||
token, err := svc.makeSession(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, "", fleet.NewAuthFailedError(err.Error())
|
||||
}
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Logout
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type logoutResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r logoutResponse) error() error { return r.Err }
|
||||
|
||||
func logoutEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
err := svc.Logout(ctx)
|
||||
if err != nil {
|
||||
return logoutResponse{Err: err}, nil
|
||||
}
|
||||
return logoutResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) Logout(ctx context.Context) error {
|
||||
// skipauth: Any user can always log out of their own session.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
// TODO: this should not return an error if the user wasn't logged in
|
||||
return svc.DestroySession(ctx)
|
||||
}
|
||||
|
||||
func (svc *Service) DestroySession(ctx context.Context) error {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return fleet.ErrNoContext
|
||||
}
|
||||
|
||||
session, err := svc.ds.SessionByID(ctx, vc.SessionID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Initiate SSO
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type initiateSSORequest struct {
|
||||
RelayURL string `json:"relay_url"`
|
||||
}
|
||||
|
||||
type initiateSSOResponse struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r initiateSSOResponse) error() error { return r.Err }
|
||||
|
||||
func initiateSSOEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*initiateSSORequest)
|
||||
idProviderURL, err := svc.InitiateSSO(ctx, req.RelayURL)
|
||||
if err != nil {
|
||||
return initiateSSOResponse{Err: err}, nil
|
||||
}
|
||||
return initiateSSOResponse{URL: idProviderURL}, nil
|
||||
}
|
||||
|
||||
// InitiateSSO initiates a Single Sign-On flow for a request to visit the
|
||||
// protected URL identified by redirectURL. It returns the URL of the identity
|
||||
// provider to make a request to to proceed with the authentication via that
|
||||
// external service, and stores ephemeral session state to validate the
|
||||
// callback from the identity provider to finalize the SSO flow.
|
||||
func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string, error) {
|
||||
// skipauth: User context does not yet exist. Unauthenticated users may
|
||||
// initiate SSO.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config")
|
||||
}
|
||||
|
||||
metadata, err := svc.getMetadata(appConfig)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata")
|
||||
}
|
||||
|
||||
serverURL := appConfig.ServerSettings.ServerURL
|
||||
settings := sso.Settings{
|
||||
Metadata: metadata,
|
||||
// Construct call back url to send to idp
|
||||
AssertionConsumerServiceURL: serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/sso/callback",
|
||||
SessionStore: svc.ssoSessionStore,
|
||||
OriginalURL: redirectURL,
|
||||
}
|
||||
|
||||
// If issuer is not explicitly set, default to host name.
|
||||
var issuer string
|
||||
entityID := appConfig.SSOSettings.EntityID
|
||||
if entityID == "" {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "parse server url")
|
||||
}
|
||||
issuer = u.Hostname()
|
||||
} else {
|
||||
issuer = entityID
|
||||
}
|
||||
|
||||
idpURL, err := sso.CreateAuthorizationRequest(&settings, issuer)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization")
|
||||
}
|
||||
|
||||
return idpURL, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Callback SSO
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type callbackSSORequest struct{}
|
||||
|
||||
func (callbackSSORequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decode sso callback")
|
||||
}
|
||||
authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse"))
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decoding sso callback")
|
||||
}
|
||||
return authResponse, nil
|
||||
}
|
||||
|
||||
type callbackSSOResponse struct {
|
||||
content string
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r callbackSSOResponse) error() error { return r.Err }
|
||||
|
||||
// If html is present we return a web page
|
||||
func (r callbackSSOResponse) html() string { return r.content }
|
||||
|
||||
func makeCallbackSSOEndpoint(urlPrefix string) handlerFunc {
|
||||
return func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
authResponse := request.(fleet.Auth)
|
||||
session, err := svc.CallbackSSO(ctx, authResponse)
|
||||
var resp callbackSSOResponse
|
||||
if err != nil {
|
||||
// redirect to login page on front end if there was some problem,
|
||||
// errors should still be logged
|
||||
session = &fleet.SSOSession{
|
||||
RedirectURL: urlPrefix + "/login",
|
||||
Token: "",
|
||||
}
|
||||
resp.Err = err
|
||||
}
|
||||
relayStateLoadPage := ` <html>
|
||||
<script type='text/javascript'>
|
||||
var redirectURL = {{ .RedirectURL }};
|
||||
window.localStorage.setItem('FLEET::auth_token', '{{ .Token }}');
|
||||
window.location = redirectURL;
|
||||
</script>
|
||||
<body>
|
||||
Redirecting to Fleet at {{ .RedirectURL }} ...
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
tmpl, err := template.New("relayStateLoader").Parse(relayStateLoadPage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var writer bytes.Buffer
|
||||
err = tmpl.Execute(&writer, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.content = writer.String()
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SSOSession, error) {
|
||||
// skipauth: User context does not yet exist. Unauthenticated users may
|
||||
// hit the SSO callback.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "get config for sso")
|
||||
}
|
||||
|
||||
// Load the request metadata if available
|
||||
|
||||
// localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080
|
||||
var metadata *sso.Metadata
|
||||
var redirectURL string
|
||||
|
||||
if appConfig.SSOSettings.EnableSSOIdPLogin && auth.RequestID() == "" {
|
||||
// Missing request ID indicates this was IdP-initiated. Only allow if
|
||||
// configured to do so.
|
||||
metadata, err = svc.getMetadata(appConfig)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "get sso metadata")
|
||||
}
|
||||
redirectURL = "/"
|
||||
} else {
|
||||
session, err := svc.ssoSessionStore.Get(auth.RequestID())
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "sso request invalid")
|
||||
}
|
||||
// Remove session to so that is can't be reused before it expires.
|
||||
err = svc.ssoSessionStore.Expire(auth.RequestID())
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "remove sso request")
|
||||
}
|
||||
if err := xml.Unmarshal([]byte(session.Metadata), &metadata); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "unmarshal metadata")
|
||||
}
|
||||
redirectURL = session.OriginalURL
|
||||
}
|
||||
|
||||
// Validate response
|
||||
validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience(
|
||||
appConfig.SSOSettings.EntityID,
|
||||
appConfig.ServerSettings.ServerURL,
|
||||
appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS
|
||||
))
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "create validator from metadata")
|
||||
}
|
||||
// make sure the response hasn't been tampered with
|
||||
auth, err = validator.ValidateSignature(auth)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "signature validation failed")
|
||||
}
|
||||
// make sure the response isn't stale
|
||||
err = validator.ValidateResponse(auth)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "response validation failed")
|
||||
}
|
||||
|
||||
// Get and log in user
|
||||
user, err := svc.ds.UserByEmail(ctx, auth.UserID())
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "find user in sso callback")
|
||||
}
|
||||
// if the user is not sso enabled they are not authorized
|
||||
if !user.SSOEnabled {
|
||||
return nil, ctxerr.New(ctx, "user not configured to use sso")
|
||||
}
|
||||
token, err := svc.makeSession(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "make session in sso callback")
|
||||
}
|
||||
result := &fleet.SSOSession{
|
||||
Token: token,
|
||||
RedirectURL: redirectURL,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// SSO Settings
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type ssoSettingsResponse struct {
|
||||
Settings *fleet.SessionSSOSettings `json:"settings,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r ssoSettingsResponse) error() error { return r.Err }
|
||||
|
||||
func settingsSSOEndpoint(ctx context.Context, _ interface{}, svc fleet.Service) (interface{}, error) {
|
||||
settings, err := svc.SSOSettings(ctx)
|
||||
if err != nil {
|
||||
return ssoSettingsResponse{Err: err}, nil
|
||||
}
|
||||
return ssoSettingsResponse{Settings: settings}, nil
|
||||
}
|
||||
|
||||
// SSOSettings returns a subset of the Single Sign-On settings as configured in
|
||||
// the app config. Those can be exposed e.g. via the response to an HTTP request,
|
||||
// and as such should not contain sensitive information.
|
||||
func (svc *Service) SSOSettings(ctx context.Context) (*fleet.SessionSSOSettings, error) {
|
||||
// skipauth: Basic SSO settings are available to unauthenticated users (so
|
||||
// that they have the necessary information to initiate SSO).
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
logging.WithLevel(ctx, level.Info)
|
||||
|
||||
appConfig, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "SessionSSOSettings getting app config")
|
||||
}
|
||||
|
||||
settings := &fleet.SessionSSOSettings{
|
||||
IDPName: appConfig.SSOSettings.IDPName,
|
||||
IDPImageURL: appConfig.SSOSettings.IDPImageURL,
|
||||
SSOEnabled: appConfig.SSOSettings.EnableSSO,
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// makeSession is a helper that creates a new session after authentication
|
||||
func (svc *Service) makeSession(ctx context.Context, id uint) (string, error) {
|
||||
sessionKeySize := svc.config.Session.KeySize
|
||||
key := make([]byte, sessionKeySize)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sessionKey := base64.StdEncoding.EncodeToString(key)
|
||||
session := &fleet.Session{
|
||||
UserID: id,
|
||||
Key: sessionKey,
|
||||
AccessedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
_, err = svc.ds.NewSession(ctx, session)
|
||||
if err != nil {
|
||||
return "", ctxerr.Wrap(ctx, err, "creating new session")
|
||||
}
|
||||
|
||||
return sessionKey, nil
|
||||
}
|
||||
|
||||
func (svc *Service) getMetadata(config *fleet.AppConfig) (*sso.Metadata, error) {
|
||||
if config.SSOSettings.MetadataURL != "" {
|
||||
metadata, err := sso.GetMetadata(config.SSOSettings.MetadataURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
if config.SSOSettings.Metadata != "" {
|
||||
metadata, err := sso.ParseMetadata(config.SSOSettings.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("missing metadata for idp %s", config.SSOSettings.IDPName)
|
||||
}
|
||||
|
||||
func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
session, err := svc.ds.SessionByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = svc.validateSession(ctx, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error {
|
||||
if session == nil {
|
||||
return fleet.NewAuthRequiredError("active session not present")
|
||||
}
|
||||
|
||||
sessionDuration := svc.config.Session.Duration
|
||||
if session.APIOnly != nil && *session.APIOnly {
|
||||
sessionDuration = 0 // make API-only tokens unlimited
|
||||
}
|
||||
|
||||
// duration 0 = unlimited
|
||||
if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration {
|
||||
err := svc.ds.DestroySession(ctx, session)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "destroying session")
|
||||
}
|
||||
return fleet.NewAuthRequiredError("expired session")
|
||||
}
|
||||
|
||||
return svc.ds.MarkSessionAccessed(ctx, session)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSessionAuth(t *testing.T) {
|
||||
|
|
@ -85,3 +90,98 @@ func TestSessionAuth(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
defer ds.Close()
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
createTestUsers(t, ds)
|
||||
|
||||
var loginTests = []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "admin1",
|
||||
email: testUsers["admin1"].Email,
|
||||
password: testUsers["admin1"].PlaintextPassword,
|
||||
},
|
||||
{
|
||||
name: "user1",
|
||||
email: testUsers["user1"].Email,
|
||||
password: testUsers["user1"].PlaintextPassword,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range loginTests {
|
||||
t.Run(tt.email, func(st *testing.T) {
|
||||
loggedIn, token, err := svc.Login(test.UserContext(test.UserAdmin), tt.email, tt.password)
|
||||
require.Nil(st, err, "login unsuccessful")
|
||||
assert.Equal(st, tt.email, loggedIn.Email)
|
||||
assert.NotEmpty(st, token)
|
||||
|
||||
sessions, err := svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), loggedIn.ID)
|
||||
require.Nil(st, err)
|
||||
require.Len(st, sessions, 1, "user should have one session")
|
||||
session := sessions[0]
|
||||
assert.NotZero(st, session.UserID)
|
||||
assert.WithinDuration(st, time.Now(), session.AccessedAt, 3*time.Second,
|
||||
"access time should be set with current time at session creation")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSessionByKey(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
cfg := config.TestConfig()
|
||||
|
||||
theSession := &fleet.Session{UserID: 123, Key: "abc"}
|
||||
|
||||
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
return theSession, nil
|
||||
}
|
||||
ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
accessed time.Duration
|
||||
apiOnly bool
|
||||
fail bool
|
||||
}{
|
||||
{"real user, accessed recently", -1 * time.Hour, false, false},
|
||||
{"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true},
|
||||
{"api-only, accessed recently", -1 * time.Hour, true, false},
|
||||
{"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
var authErr *fleet.AuthRequiredError
|
||||
ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false
|
||||
|
||||
theSession.AccessedAt = time.Now().Add(tc.accessed)
|
||||
theSession.APIOnly = ptr.Bool(tc.apiOnly)
|
||||
_, err := svc.GetSessionByKey(context.Background(), theSession.Key)
|
||||
if tc.fail {
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &authErr)
|
||||
require.True(t, ds.SessionByKeyFuncInvoked)
|
||||
require.True(t, ds.DestroySessionFuncInvoked)
|
||||
require.False(t, ds.MarkSessionAccessedFuncInvoked)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.True(t, ds.SessionByKeyFuncInvoked)
|
||||
require.False(t, ds.DestroySessionFuncInvoked)
|
||||
require.True(t, ds.MarkSessionAccessedFuncInvoked)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,9 +36,10 @@ func (ts *withDS) TearDownSuite() {
|
|||
type withServer struct {
|
||||
withDS
|
||||
|
||||
server *httptest.Server
|
||||
users map[string]fleet.User
|
||||
token string
|
||||
server *httptest.Server
|
||||
users map[string]fleet.User
|
||||
token string
|
||||
cachedAdminToken string
|
||||
}
|
||||
|
||||
func (ts *withServer) SetupSuite(dbName string) {
|
||||
|
|
@ -49,6 +50,7 @@ func (ts *withServer) SetupSuite(dbName string) {
|
|||
ts.server = server
|
||||
ts.users = users
|
||||
ts.token = ts.getTestAdminToken()
|
||||
ts.cachedAdminToken = ts.token
|
||||
}
|
||||
|
||||
func (ts *withServer) TearDownSuite() {
|
||||
|
|
@ -122,7 +124,13 @@ func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStat
|
|||
func (ts *withServer) getTestAdminToken() string {
|
||||
testUser := testUsers["admin1"]
|
||||
|
||||
return ts.getTestToken(testUser.Email, testUser.PlaintextPassword)
|
||||
// because the login endpoint is rate-limited, use the cached admin token
|
||||
// if available (if for some reason a test needs to logout the admin user,
|
||||
// then set cachedAdminToken = "" so that a new token is retrieved).
|
||||
if ts.cachedAdminToken == "" {
|
||||
ts.cachedAdminToken = ts.getTestToken(testUser.Email, testUser.PlaintextPassword)
|
||||
}
|
||||
return ts.cachedAdminToken
|
||||
}
|
||||
|
||||
func (ts *withServer) getTestToken(email string, password string) string {
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ func translateHostToID(ctx context.Context, ds fleet.Datastore, identifier strin
|
|||
return host.ID, nil
|
||||
}
|
||||
|
||||
func (svc Service) Translate(ctx context.Context, payloads []fleet.TranslatePayload) ([]fleet.TranslatePayload, error) {
|
||||
func (svc *Service) Translate(ctx context.Context, payloads []fleet.TranslatePayload) ([]fleet.TranslatePayload, error) {
|
||||
var finalPayload []fleet.TranslatePayload
|
||||
|
||||
for _, payload := range payloads {
|
||||
|
|
|
|||
|
|
@ -294,10 +294,6 @@ func userListOptionsFromRequest(r *http.Request) (fleet.UserListOptions, error)
|
|||
return uopt, nil
|
||||
}
|
||||
|
||||
func decodeNoParamsRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type getGenericSpecRequest struct {
|
||||
Name string `url:"name"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
)
|
||||
|
||||
func decodeCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
defer r.Body.Close()
|
||||
|
||||
var req carveBlockRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decoding JSON")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func decodeVerifyInviteRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
vars := mux.Vars(r)
|
||||
token, ok := vars["token"]
|
||||
if !ok {
|
||||
return 0, errBadRoute
|
||||
}
|
||||
return verifyInviteRequest{Token: token}, nil
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func decodeEnrollAgentRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req enrollAgentRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeEnrollAgentRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeEnrollAgentRequest(context.Background(), request)
|
||||
require.Nil(t, err)
|
||||
|
||||
params := r.(enrollAgentRequest)
|
||||
assert.Equal(t, "secret", params.EnrollSecret)
|
||||
assert.Equal(t, "uuid", params.HostIdentifier)
|
||||
}).Methods("POST")
|
||||
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"enroll_secret": "secret",
|
||||
"host_identifier": "uuid"
|
||||
}`))
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("POST", "/", &body),
|
||||
)
|
||||
}
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/sso"
|
||||
)
|
||||
|
||||
func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Email = strings.ToLower(req.Email)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeInitiateSSORequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req initiateSSORequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeCallbackSSORequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decode sso callback")
|
||||
}
|
||||
authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse"))
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decoding sso callback")
|
||||
}
|
||||
return authResponse, nil
|
||||
}
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeLoginRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeLoginRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(loginRequest)
|
||||
assert.Equal(t, "foo", params.Email)
|
||||
assert.Equal(t, "bar", params.Password)
|
||||
}).Methods("POST")
|
||||
t.Run("lowercase email", func(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"email": "foo",
|
||||
"password": "bar"
|
||||
}`))
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("POST", "/api/v1/fleet/login", &body),
|
||||
)
|
||||
})
|
||||
t.Run("uppercase email", func(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"email": "Foo",
|
||||
"password": "bar"
|
||||
}`))
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("POST", "/api/v1/fleet/login", &body),
|
||||
)
|
||||
})
|
||||
|
||||
}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
)
|
||||
|
||||
func decodeCreateUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req createUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodePerformRequiredPasswordResetRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req performRequiredPasswordResetRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decoding JSON")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeForgotPasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req forgotPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeResetPasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req resetPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeResetPasswordRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/users/{id}/password", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeResetPasswordRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(resetPasswordRequest)
|
||||
assert.Equal(t, "bar", params.NewPassword)
|
||||
assert.Equal(t, "baz", params.PasswordResetToken)
|
||||
}).Methods("POST")
|
||||
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"new_password": "bar",
|
||||
"password_reset_token": "baz"
|
||||
}`))
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("POST", "/api/v1/fleet/users/1/password", &body),
|
||||
)
|
||||
}
|
||||
|
|
@ -26,7 +26,7 @@ func applyUserRoleSpecsEndpoint(ctx context.Context, request interface{}, svc fl
|
|||
return applyUserRoleSpecsResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRoleSpec) error {
|
||||
func (svc *Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRoleSpec) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -61,7 +61,7 @@ func (svc Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRol
|
|||
return svc.ds.SaveUsers(ctx, users)
|
||||
}
|
||||
|
||||
func (svc Service) checkAtLeastOneAdmin(ctx context.Context, user *fleet.User, spec *fleet.UserRoleSpec, email string) error {
|
||||
func (svc *Service) checkAtLeastOneAdmin(ctx context.Context, user *fleet.User, spec *fleet.UserRoleSpec, email string) error {
|
||||
if null.StringFromPtr(user.GlobalRole).ValueOrZero() == fleet.RoleAdmin &&
|
||||
null.StringFromPtr(spec.GlobalRole).ValueOrZero() != fleet.RoleAdmin {
|
||||
users, err := svc.ds.ListUsers(ctx, fleet.UserListOptions{})
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import (
|
|||
"encoding/base64"
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
|
|
@ -49,6 +51,10 @@ func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.VerifyAdminCreate(); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "verify user payload")
|
||||
}
|
||||
|
||||
if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil {
|
||||
return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email)
|
||||
}
|
||||
|
|
@ -61,6 +67,49 @@ func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet
|
|||
return svc.newUser(ctx, p)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create User From Invite
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func createUserFromInviteEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*createUserRequest)
|
||||
user, err := svc.CreateUserFromInvite(ctx, req.UserPayload)
|
||||
if err != nil {
|
||||
return createUserResponse{Err: err}, nil
|
||||
}
|
||||
return createUserResponse{User: user}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
// skipauth: There is no viewer context at this point. We rely on verifying
|
||||
// the invite for authNZ.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
if err := p.VerifyInviteCreate(); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "verify user payload")
|
||||
}
|
||||
|
||||
invite, err := svc.VerifyInvite(ctx, *p.InviteToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set the payload role property based on an existing invite.
|
||||
p.GlobalRole = invite.GlobalRole.Ptr()
|
||||
p.Teams = &invite.Teams
|
||||
|
||||
user, err := svc.newUser(ctx, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = svc.ds.DeleteInvite(ctx, invite.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// List Users
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -215,6 +264,14 @@ func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPay
|
|||
return nil, err
|
||||
}
|
||||
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, ctxerr.New(ctx, "viewer not present") // should never happen, authorize would've failed
|
||||
}
|
||||
if err := p.VerifyModify(vc.UserID() == userID); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "verify user payload")
|
||||
}
|
||||
|
||||
if p.GlobalRole != nil || p.Teams != nil {
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -386,13 +443,21 @@ func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string)
|
|||
return err
|
||||
}
|
||||
|
||||
if oldPass == "" {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"))
|
||||
}
|
||||
if newPass == "" {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "New password cannot be empty"))
|
||||
}
|
||||
if err := fleet.ValidatePasswordRequirements(newPass); err != nil {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", err.Error()))
|
||||
}
|
||||
if vc.User.SSOEnabled {
|
||||
return ctxerr.New(ctx, "change password for single sign on user not allowed")
|
||||
}
|
||||
if err := vc.User.ValidatePassword(newPass); err == nil {
|
||||
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"))
|
||||
}
|
||||
|
||||
if err := vc.User.ValidatePassword(oldPass); err != nil {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match"))
|
||||
}
|
||||
|
|
@ -622,3 +687,245 @@ func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, em
|
|||
func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error {
|
||||
return svc.ds.SaveUser(ctx, user)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Perform Required Password Reset
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type performRequiredPasswordResetRequest struct {
|
||||
Password string `json:"new_password"`
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
|
||||
type performRequiredPasswordResetResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r performRequiredPasswordResetResponse) error() error { return r.Err }
|
||||
|
||||
func performRequiredPasswordResetEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*performRequiredPasswordResetRequest)
|
||||
user, err := svc.PerformRequiredPasswordReset(ctx, req.Password)
|
||||
if err != nil {
|
||||
return performRequiredPasswordResetResponse{Err: err}, nil
|
||||
}
|
||||
return performRequiredPasswordResetResponse{User: user}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password string) (*fleet.User, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fleet.ErrNoContext
|
||||
}
|
||||
if !vc.CanPerformPasswordReset() {
|
||||
return nil, fleet.NewPermissionError("cannot reset password")
|
||||
}
|
||||
user := vc.User
|
||||
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.SSOEnabled {
|
||||
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
if !user.IsAdminForcedPasswordReset() {
|
||||
return nil, ctxerr.New(ctx, "user does not require password reset")
|
||||
}
|
||||
|
||||
// prevent setting the same password
|
||||
if err := user.ValidatePassword(password); err == nil {
|
||||
return nil, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
}
|
||||
|
||||
user.AdminForcedPasswordReset = false
|
||||
err := svc.setNewPassword(ctx, user, password)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
|
||||
// Sessions should already have been cleared when the reset was
|
||||
// required
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// setNewPassword is a helper for changing a user's password. It should be
|
||||
// called to set the new password after proper authorization has been
|
||||
// performed.
|
||||
func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, password string) error {
|
||||
err := user.SetPassword(password, svc.config.Auth.SaltKeySize, svc.config.Auth.BcryptCost)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
if user.SSOEnabled {
|
||||
return ctxerr.New(ctx, "set password for single sign on user not allowed")
|
||||
}
|
||||
err = svc.saveUser(ctx, user)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "saving changed password")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Reset Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type resetPasswordRequest struct {
|
||||
PasswordResetToken string `json:"password_reset_token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type resetPasswordResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r resetPasswordResponse) error() error { return r.Err }
|
||||
|
||||
func resetPasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*resetPasswordRequest)
|
||||
err := svc.ResetPassword(ctx, req.PasswordResetToken, req.NewPassword)
|
||||
return resetPasswordResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ResetPassword(ctx context.Context, token, password string) error {
|
||||
// skipauth: No viewer context available. The user is locked out of their
|
||||
// account and authNZ is performed entirely by providing a valid password
|
||||
// reset token.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
if token == "" {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("token", "Token cannot be empty field"))
|
||||
}
|
||||
if password == "" {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "New password cannot be empty field"))
|
||||
}
|
||||
if err := fleet.ValidatePasswordRequirements(password); err != nil {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", err.Error()))
|
||||
}
|
||||
|
||||
reset, err := svc.ds.FindPasswordResetByToken(ctx, token)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "looking up reset by token")
|
||||
}
|
||||
user, err := svc.ds.UserByID(ctx, reset.UserID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "retrieving user")
|
||||
}
|
||||
|
||||
if user.SSOEnabled {
|
||||
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
|
||||
// prevent setting the same password
|
||||
if err := user.ValidatePassword(password); err == nil {
|
||||
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
}
|
||||
|
||||
err = svc.setNewPassword(ctx, user, password)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
|
||||
// delete password reset tokens for user
|
||||
if err := svc.ds.DeletePasswordResetRequestsForUser(ctx, user.ID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete password reset requests")
|
||||
}
|
||||
|
||||
// Clear sessions so that any other browsers will have to log in with
|
||||
// the new password
|
||||
if err := svc.ds.DestroyAllSessionsForUser(ctx, user.ID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete user sessions")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Forgot Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type forgotPasswordRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type forgotPasswordResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r forgotPasswordResponse) error() error { return r.Err }
|
||||
func (r forgotPasswordResponse) status() int { return http.StatusAccepted }
|
||||
|
||||
func forgotPasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*forgotPasswordRequest)
|
||||
// Any error returned by the service should not be returned to the
|
||||
// client to prevent information disclosure (it will be logged in the
|
||||
// server logs).
|
||||
_ = svc.RequestPasswordReset(ctx, req.Email)
|
||||
return forgotPasswordResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error {
|
||||
// skipauth: No viewer context available. The user is locked out of their
|
||||
// account and trying to reset their password.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
// Regardless of error, sleep until the request has taken at least 1 second.
|
||||
// This means that any request to this method will take ~1s and frustrate a timing attack.
|
||||
defer func(start time.Time) {
|
||||
time.Sleep(time.Until(start.Add(1 * time.Second)))
|
||||
}(time.Now())
|
||||
|
||||
user, err := svc.ds.UserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.SSOEnabled {
|
||||
return ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
|
||||
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token := base64.URLEncoding.EncodeToString([]byte(random))
|
||||
|
||||
request := &fleet.PasswordResetRequest{
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
UserID: user.ID,
|
||||
Token: token,
|
||||
}
|
||||
_, err = svc.ds.NewPasswordResetRequest(ctx, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := svc.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetEmail := fleet.Email{
|
||||
Subject: "Reset Your Fleet Password",
|
||||
To: []string{user.Email},
|
||||
Config: config,
|
||||
Mailer: &mail.PasswordResetMailer{
|
||||
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
|
||||
AssetURL: getAssetURL(),
|
||||
Token: token,
|
||||
},
|
||||
}
|
||||
|
||||
return svc.mailService.SendEmail(resetEmail)
|
||||
}
|
||||
|
||||
func (svc *Service) ListAvailableTeamsForUser(ctx context.Context, user *fleet.User) ([]*fleet.TeamSummary, error) {
|
||||
// skipauth: No authorization check needed due to implementation returning
|
||||
// only license error.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
return nil, fleet.ErrMissingLicense
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -471,6 +472,7 @@ func testUsersChangePassword(t *testing.T, ds *mysql.Datastore) {
|
|||
anyErr: true,
|
||||
},
|
||||
{ // missing old password
|
||||
user: users["user1@example.com"],
|
||||
newPassword: "123cataaa!",
|
||||
wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"),
|
||||
},
|
||||
|
|
@ -540,3 +542,141 @@ func testUsersRequirePasswordReset(t *testing.T, ds *mysql.Datastore) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformRequiredPasswordReset(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
createTestUsers(t, ds)
|
||||
|
||||
for _, tt := range testUsers {
|
||||
t.Run(tt.Email, func(t *testing.T) {
|
||||
user, err := ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
|
||||
require.Nil(t, err)
|
||||
|
||||
ctx = refreshCtx(t, ctx, user, ds, nil)
|
||||
|
||||
session, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID})
|
||||
require.Nil(t, err)
|
||||
ctx = refreshCtx(t, ctx, user, ds, session)
|
||||
|
||||
// should error when reset not required
|
||||
_, err = svc.RequirePasswordReset(ctx, user.ID, false)
|
||||
require.Nil(t, err)
|
||||
ctx = refreshCtx(t, ctx, user, ds, session)
|
||||
_, err = svc.PerformRequiredPasswordReset(ctx, "new_pass")
|
||||
require.NotNil(t, err)
|
||||
|
||||
_, err = svc.RequirePasswordReset(ctx, user.ID, true)
|
||||
require.Nil(t, err)
|
||||
ctx = refreshCtx(t, ctx, user, ds, session)
|
||||
|
||||
// should error when using same password
|
||||
_, err = svc.PerformRequiredPasswordReset(ctx, tt.PlaintextPassword)
|
||||
require.Equal(t, "validation failed: new_password cannot reuse old password", err.Error())
|
||||
|
||||
// should succeed with good new password
|
||||
u, err := svc.PerformRequiredPasswordReset(ctx, "new_pass")
|
||||
require.Nil(t, err)
|
||||
assert.False(t, u.AdminForcedPasswordReset)
|
||||
|
||||
ctx = context.Background()
|
||||
|
||||
// Now user should be able to login with new password
|
||||
u, _, err = svc.Login(ctx, tt.Email, "new_pass")
|
||||
require.Nil(t, err)
|
||||
assert.False(t, u.AdminForcedPasswordReset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetPassword(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
createTestUsers(t, ds)
|
||||
passwordResetTests := []struct {
|
||||
token string
|
||||
newPassword string
|
||||
wantErr error
|
||||
}{
|
||||
{ // all good
|
||||
token: "abcd",
|
||||
newPassword: "123cat!",
|
||||
},
|
||||
{ // prevent reuse
|
||||
token: "abcd",
|
||||
newPassword: "123cat!",
|
||||
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
|
||||
},
|
||||
{ // bad token
|
||||
token: "dcbaz",
|
||||
newPassword: "123cat!",
|
||||
wantErr: sql.ErrNoRows,
|
||||
},
|
||||
{ // missing token
|
||||
newPassword: "123cat!",
|
||||
wantErr: fleet.NewInvalidArgumentError("token", "Token cannot be empty field"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range passwordResetTests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
request := &fleet.PasswordResetRequest{
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
CreateTimestamp: fleet.CreateTimestamp{
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
UserID: 1,
|
||||
Token: "abcd",
|
||||
}
|
||||
_, err := ds.NewPasswordResetRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
serr := svc.ResetPassword(test.UserContext(&fleet.User{ID: 1}), tt.token, tt.newPassword)
|
||||
if tt.wantErr != nil {
|
||||
assert.Equal(t, tt.wantErr.Error(), ctxerr.Cause(serr).Error())
|
||||
} else {
|
||||
assert.Nil(t, serr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context {
|
||||
reloadedUser, err := ds.UserByEmail(ctx, user.Email)
|
||||
require.NoError(t, err)
|
||||
|
||||
return viewer.NewContext(ctx, viewer.Viewer{User: reloadedUser, Session: session})
|
||||
}
|
||||
|
||||
func TestAuthenticatedUser(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
createTestUsers(t, ds)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
admin1, err := ds.UserByEmail(context.Background(), "admin1@example.com")
|
||||
assert.Nil(t, err)
|
||||
admin1Session, err := ds.NewSession(context.Background(), &fleet.Session{
|
||||
UserID: admin1.ID,
|
||||
Key: "admin1",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: admin1, Session: admin1Session})
|
||||
user, err := svc.AuthenticatedUser(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user, admin1)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,202 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"unicode"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
func (mw validationMiddleware) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
if p.Name == nil {
|
||||
invalid.Append("name", "Full name missing required argument")
|
||||
} else {
|
||||
if *p.Name == "" {
|
||||
invalid.Append("name", "Full name cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// we don't need a password for single sign on
|
||||
if p.SSOInvite == nil || !*p.SSOInvite {
|
||||
if p.Password == nil {
|
||||
invalid.Append("password", "Password missing required argument")
|
||||
} else {
|
||||
if *p.Password == "" {
|
||||
invalid.Append("password", "Password cannot be empty")
|
||||
}
|
||||
if err := validatePasswordRequirements(*p.Password); err != nil {
|
||||
invalid.Append("password", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.Email == nil {
|
||||
invalid.Append("email", "Email missing required argument")
|
||||
} else {
|
||||
if *p.Email == "" {
|
||||
invalid.Append("email", "Email cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if p.InviteToken == nil {
|
||||
invalid.Append("invite_token", "Invite token missing required argument")
|
||||
} else {
|
||||
if *p.InviteToken == "" {
|
||||
invalid.Append("invite_token", "Invite token cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
return mw.Service.CreateUserFromInvite(ctx, p)
|
||||
}
|
||||
|
||||
func (mw validationMiddleware) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
if p.Name == nil {
|
||||
invalid.Append("name", "Full name missing required argument")
|
||||
} else {
|
||||
if *p.Name == "" {
|
||||
invalid.Append("name", "Full name cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// we don't need a password for single sign on
|
||||
if (p.SSOInvite == nil || !*p.SSOInvite) && (p.SSOEnabled == nil || !*p.SSOEnabled) {
|
||||
if p.Password == nil {
|
||||
invalid.Append("password", "Password missing required argument")
|
||||
} else {
|
||||
if *p.Password == "" {
|
||||
invalid.Append("password", "Password cannot be empty")
|
||||
}
|
||||
// Skip password validation in the case of admin created users
|
||||
}
|
||||
}
|
||||
|
||||
if p.SSOEnabled != nil && *p.SSOEnabled && p.Password != nil && len(*p.Password) > 0 {
|
||||
invalid.Append("password", "not allowed for SSO users")
|
||||
}
|
||||
|
||||
if p.Email == nil {
|
||||
invalid.Append("email", "Email missing required argument")
|
||||
} else {
|
||||
if *p.Email == "" {
|
||||
invalid.Append("email", "Email cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if p.InviteToken != nil {
|
||||
invalid.Append("invite_token", "Invite token should not be specified with admin user creation")
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
return mw.Service.CreateUser(ctx, p)
|
||||
}
|
||||
|
||||
func (mw validationMiddleware) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
if p.Name != nil {
|
||||
if *p.Name == "" {
|
||||
invalid.Append("name", "Full name cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if p.Email != nil {
|
||||
if *p.Email == "" {
|
||||
invalid.Append("email", "Email cannot be empty")
|
||||
}
|
||||
// if the user is not an admin, or if an admin is changing their own email
|
||||
// address a password is required,
|
||||
if passwordRequiredForEmailChange(ctx, userID, invalid) {
|
||||
if p.Password == nil {
|
||||
invalid.Append("password", "Password cannot be empty if email is changed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
return mw.Service.ModifyUser(ctx, userID, p)
|
||||
}
|
||||
|
||||
func passwordRequiredForEmailChange(ctx context.Context, uid uint, invalid *fleet.InvalidArgumentError) bool {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
invalid.Append("viewer", "Viewer not present")
|
||||
return false
|
||||
}
|
||||
// if a user is changing own email need a password no matter what
|
||||
return vc.UserID() == uid
|
||||
}
|
||||
|
||||
func (mw validationMiddleware) ChangePassword(ctx context.Context, oldPass, newPass string) error {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
if oldPass == "" {
|
||||
invalid.Append("old_password", "Old password cannot be empty")
|
||||
}
|
||||
if newPass == "" {
|
||||
invalid.Append("new_password", "New password cannot be empty")
|
||||
}
|
||||
|
||||
if err := validatePasswordRequirements(newPass); err != nil {
|
||||
invalid.Append("new_password", err.Error())
|
||||
}
|
||||
|
||||
if invalid.HasErrors() {
|
||||
return ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
return mw.Service.ChangePassword(ctx, oldPass, newPass)
|
||||
}
|
||||
|
||||
func (mw validationMiddleware) ResetPassword(ctx context.Context, token, password string) error {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
if token == "" {
|
||||
invalid.Append("token", "Token cannot be empty field")
|
||||
}
|
||||
if password == "" {
|
||||
invalid.Append("new_password", "New password cannot be empty field")
|
||||
}
|
||||
if err := validatePasswordRequirements(password); err != nil {
|
||||
invalid.Append("new_password", err.Error())
|
||||
}
|
||||
if invalid.HasErrors() {
|
||||
return ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
return mw.Service.ResetPassword(ctx, token, password)
|
||||
}
|
||||
|
||||
// Requirements for user password:
|
||||
// at least 7 character length
|
||||
// at least 1 symbol
|
||||
// at least 1 number
|
||||
func validatePasswordRequirements(password string) error {
|
||||
var (
|
||||
number bool
|
||||
symbol bool
|
||||
)
|
||||
|
||||
for _, s := range password {
|
||||
switch {
|
||||
case unicode.IsNumber(s):
|
||||
number = true
|
||||
case unicode.IsPunct(s) || unicode.IsSymbol(s):
|
||||
symbol = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(password) >= 7 &&
|
||||
number &&
|
||||
symbol {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("Password does not meet validation requirements")
|
||||
}
|
||||
Loading…
Reference in a new issue