From c8bc026d6fe3cd9db8e095f8ccd6a24ce684653e Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Tue, 8 Mar 2022 11:27:38 -0500 Subject: [PATCH] Migrate special-case endpoints to new pattern (#4511) --- server/datastore/mysql/app_configs.go | 6 +- server/datastore/mysql/password_reset.go | 4 +- server/fleet/datastore.go | 2 +- server/fleet/users.go | 128 + server/fleet/users_test.go | 34 + server/mock/datastore_mock.go | 12 +- server/service/carves.go | 72 + server/service/carves_test.go | 419 +++ server/service/endpoint_carves.go | 47 - server/service/endpoint_invites.go | 30 - server/service/endpoint_middleware.go | 17 +- server/service/endpoint_osquery.go | 36 - server/service/endpoint_sessions.go | 163 -- server/service/endpoint_users.go | 100 - server/service/endpoint_utils.go | 41 +- server/service/endpoint_utils_test.go | 72 +- server/service/handler.go | 153 +- server/service/handler_test.go | 53 - server/service/integration_core_test.go | 427 ++- server/service/invites.go | 49 + ...ervice_invites_test.go => invites_test.go} | 0 server/service/jitter.go | 84 + server/service/jitter_test.go | 52 + .../service/middleware/ratelimit/ratelimit.go | 10 +- .../middleware/ratelimit/ratelimit_test.go | 1 + server/service/osquery.go | 251 ++ server/service/osquery_test.go | 2344 ++++++++++++++++ server/service/service.go | 2 +- server/service/service_carves.go | 46 - server/service/service_carves_test.go | 432 --- server/service/service_invites.go | 34 - server/service/service_osquery.go | 317 --- server/service/service_osquery_test.go | 2404 ----------------- server/service/service_sessions.go | 325 --- server/service/service_sessions_test.go | 111 - server/service/service_teams.go | 15 - server/service/service_users.go | 182 -- server/service/service_users_test.go | 188 -- server/service/sessions.go | 492 ++++ server/service/sessions_test.go | 100 + server/service/testing_client.go | 16 +- server/service/translator.go | 2 +- server/service/transport.go | 4 - server/service/transport_carves.go | 20 - server/service/transport_invites.go | 17 - server/service/transport_osquery.go | 17 - server/service/transport_osquery_test.go | 36 - server/service/transport_sessions.go | 41 - server/service/transport_sessions_test.go | 49 - server/service/transport_users.go | 42 - server/service/transport_users_test.go | 35 - server/service/user_roles.go | 4 +- server/service/users.go | 311 ++- server/service/users_test.go | 140 + server/service/validation_users.go | 202 -- 55 files changed, 5091 insertions(+), 5100 deletions(-) delete mode 100644 server/service/endpoint_carves.go delete mode 100644 server/service/endpoint_invites.go delete mode 100644 server/service/endpoint_osquery.go delete mode 100644 server/service/endpoint_sessions.go delete mode 100644 server/service/endpoint_users.go rename server/service/{service_invites_test.go => invites_test.go} (100%) create mode 100644 server/service/jitter.go create mode 100644 server/service/jitter_test.go delete mode 100644 server/service/service_carves.go delete mode 100644 server/service/service_carves_test.go delete mode 100644 server/service/service_invites.go delete mode 100644 server/service/service_osquery.go delete mode 100644 server/service/service_osquery_test.go delete mode 100644 server/service/service_sessions.go delete mode 100644 server/service/service_sessions_test.go delete mode 100644 server/service/service_teams.go delete mode 100644 server/service/service_users_test.go delete mode 100644 server/service/transport_carves.go delete mode 100644 server/service/transport_invites.go delete mode 100644 server/service/transport_osquery.go delete mode 100644 server/service/transport_osquery_test.go delete mode 100644 server/service/transport_sessions.go delete mode 100644 server/service/transport_sessions_test.go delete mode 100644 server/service/transport_users.go delete mode 100644 server/service/transport_users_test.go delete mode 100644 server/service/validation_users.go diff --git a/server/datastore/mysql/app_configs.go b/server/datastore/mysql/app_configs.go index bedd88f01a..670d5c5376 100644 --- a/server/datastore/mysql/app_configs.go +++ b/server/datastore/mysql/app_configs.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" @@ -74,7 +75,10 @@ func (ds *Datastore) VerifyEnrollSecret(ctx context.Context, secret string) (*fl var s fleet.EnrollSecret err := sqlx.GetContext(ctx, ds.reader, &s, "SELECT team_id FROM enroll_secrets WHERE secret = ?", secret) if err != nil { - return nil, ctxerr.New(ctx, "no matching secret found") + if errors.Is(err, sql.ErrNoRows) { + return nil, ctxerr.New(ctx, "no matching secret found") + } + return nil, ctxerr.Wrap(ctx, err, "verify enroll secret") } return &s, nil diff --git a/server/datastore/mysql/password_reset.go b/server/datastore/mysql/password_reset.go index 87eff64cba..13c523fbff 100644 --- a/server/datastore/mysql/password_reset.go +++ b/server/datastore/mysql/password_reset.go @@ -36,7 +36,8 @@ func (ds *Datastore) DeletePasswordResetRequestsForUser(ctx context.Context, use return nil } -func (ds *Datastore) FindPassswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) { + +func (ds *Datastore) FindPasswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) { sqlStatement := ` SELECT * FROM password_reset_requests WHERE token = ? LIMIT 1 @@ -48,5 +49,4 @@ func (ds *Datastore) FindPassswordResetByToken(ctx context.Context, token string } return passwordResetRequest, nil - } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 3dfc42c98b..3eb2f34675 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -226,7 +226,7 @@ type Datastore interface { NewPasswordResetRequest(ctx context.Context, req *PasswordResetRequest) (*PasswordResetRequest, error) DeletePasswordResetRequestsForUser(ctx context.Context, userID uint) error - FindPassswordResetByToken(ctx context.Context, token string) (*PasswordResetRequest, error) + FindPasswordResetByToken(ctx context.Context, token string) (*PasswordResetRequest, error) /////////////////////////////////////////////////////////////////////////////// // SessionStore is the abstract interface that all session backends must conform to. diff --git a/server/fleet/users.go b/server/fleet/users.go index ecff0c6982..d7aa03c0c6 100644 --- a/server/fleet/users.go +++ b/server/fleet/users.go @@ -1,7 +1,9 @@ package fleet import ( + "errors" "fmt" + "unicode" "github.com/fleetdm/fleet/v4/server" "golang.org/x/crypto/bcrypt" @@ -69,6 +71,104 @@ type UserPayload struct { Teams *[]UserTeam `json:"teams,omitempty"` } +func (p *UserPayload) VerifyInviteCreate() error { + invalid := &InvalidArgumentError{} + if p.Name == nil { + invalid.Append("name", "Full name missing required argument") + } else if *p.Name == "" { + invalid.Append("name", "Full name cannot be empty") + } + + // we don't need a password for single sign on + if p.SSOInvite == nil || !*p.SSOInvite { + if p.Password == nil { + invalid.Append("password", "Password missing required argument") + } else if *p.Password == "" { + invalid.Append("password", "Password cannot be empty") + } else if err := ValidatePasswordRequirements(*p.Password); err != nil { + invalid.Append("password", err.Error()) + } + } + + if p.Email == nil { + invalid.Append("email", "Email missing required argument") + } else if *p.Email == "" { + invalid.Append("email", "Email cannot be empty") + } + + if p.InviteToken == nil { + invalid.Append("invite_token", "Invite token missing required argument") + } else if *p.InviteToken == "" { + invalid.Append("invite_token", "Invite token cannot be empty") + } + + if invalid.HasErrors() { + return invalid + } + return nil +} + +func (p *UserPayload) VerifyAdminCreate() error { + invalid := &InvalidArgumentError{} + if p.Name == nil { + invalid.Append("name", "Full name missing required argument") + } else if *p.Name == "" { + invalid.Append("name", "Full name cannot be empty") + } + + // we don't need a password for single sign on + if (p.SSOInvite == nil || !*p.SSOInvite) && (p.SSOEnabled == nil || !*p.SSOEnabled) { + if p.Password == nil { + invalid.Append("password", "Password missing required argument") + } else if *p.Password == "" { + invalid.Append("password", "Password cannot be empty") + } + // Skip password validation in the case of admin created users + } + + if p.SSOEnabled != nil && *p.SSOEnabled && p.Password != nil && len(*p.Password) > 0 { + invalid.Append("password", "not allowed for SSO users") + } + + if p.Email == nil { + invalid.Append("email", "Email missing required argument") + } else if *p.Email == "" { + invalid.Append("email", "Email cannot be empty") + } + + if p.InviteToken != nil { + invalid.Append("invite_token", "Invite token should not be specified with admin user creation") + } + + if invalid.HasErrors() { + return invalid + } + return nil +} + +func (p *UserPayload) VerifyModify(ownUser bool) error { + invalid := &InvalidArgumentError{} + if p.Name != nil && *p.Name == "" { + invalid.Append("name", "Full name cannot be empty") + } + + if p.Email != nil { + if *p.Email == "" { + invalid.Append("email", "Email cannot be empty") + } + // if the user is not an admin, or if an admin is changing their own email + // address a password is required, + if ownUser && p.Password == nil { + invalid.Append("password", "Password cannot be empty if email is changed") + } + } + + if invalid.HasErrors() { + return invalid + } + return nil +} + // User creates a user from payload. func (p UserPayload) User(keySize, cost int) (*User, error) { user := &User{ @@ -130,3 +230,31 @@ func (u *User) SetPassword(plaintext string, keySize, cost int) error { u.Password = hashed return nil } + +// Requirements for user password: +// at least 7 character length +// at least 1 symbol +// at least 1 number +func ValidatePasswordRequirements(password string) error { + var ( + number bool + symbol bool + ) + + for _, s := range password { + switch { + case unicode.IsNumber(s): + number = true + case unicode.IsPunct(s) || unicode.IsSymbol(s): + symbol = true + } + } + + if len(password) >= 7 && + number && + symbol { + return nil + } + + return errors.New("Password does not meet validation requirements") +} diff --git a/server/fleet/users_test.go b/server/fleet/users_test.go index 36b27a9b28..b72e55818e 100644 --- a/server/fleet/users_test.go +++ b/server/fleet/users_test.go @@ -42,3 +42,37 @@ func newTestUser(t *testing.T, password, email string) *User { Email: email, } } + +func TestUserPasswordRequirements(t *testing.T) { + passwordTests := []struct { + password string + wantErr bool + }{ + { + password: "foobar", + wantErr: true, + }, + { + password: "foobarbaz", + wantErr: true, + }, + { + password: "foobarbaz!", + wantErr: true, + }, + { + password: "foobarbaz!3", + }, + } + + for _, tt := range passwordTests { + t.Run(tt.password, func(t *testing.T) { + err := ValidatePasswordRequirements(tt.password) + if tt.wantErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index a350b5b933..b9a478e49e 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -190,7 +190,7 @@ type NewPasswordResetRequestFunc func(ctx context.Context, req *fleet.PasswordRe type DeletePasswordResetRequestsForUserFunc func(ctx context.Context, userID uint) error -type FindPassswordResetByTokenFunc func(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) +type FindPasswordResetByTokenFunc func(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) type SessionByKeyFunc func(ctx context.Context, key string) (*fleet.Session, error) @@ -650,8 +650,8 @@ type DataStore struct { DeletePasswordResetRequestsForUserFunc DeletePasswordResetRequestsForUserFunc DeletePasswordResetRequestsForUserFuncInvoked bool - FindPassswordResetByTokenFunc FindPassswordResetByTokenFunc - FindPassswordResetByTokenFuncInvoked bool + FindPasswordResetByTokenFunc FindPasswordResetByTokenFunc + FindPasswordResetByTokenFuncInvoked bool SessionByKeyFunc SessionByKeyFunc SessionByKeyFuncInvoked bool @@ -1384,9 +1384,9 @@ func (s *DataStore) DeletePasswordResetRequestsForUser(ctx context.Context, user return s.DeletePasswordResetRequestsForUserFunc(ctx, userID) } -func (s *DataStore) FindPassswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) { - s.FindPassswordResetByTokenFuncInvoked = true - return s.FindPassswordResetByTokenFunc(ctx, token) +func (s *DataStore) FindPasswordResetByToken(ctx context.Context, token string) (*fleet.PasswordResetRequest, error) { + s.FindPasswordResetByTokenFuncInvoked = true + return s.FindPasswordResetByTokenFunc(ctx, token) } func (s *DataStore) SessionByKey(ctx context.Context, key string) (*fleet.Session, error) { diff --git a/server/service/carves.go b/server/service/carves.go index dfc9049e25..770229045a 100644 --- a/server/service/carves.go +++ b/server/service/carves.go @@ -232,3 +232,75 @@ func (svc *Service) CarveBegin(ctx context.Context, payload fleet.CarveBeginPayl return carve, nil } + +//////////////////////////////////////////////////////////////////////////////// +// Receive Block for File Carve +//////////////////////////////////////////////////////////////////////////////// + +type carveBlockRequest struct { + BlockId int64 `json:"block_id"` + SessionId string `json:"session_id"` + RequestId string `json:"request_id"` + Data []byte `json:"data"` +} + +type carveBlockResponse struct { + Success bool `json:"success,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r carveBlockResponse) error() error { return r.Err } + +func carveBlockEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*carveBlockRequest) + + payload := fleet.CarveBlockPayload{ + SessionId: req.SessionId, + RequestId: req.RequestId, + BlockId: req.BlockId, + Data: req.Data, + } + + err := svc.CarveBlock(ctx, payload) + if err != nil { + return carveBlockResponse{Err: err}, nil + } + + return carveBlockResponse{Success: true}, nil +} + +func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + // Note host did not authenticate via node key. We need to authenticate them + // by the session ID and request ID + carve, err := svc.carveStore.CarveBySessionId(ctx, payload.SessionId) + if err != nil { + return ctxerr.Wrap(ctx, err, "find carve by session_id") + } + + if payload.RequestId != carve.RequestId { + return errors.New("request_id does not match") + } + + // Request is now authenticated + + if payload.BlockId > carve.BlockCount-1 { + return fmt.Errorf("block_id exceeds expected max (%d): %d", carve.BlockCount-1, payload.BlockId) + } + + if payload.BlockId != carve.MaxBlock+1 { + return fmt.Errorf("block_id does not match expected block (%d): %d", carve.MaxBlock+1, payload.BlockId) + } + + if int64(len(payload.Data)) > carve.BlockSize { + return fmt.Errorf("exceeded declared block size %d: %d", carve.BlockSize, len(payload.Data)) + } + + if err := svc.carveStore.NewBlock(ctx, carve, payload.BlockId, payload.Data); err != nil { + return ctxerr.Wrap(ctx, err, "save block data") + } + + return nil +} diff --git a/server/service/carves_test.go b/server/service/carves_test.go index 987e1cefdb..35f912c48b 100644 --- a/server/service/carves_test.go +++ b/server/service/carves_test.go @@ -4,8 +4,10 @@ import ( "context" "errors" "testing" + "time" "github.com/fleetdm/fleet/v4/server/authz" + hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/test" @@ -184,3 +186,420 @@ func TestCarveGetBlockExpired(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "expired carve") } + +func TestCarveBegin(t *testing.T) { + host := fleet.Host{ID: 3} + payload := fleet.CarveBeginPayload{ + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + } + ms := new(mock.Store) + ds := new(mock.Store) + svc := &Service{ + carveStore: ms, + ds: ds, + } + expectedMetadata := fleet.CarveMetadata{ + ID: 7, + HostId: host.ID, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + } + ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) { + metadata.ID = 7 + return metadata, nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if host.ID != id { + return nil, errors.New("not found") + } + return &fleet.Host{ + Hostname: host.Hostname, + }, nil + } + + ctx := hostctx.NewContext(context.Background(), &host) + + metadata, err := svc.CarveBegin(ctx, payload) + require.NoError(t, err) + assert.NotEmpty(t, metadata.SessionId) + metadata.SessionId = "" // Clear this before comparison + metadata.Name = "" // Clear this before comparison + metadata.CreatedAt = time.Time{} // Clear this before comparison + assert.Equal(t, expectedMetadata, *metadata) +} + +func TestCarveBeginNewCarveError(t *testing.T) { + host := fleet.Host{ID: 3} + payload := fleet.CarveBeginPayload{ + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + } + ms := new(mock.Store) + ds := new(mock.Store) + svc := &Service{ + carveStore: ms, + ds: ds, + } + ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) { + return nil, errors.New("ouch!") + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if host.ID != id { + return nil, errors.New("not found") + } + return &fleet.Host{ + Hostname: host.Hostname, + }, nil + } + + ctx := hostctx.NewContext(context.Background(), &host) + + _, err := svc.CarveBegin(ctx, payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "ouch!") +} + +func TestCarveBeginEmptyError(t *testing.T) { + ms := new(mock.Store) + ds := new(mock.Store) + svc := &Service{ + carveStore: ms, + ds: ds, + } + ctx := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1}) + + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if id != 1 { + return nil, errors.New("not found") + } + return &fleet.Host{}, nil + } + + _, err := svc.CarveBegin(ctx, fleet.CarveBeginPayload{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "carve_size must be greater than 0") +} + +func TestCarveBeginMissingHostError(t *testing.T) { + ms := new(mock.Store) + svc := &Service{carveStore: ms} + + _, err := svc.CarveBegin(context.Background(), fleet.CarveBeginPayload{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing host") +} + +func TestCarveBeginBlockSizeMaxError(t *testing.T) { + host := fleet.Host{ID: 3} + payload := fleet.CarveBeginPayload{ + BlockCount: 10, + BlockSize: 1024 * 1024 * 1024 * 1024, // 1TB + CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB + RequestId: "carve_request", + } + ms := new(mock.Store) + ds := new(mock.Store) + svc := &Service{ + carveStore: ms, + ds: ds, + } + + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if host.ID != id { + return nil, errors.New("not found") + } + return &fleet.Host{ + Hostname: host.Hostname, + }, nil + } + + ctx := hostctx.NewContext(context.Background(), &host) + + _, err := svc.CarveBegin(ctx, payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "block_size exceeds max") +} + +func TestCarveBeginCarveSizeMaxError(t *testing.T) { + host := fleet.Host{ID: 3} + payload := fleet.CarveBeginPayload{ + BlockCount: 1024 * 1024, + BlockSize: 10 * 1024 * 1024, // 1TB + CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB + RequestId: "carve_request", + } + ms := new(mock.Store) + ds := new(mock.Store) + svc := &Service{ + carveStore: ms, + ds: ds, + } + + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if host.ID != id { + return nil, errors.New("not found") + } + return &fleet.Host{ + Hostname: host.Hostname, + }, nil + } + + ctx := hostctx.NewContext(context.Background(), &host) + + _, err := svc.CarveBegin(ctx, payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "carve_size exceeds max") +} + +func TestCarveBeginCarveSizeError(t *testing.T) { + host := fleet.Host{ID: 3} + payload := fleet.CarveBeginPayload{ + BlockCount: 7, + BlockSize: 13, + CarveSize: 7*13 + 1, + RequestId: "carve_request", + } + ms := new(mock.Store) + ds := new(mock.Store) + svc := &Service{ + carveStore: ms, + ds: ds, + } + ctx := hostctx.NewContext(context.Background(), &host) + + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if host.ID != id { + return nil, errors.New("not found") + } + return &fleet.Host{ + Hostname: host.Hostname, + }, nil + } + + // Too big + _, err := svc.CarveBegin(ctx, payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "carve_size does not match") + + // Too small + payload.CarveSize = 6 * 13 + _, err = svc.CarveBegin(ctx, payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "carve_size does not match") +} + +func TestCarveCarveBlockGetCarveError(t *testing.T) { + sessionId := "foobar" + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + return nil, errors.New("ouch!") + } + + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :)"), + SessionId: sessionId, + } + + err := svc.CarveBlock(context.Background(), payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "ouch!") +} + +func TestCarveCarveBlockRequestIdError(t *testing.T) { + sessionId := "foobar" + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: sessionId, + } + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.SessionId, sessionId) + return metadata, nil + } + + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :)"), + RequestId: "not_matching", + SessionId: sessionId, + } + + err := svc.CarveBlock(context.Background(), payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "request_id does not match") +} + +func TestCarveCarveBlockBlockCountExceedError(t *testing.T) { + sessionId := "foobar" + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: sessionId, + } + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.SessionId, sessionId) + return metadata, nil + } + + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :)"), + RequestId: "carve_request", + SessionId: sessionId, + BlockId: 23, + } + + err := svc.CarveBlock(context.Background(), payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "block_id exceeds expected max") +} + +func TestCarveCarveBlockBlockCountMatchError(t *testing.T) { + sessionId := "foobar" + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: sessionId, + MaxBlock: 3, + } + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.SessionId, sessionId) + return metadata, nil + } + + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :)"), + RequestId: "carve_request", + SessionId: sessionId, + BlockId: 7, + } + + err := svc.CarveBlock(context.Background(), payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "block_id does not match") +} + +func TestCarveCarveBlockBlockSizeError(t *testing.T) { + sessionId := "foobar" + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 16, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: sessionId, + MaxBlock: 3, + } + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.SessionId, sessionId) + return metadata, nil + } + + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :) TOO LONG!!!"), + RequestId: "carve_request", + SessionId: sessionId, + BlockId: 4, + } + + err := svc.CarveBlock(context.Background(), payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeded declared block size") +} + +func TestCarveCarveBlockNewBlockError(t *testing.T) { + sessionId := "foobar" + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: sessionId, + MaxBlock: 3, + } + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.SessionId, sessionId) + return metadata, nil + } + ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error { + return errors.New("kaboom!") + } + + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :)"), + RequestId: "carve_request", + SessionId: sessionId, + BlockId: 4, + } + + err := svc.CarveBlock(context.Background(), payload) + require.Error(t, err) + assert.Contains(t, err.Error(), "kaboom!") +} + +func TestCarveCarveBlock(t *testing.T) { + sessionId := "foobar" + metadata := &fleet.CarveMetadata{ + ID: 2, + HostId: 3, + BlockCount: 23, + BlockSize: 64, + CarveSize: 23 * 64, + RequestId: "carve_request", + SessionId: sessionId, + MaxBlock: 3, + } + payload := fleet.CarveBlockPayload{ + Data: []byte("this is the carve data :)"), + RequestId: "carve_request", + SessionId: sessionId, + BlockId: 4, + } + ms := new(mock.Store) + svc := &Service{carveStore: ms} + ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { + assert.Equal(t, metadata.SessionId, sessionId) + return metadata, nil + } + ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error { + assert.Equal(t, metadata, carve) + assert.Equal(t, int64(4), blockId) + assert.Equal(t, payload.Data, data) + return nil + } + + err := svc.CarveBlock(context.Background(), payload) + require.NoError(t, err) + assert.True(t, ms.NewBlockFuncInvoked) +} diff --git a/server/service/endpoint_carves.go b/server/service/endpoint_carves.go deleted file mode 100644 index 46f58c90a7..0000000000 --- a/server/service/endpoint_carves.go +++ /dev/null @@ -1,47 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -//////////////////////////////////////////////////////////////////////////////// -// Receive Block for File Carve -//////////////////////////////////////////////////////////////////////////////// - -type carveBlockRequest struct { - NodeKey string `json:"node_key"` - BlockId int64 `json:"block_id"` - SessionId string `json:"session_id"` - RequestId string `json:"request_id"` - Data []byte `json:"data"` -} - -type carveBlockResponse struct { - Success bool `json:"success,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r carveBlockResponse) error() error { return r.Err } - -func makeCarveBlockEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(carveBlockRequest) - - payload := fleet.CarveBlockPayload{ - SessionId: req.SessionId, - RequestId: req.RequestId, - BlockId: req.BlockId, - Data: req.Data, - } - - err := svc.CarveBlock(ctx, payload) - if err != nil { - return carveBlockResponse{Err: err}, nil - } - - return carveBlockResponse{Success: true}, nil - } -} diff --git a/server/service/endpoint_invites.go b/server/service/endpoint_invites.go deleted file mode 100644 index 92a5daae29..0000000000 --- a/server/service/endpoint_invites.go +++ /dev/null @@ -1,30 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -type verifyInviteRequest struct { - Token string -} - -type verifyInviteResponse struct { - Invite *fleet.Invite `json:"invite"` - Err error `json:"error,omitempty"` -} - -func (r verifyInviteResponse) error() error { return r.Err } - -func makeVerifyInviteEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(verifyInviteRequest) - invite, err := svc.VerifyInvite(ctx, req.Token) - if err != nil { - return verifyInviteResponse{Err: err}, nil - } - return verifyInviteResponse{Invite: invite}, nil - } -} diff --git a/server/service/endpoint_middleware.go b/server/service/endpoint_middleware.go index bdf7d50790..6853501829 100644 --- a/server/service/endpoint_middleware.go +++ b/server/service/endpoint_middleware.go @@ -137,6 +137,10 @@ func authenticatedUser(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpo return logged(authUserFunc) } +func unauthenticatedRequest(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint { + return logged(next) +} + // logged wraps an endpoint and adds the error if the context supports it func logged(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { @@ -167,16 +171,3 @@ func authViewer(ctx context.Context, sessionKey string, svc fleet.Service) (*vie } return &viewer.Viewer{User: user, Session: session}, nil } - -func canPerformPasswordReset(next endpoint.Endpoint) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - vc, ok := viewer.FromContext(ctx) - if !ok { - return nil, fleet.ErrNoContext - } - if !vc.CanPerformPasswordReset() { - return nil, fleet.NewPermissionError("cannot reset password") - } - return next(ctx, request) - } -} diff --git a/server/service/endpoint_osquery.go b/server/service/endpoint_osquery.go deleted file mode 100644 index c0ed9a9211..0000000000 --- a/server/service/endpoint_osquery.go +++ /dev/null @@ -1,36 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -//////////////////////////////////////////////////////////////////////////////// -// Enroll Agent -//////////////////////////////////////////////////////////////////////////////// - -type enrollAgentRequest struct { - EnrollSecret string `json:"enroll_secret"` - HostIdentifier string `json:"host_identifier"` - HostDetails map[string](map[string]string) `json:"host_details"` -} - -type enrollAgentResponse struct { - NodeKey string `json:"node_key,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r enrollAgentResponse) error() error { return r.Err } - -func makeEnrollAgentEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(enrollAgentRequest) - nodeKey, err := svc.EnrollAgent(ctx, req.EnrollSecret, req.HostIdentifier, req.HostDetails) - if err != nil { - return enrollAgentResponse{Err: err}, nil - } - return enrollAgentResponse{NodeKey: nodeKey}, nil - } -} diff --git a/server/service/endpoint_sessions.go b/server/service/endpoint_sessions.go deleted file mode 100644 index 97430b1870..0000000000 --- a/server/service/endpoint_sessions.go +++ /dev/null @@ -1,163 +0,0 @@ -package service - -import ( - "bytes" - "context" - "errors" - "html/template" - - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -//////////////////////////////////////////////////////////////////////////////// -// Login -//////////////////////////////////////////////////////////////////////////////// - -type loginRequest struct { - Email string - Password string -} - -type loginResponse struct { - User *fleet.User `json:"user,omitempty"` - AvailableTeams []*fleet.TeamSummary `json:"available_teams"` - Token string `json:"token,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r loginResponse) error() error { return r.Err } - -func makeLoginEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(loginRequest) - user, token, err := svc.Login(ctx, req.Email, req.Password) - if err != nil { - return loginResponse{Err: err}, nil - } - // Add viewer context allow access to service teams for list of available teams - v, err := authViewer(ctx, token, svc) - if err != nil { - return loginResponse{Err: err}, nil - } - ctx = viewer.NewContext(ctx, *v) - availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user) - if err != nil { - if errors.Is(err, fleet.ErrMissingLicense) { - availableTeams = []*fleet.TeamSummary{} - } else { - return loginResponse{Err: err}, nil - } - } - return loginResponse{user, availableTeams, token, nil}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Logout -//////////////////////////////////////////////////////////////////////////////// - -type logoutResponse struct { - Err error `json:"error,omitempty"` -} - -func (r logoutResponse) error() error { return r.Err } - -func makeLogoutEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - err := svc.Logout(ctx) - if err != nil { - return logoutResponse{Err: err}, nil - } - return logoutResponse{}, nil - } -} - -type initiateSSORequest struct { - RelayURL string `json:"relay_url"` -} - -type initiateSSOResponse struct { - URL string `json:"url,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r initiateSSOResponse) error() error { return r.Err } - -func makeInitiateSSOEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(initiateSSORequest) - idProviderURL, err := svc.InitiateSSO(ctx, req.RelayURL) - if err != nil { - return initiateSSOResponse{Err: err}, nil - } - return initiateSSOResponse{URL: idProviderURL}, nil - } -} - -type callbackSSOResponse struct { - content string - Err error `json:"error,omitempty"` -} - -func (r callbackSSOResponse) error() error { return r.Err } - -// If html is present we return a web page -func (r callbackSSOResponse) html() string { return r.content } - -func makeCallbackSSOEndpoint(svc fleet.Service, urlPrefix string) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - authResponse := request.(fleet.Auth) - session, err := svc.CallbackSSO(ctx, authResponse) - var resp callbackSSOResponse - if err != nil { - // redirect to login page on front end if there was some problem, - // errors should still be logged - session = &fleet.SSOSession{ - RedirectURL: urlPrefix + "/login", - Token: "", - } - resp.Err = err - } - relayStateLoadPage := ` - - - Redirecting to Fleet at {{ .RedirectURL }} ... - - - ` - tmpl, err := template.New("relayStateLoader").Parse(relayStateLoadPage) - if err != nil { - return nil, err - } - var writer bytes.Buffer - err = tmpl.Execute(&writer, session) - if err != nil { - return nil, err - } - resp.content = writer.String() - return resp, nil - } -} - -type ssoSettingsResponse struct { - Settings *fleet.SessionSSOSettings `json:"settings,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r ssoSettingsResponse) error() error { return r.Err } - -func makeSSOSettingsEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, unused interface{}) (interface{}, error) { - settings, err := svc.SSOSettings(ctx) - if err != nil { - return ssoSettingsResponse{Err: err}, nil - } - return ssoSettingsResponse{Settings: settings}, nil - } -} diff --git a/server/service/endpoint_users.go b/server/service/endpoint_users.go deleted file mode 100644 index a3d61a7867..0000000000 --- a/server/service/endpoint_users.go +++ /dev/null @@ -1,100 +0,0 @@ -package service - -import ( - "context" - "net/http" - - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/go-kit/kit/endpoint" -) - -//////////////////////////////////////////////////////////////////////////////// -// Create User With Invite -//////////////////////////////////////////////////////////////////////////////// - -func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(createUserRequest) - user, err := svc.CreateUserFromInvite(ctx, req.UserPayload) - if err != nil { - return createUserResponse{Err: err}, nil - } - return createUserResponse{User: user}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Reset Password -//////////////////////////////////////////////////////////////////////////////// - -type resetPasswordRequest struct { - PasswordResetToken string `json:"password_reset_token"` - NewPassword string `json:"new_password"` -} - -type resetPasswordResponse struct { - Err error `json:"error,omitempty"` -} - -func (r resetPasswordResponse) error() error { return r.Err } - -func makeResetPasswordEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(resetPasswordRequest) - err := svc.ResetPassword(ctx, req.PasswordResetToken, req.NewPassword) - return resetPasswordResponse{Err: err}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Perform Required Password Reset -//////////////////////////////////////////////////////////////////////////////// - -type performRequiredPasswordResetRequest struct { - Password string `json:"new_password"` - ID uint `json:"id"` -} - -type performRequiredPasswordResetResponse struct { - User *fleet.User `json:"user,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r performRequiredPasswordResetResponse) error() error { return r.Err } - -func makePerformRequiredPasswordResetEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(performRequiredPasswordResetRequest) - user, err := svc.PerformRequiredPasswordReset(ctx, req.Password) - if err != nil { - return performRequiredPasswordResetResponse{Err: err}, nil - } - return performRequiredPasswordResetResponse{User: user}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Forgot Password -//////////////////////////////////////////////////////////////////////////////// - -type forgotPasswordRequest struct { - Email string `json:"email"` -} - -type forgotPasswordResponse struct { - Err error `json:"error,omitempty"` -} - -func (r forgotPasswordResponse) error() error { return r.Err } -func (r forgotPasswordResponse) status() int { return http.StatusAccepted } - -func makeForgotPasswordEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(forgotPasswordRequest) - // Any error returned by the service should not be returned to the - // client to prevent information disclosure (it will be logged in the - // server logs). - _ = svc.RequestPasswordReset(ctx, req.Email) - return forgotPasswordResponse{}, nil - } -} diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 88ffb079da..07a74f9ee4 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -67,6 +67,10 @@ func allFields(ifv reflect.Value) []reflect.StructField { return fields } +type requestDecoder interface { + DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) +} + // makeDecoder creates a decoder for the type for the struct passed on. If the // struct has at least 1 json tag it'll unmarshall the body. If the struct has // a `url` tag with value list_options it'll gather fleet.ListOptions from the @@ -79,12 +83,22 @@ func allFields(ifv reflect.Value) []reflect.StructField { // follows: `url:"some-id,optional"`. // The "list_options" are optional by default and it'll ignore the optional // portion of the tag. +// +// If iface implements the requestDecoder interface, it returns a function that +// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its +// own decoding. func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { if iface == nil { return func(ctx context.Context, r *http.Request) (interface{}, error) { return nil, nil } } + if rd, ok := iface.(requestDecoder); ok { + return func(ctx context.Context, r *http.Request) (interface{}, error) { + return rd.DecodeRequest(ctx, r) + } + } + t := reflect.TypeOf(iface) if t.Kind() != reflect.Struct { panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface)) @@ -272,6 +286,7 @@ type authEndpointer struct { startingAtVersion string endingAtVersion string alternativePaths []string + customMiddleware []endpoint.Middleware } func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { @@ -297,6 +312,16 @@ func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts [ } } +func newNoAuthEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer { + return &authEndpointer{ + svc: svc, + opts: opts, + r: r, + authFunc: unauthenticatedRequest, + versions: versions, + } +} + var pathReplacer = strings.NewReplacer( "/", "_", "{", "_", @@ -374,7 +399,15 @@ func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler next := func(ctx context.Context, request interface{}) (interface{}, error) { return f(ctx, request, e.svc) } - return newServer(e.authFunc(e.svc, next), makeDecoder(v), e.opts) + endp := e.authFunc(e.svc, next) + + // apply middleware in reverse order so that the first wraps the second + // wraps the third etc. + for i := len(e.customMiddleware) - 1; i >= 0; i-- { + mw := e.customMiddleware[i] + endp = mw(endp) + } + return newServer(endp, makeDecoder(v), e.opts) } func (e *authEndpointer) StartingAtVersion(version string) *authEndpointer { @@ -394,3 +427,9 @@ func (e *authEndpointer) WithAltPaths(paths ...string) *authEndpointer { ae.alternativePaths = paths return &ae } + +func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authEndpointer { + ae := *e + ae.customMiddleware = mws + return &ae +} diff --git a/server/service/endpoint_utils_test.go b/server/service/endpoint_utils_test.go index e925a38e98..1fd0ad90a3 100644 --- a/server/service/endpoint_utils_test.go +++ b/server/service/endpoint_utils_test.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "io" "net/http" @@ -14,6 +15,7 @@ import ( "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/go-kit/kit/endpoint" kitlog "github.com/go-kit/kit/log" kithttp "github.com/go-kit/kit/transport/http" "github.com/gorilla/mux" @@ -251,7 +253,6 @@ func TestUniversalDecoderQueryAndListPlayNice(t *testing.T) { } func TestEndpointer(t *testing.T) { - r := mux.NewRouter() ds := new(mock.Store) ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) { @@ -395,3 +396,72 @@ func TestEndpointer(t *testing.T) { require.False(t, doesItMatch(route.method, route.path, false), route) } } + +func TestEndpointerCustomMiddleware(t *testing.T) { + r := mux.NewRouter() + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + fleetAPIOptions := []kithttp.ServerOption{ + kithttp.ServerBefore( + kithttp.PopulateRequestContext, + setRequestsContexts(svc), + ), + kithttp.ServerErrorHandler(&errorHandler{kitlog.NewNopLogger()}), + kithttp.ServerErrorEncoder(encodeError), + kithttp.ServerAfter( + kithttp.SetContentType("application/json; charset=utf-8"), + logRequestEnd(kitlog.NewNopLogger()), + checkLicenseExpiration(svc), + ), + } + + var buf bytes.Buffer + e := newNoAuthEndpointer(svc, fleetAPIOptions, r, "v1") + e.GET("/none/", func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + buf.WriteString("H1") + return nil, nil + }, nil) + + e.WithCustomMiddleware( + func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + buf.WriteString("A") + return e(ctx, request) + } + }, + func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + buf.WriteString("B") + return e(ctx, request) + } + }, + func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + buf.WriteString("C") + return e(ctx, request) + } + }, + ). + GET("/mw/", func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + buf.WriteString("H2") + return nil, nil + }, nil) + + req := httptest.NewRequest("GET", "/none/", nil) + var m1 mux.RouteMatch + + require.True(t, r.Match(req, &m1)) + rec := httptest.NewRecorder() + m1.Handler.ServeHTTP(rec, req) + require.Equal(t, "H1", buf.String()) + + buf.Reset() + req = httptest.NewRequest("GET", "/mw/", nil) + var m2 mux.RouteMatch + + require.True(t, r.Match(req, &m2)) + rec = httptest.NewRecorder() + m2.Handler.ServeHTTP(rec, req) + require.Equal(t, "ABCH2", buf.String()) +} diff --git a/server/service/handler.go b/server/service/handler.go index 869caa4cab..e7b6b0d1cc 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -4,7 +4,7 @@ import ( "context" "errors" "net/http" - "strings" + "regexp" "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/contexts/logging" @@ -22,92 +22,6 @@ import ( otmiddleware "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" ) -// FleetEndpoints is a collection of RPC endpoints implemented by the Fleet API. -type FleetEndpoints struct { - Login endpoint.Endpoint - Logout endpoint.Endpoint - ForgotPassword endpoint.Endpoint - ResetPassword endpoint.Endpoint - CreateUserWithInvite endpoint.Endpoint - PerformRequiredPasswordReset endpoint.Endpoint - VerifyInvite endpoint.Endpoint - EnrollAgent endpoint.Endpoint - CarveBlock endpoint.Endpoint - InitiateSSO endpoint.Endpoint - CallbackSSO endpoint.Endpoint - SSOSettings endpoint.Endpoint -} - -// MakeFleetServerEndpoints creates the Fleet API endpoints. -func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore throttled.GCRAStore, logger kitlog.Logger) FleetEndpoints { - limiter := ratelimit.NewMiddleware(limitStore) - - return FleetEndpoints{ - Login: limiter.Limit( - throttled.RateQuota{MaxRate: throttled.PerMin(10), MaxBurst: 9})( - makeLoginEndpoint(svc), - ), - Logout: logged(makeLogoutEndpoint(svc)), - ForgotPassword: limiter.Limit( - throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})( - logged(makeForgotPasswordEndpoint(svc)), - ), - ResetPassword: logged(makeResetPasswordEndpoint(svc)), - CreateUserWithInvite: logged(makeCreateUserFromInviteEndpoint(svc)), - VerifyInvite: logged(makeVerifyInviteEndpoint(svc)), - InitiateSSO: logged(makeInitiateSSOEndpoint(svc)), - CallbackSSO: logged(makeCallbackSSOEndpoint(svc, urlPrefix)), - SSOSettings: logged(makeSSOSettingsEndpoint(svc)), - - // PerformRequiredPasswordReset needs only to authenticate the - // logged in user - PerformRequiredPasswordReset: logged(canPerformPasswordReset(makePerformRequiredPasswordResetEndpoint(svc))), - - // Osquery endpoints - EnrollAgent: logged(makeEnrollAgentEndpoint(svc)), - // For some reason osquery does not provide a node key with the block - // data. Instead the carve session ID should be verified in the service - // method. - CarveBlock: logged(makeCarveBlockEndpoint(svc)), - } -} - -type fleetHandlers struct { - Login http.Handler - Logout http.Handler - ForgotPassword http.Handler - ResetPassword http.Handler - CreateUserWithInvite http.Handler - PerformRequiredPasswordReset http.Handler - VerifyInvite http.Handler - EnrollAgent http.Handler - CarveBlock http.Handler - InitiateSSO http.Handler - CallbackSSO http.Handler - SettingsSSO http.Handler -} - -func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandlers { - newServer := func(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc) http.Handler { - e = authzcheck.NewMiddleware().AuthzCheck()(e) - return kithttp.NewServer(e, decodeFn, encodeResponse, opts...) - } - return &fleetHandlers{ - Login: newServer(e.Login, decodeLoginRequest), - Logout: newServer(e.Logout, decodeNoParamsRequest), - ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest), - ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest), - CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest), - PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest), - VerifyInvite: newServer(e.VerifyInvite, decodeVerifyInviteRequest), - EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest), - CarveBlock: newServer(e.CarveBlock, decodeCarveBlockRequest), - InitiateSSO: newServer(e.InitiateSSO, decodeInitiateSSORequest), - CallbackSSO: newServer(e.CallbackSSO, decodeCallbackSSORequest), - SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest), - } -} - type errorHandler struct { logger kitlog.Logger } @@ -176,18 +90,16 @@ func MakeHandler(svc fleet.Service, config config.FleetConfig, logger kitlog.Log ), } - fleetEndpoints := MakeFleetServerEndpoints(svc, config.Server.URLPrefix, limitStore, logger) - fleetHandlers := makeKitHandlers(fleetEndpoints, fleetAPIOptions) - r := mux.NewRouter() if config.Logging.TracingEnabled && config.Logging.TracingType == "opentelemetry" { r.Use(otmiddleware.Middleware("fleet")) } - attachFleetAPIRoutes(r, fleetHandlers) - attachNewStyleFleetAPIRoutes(r, svc, logger, fleetAPIOptions) + attachFleetAPIRoutes(r, svc, config, logger, limitStore, fleetAPIOptions) // Results endpoint is handled different due to websockets use + // TODO: this would not work once v1 is deprecated - note that the handler too uses the /v1/ path + // and this routes on path prefix, not exact path (unlike the authendpointer struct). r.PathPrefix("/api/v1/fleet/results/"). Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, logger)). Name("distributed_query_results") @@ -277,22 +189,9 @@ func addMetrics(r *mux.Router) { r.Walk(walkFn) } -func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { - r.Handle("/api/v1/fleet/login", h.Login).Methods("POST").Name("login") - r.Handle("/api/v1/fleet/logout", h.Logout).Methods("POST").Name("logout") - r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password") - r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password") - r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset") - r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso") - r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config") - r.Handle("/api/v1/fleet/sso/callback", h.CallbackSSO).Methods("POST").Name("callback_sso") - r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite") - r.Handle("/api/v1/fleet/invites/{token}", h.VerifyInvite).Methods("GET").Name("verify_invite") - r.Handle("/api/v1/osquery/enroll", h.EnrollAgent).Methods("POST").Name("enroll_agent") - r.Handle("/api/v1/osquery/carve/block", h.CarveBlock).Methods("POST").Name("carve_block") -} +func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetConfig, + logger kitlog.Logger, limitStore throttled.GCRAStore, opts []kithttp.ServerOption) { -func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlog.Logger, opts []kithttp.ServerOption) { // user-authenticated endpoints ue := newUserAuthenticatedEndpointer(svc, opts, r, "v1") @@ -441,10 +340,42 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, logger kitlo he.POST("/api/_version_/osquery/distributed/write", submitDistributedQueryResultsEndpoint, submitDistributedQueryResultsRequestShim{}) he.POST("/api/_version_/osquery/carve/begin", carveBeginEndpoint, carveBeginRequest{}) he.POST("/api/_version_/osquery/log", submitLogsEndpoint, submitLogsRequest{}) + + // unauthenticated endpoints - most of those are either login-related, + // invite-related or host-enrolling. So they typically do some kind of + // one-time authentication by verifying that a valid secret token is provided + // with the request. + ne := newNoAuthEndpointer(svc, opts, r, "v1") + ne.POST("/api/_version_/osquery/enroll", enrollAgentEndpoint, enrollAgentRequest{}) + + // For some reason osquery does not provide a node key with the block data. + // Instead the carve session ID should be verified in the service method. + ne.POST("/api/_version_/osquery/carve/block", carveBlockEndpoint, carveBlockRequest{}) + + ne.POST("/api/_version_/fleet/perform_required_password_reset", performRequiredPasswordResetEndpoint, performRequiredPasswordResetRequest{}) + ne.POST("/api/_version_/fleet/users", createUserFromInviteEndpoint, createUserRequest{}) + ne.GET("/api/_version_/fleet/invites/{token}", verifyInviteEndpoint, verifyInviteRequest{}) + ne.POST("/api/_version_/fleet/reset_password", resetPasswordEndpoint, resetPasswordRequest{}) + ne.POST("/api/_version_/fleet/logout", logoutEndpoint, nil) + ne.POST("/api/_version_/fleet/sso", initiateSSOEndpoint, initiateSSORequest{}) + ne.POST("/api/_version_/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{}) + ne.GET("/api/_version_/fleet/sso", settingsSSOEndpoint, nil) + + limiter := ratelimit.NewMiddleware(limitStore) + ne. + WithCustomMiddleware(limiter.Limit("forgot_password", throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: 9})). + POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{}) + + ne. + WithCustomMiddleware(limiter.Limit("login", throttled.RateQuota{MaxRate: throttled.PerMin(10), MaxBurst: 9})). + POST("/api/_version_/fleet/login", loginEndpoint, loginRequest{}) } -// TODO: this duplicates the one in makeKitHandler func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []kithttp.ServerOption) http.Handler { + // TODO: some handlers don't have authz checks, and because the SkipAuth call is done only in the + // endpoint handler, any middleware that raises errors before the handler is reached will end up + // returning authz check missing instead of the more relevant error. Should be addressed as part + // of #4406. e = authzcheck.NewMiddleware().AuthzCheck()(e) return kithttp.NewServer(e, decodeFn, encodeResponse, opts...) } @@ -453,15 +384,19 @@ func newServer(e endpoint.Endpoint, decodeFn kithttp.DecodeRequestFunc, opts []k // If setup hasn't been completed it serves the API with a setup middleware. // If the server is already configured, the default API handler is exposed. func WithSetup(svc fleet.Service, logger kitlog.Logger, next http.Handler) http.HandlerFunc { + + rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`) return func(w http.ResponseWriter, r *http.Request) { configRouter := http.NewServeMux() + // TODO: hard-codes v1 as a path fragment, which would probably not work once we + // deprecate it for newer versions, unless we want to treat the setup differently (not versioned?) configRouter.Handle("/api/v1/setup", kithttp.NewServer( makeSetupEndpoint(svc), decodeSetupRequest, encodeResponse, )) // whitelist osqueryd endpoints - if strings.HasPrefix(r.URL.Path, "/api/v1/osquery") { + if rxOsquery.MatchString(r.URL.Path) { next.ServeHTTP(w, r) return } diff --git a/server/service/handler_test.go b/server/service/handler_test.go index 045a6a9619..a662bbefd1 100644 --- a/server/service/handler_test.go +++ b/server/service/handler_test.go @@ -16,63 +16,10 @@ import ( "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/throttled/throttled/v2/store/memstore" ) -func TestAPIRoutes(t *testing.T) { - ds := new(mock.Store) - - svc := newTestService(ds, nil, nil) - - r := mux.NewRouter() - limitStore, _ := memstore.New(0) - ke := MakeFleetServerEndpoints(svc, "", limitStore, kitlog.NewNopLogger()) - kh := makeKitHandlers(ke, nil) - attachFleetAPIRoutes(r, kh) - handler := mux.NewRouter() - handler.PathPrefix("/").Handler(r) - - routes := []struct { - verb string - uri string - }{ - { - verb: "POST", - uri: "/api/v1/fleet/users", - }, - { - verb: "POST", - uri: "/api/v1/fleet/login", - }, - { - verb: "POST", - uri: "/api/v1/fleet/forgot_password", - }, - { - verb: "POST", - uri: "/api/v1/fleet/reset_password", - }, - { - verb: "POST", - uri: "/api/v1/osquery/enroll", - }, - } - - for _, route := range routes { - t.Run(fmt.Sprintf(": %v", route.uri), func(st *testing.T) { - recorder := httptest.NewRecorder() - handler.ServeHTTP( - recorder, - httptest.NewRequest(route.verb, route.uri, nil), - ) - assert.NotEqual(st, 404, recorder.Code) - assert.NotEqual(st, 405, recorder.Code, route.verb) // if it matches a path but with wrong verb - }) - } -} - func TestAPIRoutesConflicts(t *testing.T) { ds := new(mock.Store) diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 34102912aa..14d2766a63 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -15,11 +15,13 @@ import ( "testing" "time" + "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" "github.com/ghodss/yaml" "github.com/google/uuid" + "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -978,6 +980,20 @@ func (s *integrationTestSuite) TestInvites() { require.NotZero(t, createInviteResp.Invite.ID) validInvite := *createInviteResp.Invite + // create user from valid invite - the token was not returned via the + // response's json, must get it from the db + inv, err := s.ds.Invite(context.Background(), validInvite.ID) + require.NoError(t, err) + validInviteToken := inv.Token + + // verify the token with valid invite + var verifyInvResp verifyInviteResponse + s.DoJSON("GET", "/api/v1/fleet/invites/"+validInviteToken, nil, http.StatusOK, &verifyInvResp) + require.Equal(t, validInvite.ID, verifyInvResp.Invite.ID) + + // verify the token with an invalid invite + s.DoJSON("GET", "/api/v1/fleet/invites/invalid", nil, http.StatusNotFound, &verifyInvResp) + // create invite without an email createInviteReq = createInviteRequest{InvitePayload: fleet.InvitePayload{ Email: nil, @@ -1076,9 +1092,21 @@ func (s *integrationTestSuite) TestInvites() { require.Len(t, verify.Teams, 1) assert.Equal(t, team.ID, verify.Teams[0].ID) + var createFromInviteResp createUserResponse + s.DoJSON("POST", "/api/v1/fleet/users", fleet.UserPayload{ + Name: ptr.String("Full Name"), + Password: ptr.String("pass1word!"), + Email: ptr.String("a@b.c"), + InviteToken: ptr.String(validInviteToken), + }, http.StatusOK, &createFromInviteResp) + + // keep the invite token from the other valid invite (before deleting it) + inv, err = s.ds.Invite(context.Background(), createInviteResp.Invite.ID) + require.NoError(t, err) + deletedInviteToken := inv.Token + // delete an existing invite var delResp deleteInviteResponse - s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", validInvite.ID), nil, http.StatusOK, &delResp) s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", createInviteResp.Invite.ID), nil, http.StatusOK, &delResp) // list invites, is now empty @@ -1088,6 +1116,111 @@ func (s *integrationTestSuite) TestInvites() { // delete a now non-existing invite s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", validInvite.ID), nil, http.StatusNotFound, &delResp) + + // create user from never used but deleted invite + s.DoJSON("POST", "/api/v1/fleet/users", fleet.UserPayload{ + Name: ptr.String("Full Name"), + Password: ptr.String("pass1word!"), + Email: ptr.String("a@b.c"), + InviteToken: ptr.String(deletedInviteToken), + }, http.StatusNotFound, &createFromInviteResp) +} + +func (s *integrationTestSuite) TestCreateUserFromInviteErrors() { + t := s.T() + + // create a valid invite + createInviteReq := createInviteRequest{InvitePayload: fleet.InvitePayload{ + Email: ptr.String("a@b.c"), + Name: ptr.String("A"), + GlobalRole: null.StringFrom(fleet.RoleObserver), + }} + createInviteResp := createInviteResponse{} + s.DoJSON("POST", "/api/v1/fleet/invites", createInviteReq, http.StatusOK, &createInviteResp) + + // make sure to delete it on exit + defer func() { + var delResp deleteInviteResponse + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/invites/%d", createInviteResp.Invite.ID), nil, http.StatusOK, &delResp) + }() + + // the token is not returned via the response's json, must get it from the db + invite, err := s.ds.Invite(context.Background(), createInviteResp.Invite.ID) + require.NoError(t, err) + + cases := []struct { + desc string + pld fleet.UserPayload + want int + }{ + { + "empty name", + fleet.UserPayload{ + Name: ptr.String(""), + Password: ptr.String("pass1word!"), + Email: ptr.String("a@b.c"), + InviteToken: ptr.String(invite.Token), + }, + http.StatusUnprocessableEntity, + }, + { + "empty email", + fleet.UserPayload{ + Name: ptr.String("Name"), + Password: ptr.String("pass1word!"), + Email: ptr.String(""), + InviteToken: ptr.String(invite.Token), + }, + http.StatusUnprocessableEntity, + }, + { + "empty password", + fleet.UserPayload{ + Name: ptr.String("Name"), + Password: ptr.String(""), + Email: ptr.String("a@b.c"), + InviteToken: ptr.String(invite.Token), + }, + http.StatusUnprocessableEntity, + }, + { + "empty token", + fleet.UserPayload{ + Name: ptr.String("Name"), + Password: ptr.String("pass1word!"), + Email: ptr.String("a@b.c"), + InviteToken: ptr.String(""), + }, + http.StatusUnprocessableEntity, + }, + { + "invalid token", + fleet.UserPayload{ + Name: ptr.String("Name"), + Password: ptr.String("pass1word!"), + Email: ptr.String("a@b.c"), + InviteToken: ptr.String("invalid"), + }, + http.StatusNotFound, + }, + { + "invalid password", + fleet.UserPayload{ + Name: ptr.String("Name"), + Password: ptr.String("password"), // no number or symbol + Email: ptr.String("a@b.c"), + InviteToken: ptr.String(invite.Token), + }, + http.StatusUnprocessableEntity, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + var resp createUserResponse + s.DoJSON("POST", "/api/v1/fleet/users", c.pld, c.want, &resp) + }) + } } func (s *integrationTestSuite) TestGetHostSummary() { @@ -2302,6 +2435,9 @@ func (s *integrationTestSuite) TestLabelSpecs() { } func (s *integrationTestSuite) TestUsers() { + // ensure that on exit, the admin token is used + defer func() { s.token = s.getTestAdminToken() }() + t := s.T() // list existing users @@ -2324,14 +2460,16 @@ func (s *integrationTestSuite) TestUsers() { // create a new user var createResp createUserResponse + userRawPwd := "pass" params := fleet.UserPayload{ Name: ptr.String("extra"), Email: ptr.String("extra@asd.com"), - Password: ptr.String("pass"), + Password: ptr.String(userRawPwd), GlobalRole: ptr.String(fleet.RoleObserver), } s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp) assert.NotZero(t, createResp.User.ID) + assert.True(t, createResp.User.AdminForcedPasswordReset) u := *createResp.User // login as that user and check that teams info is empty @@ -2407,6 +2545,46 @@ func (s *integrationTestSuite) TestUsers() { } s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), params, http.StatusNotFound, &modResp) + // perform a required password change as the user themselves + s.token = s.getTestToken(u.Email, userRawPwd) + var perfPwdResetResp performRequiredPasswordResetResponse + newRawPwd := "new_password!" + s.DoJSON("POST", "/api/v1/fleet/perform_required_password_reset", performRequiredPasswordResetRequest{ + Password: newRawPwd, + ID: u.ID, + }, http.StatusOK, &perfPwdResetResp) + assert.False(t, perfPwdResetResp.User.AdminForcedPasswordReset) + oldUserRawPwd := userRawPwd + userRawPwd = newRawPwd + + // perform a required password change again, this time it fails as there is no request pending + perfPwdResetResp = performRequiredPasswordResetResponse{} + newRawPwd = "new_password2!" + s.DoJSON("POST", "/api/v1/fleet/perform_required_password_reset", performRequiredPasswordResetRequest{ + Password: newRawPwd, + ID: u.ID, + }, http.StatusInternalServerError, &perfPwdResetResp) // TODO: should be 40?, see #4406 + s.token = s.getTestAdminToken() + + // login as that user to verify that the new password is active (userRawPwd was updated to the new pwd) + loginResp = loginResponse{} + s.DoJSON("POST", "/api/v1/fleet/login", loginRequest{Email: u.Email, Password: userRawPwd}, http.StatusOK, &loginResp) + require.Equal(t, loginResp.User.ID, u.ID) + + // logout for that user + s.token = loginResp.Token + var logoutResp logoutResponse + s.DoJSON("POST", "/api/v1/fleet/logout", nil, http.StatusOK, &logoutResp) + + // logout again, even though not logged in + s.DoJSON("POST", "/api/v1/fleet/logout", nil, http.StatusInternalServerError, &logoutResp) // TODO: should be OK even if not logged in, see #4406. + + s.token = s.getTestAdminToken() + + // login as that user with previous pwd fails + loginResp = loginResponse{} + s.DoJSON("POST", "/api/v1/fleet/login", loginRequest{Email: u.Email, Password: oldUserRawPwd}, http.StatusUnauthorized, &loginResp) + // require a password reset var reqResetResp requirePasswordResetResponse s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID), map[string]bool{"require": true}, http.StatusOK, &reqResetResp) @@ -3094,6 +3272,245 @@ func (s *integrationTestSuite) TestOsqueryConfig() { assert.Contains(t, errRes["error"], "invalid node key") } +func (s *integrationTestSuite) TestEnrollHost() { + t := s.T() + + // set the enroll secret + var applyResp applyEnrollSecretSpecResponse + s.DoJSON("POST", "/api/v1/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{ + Spec: &fleet.EnrollSecretSpec{ + Secrets: []*fleet.EnrollSecret{{Secret: t.Name()}}, + }, + }, http.StatusOK, &applyResp) + + // invalid enroll secret fails + j, err := json.Marshal(&enrollAgentRequest{ + EnrollSecret: "nosuchsecret", + HostIdentifier: "abcd", + }) + require.NoError(t, err) + s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusUnauthorized) + + // valid enroll secret succeeds + j, err = json.Marshal(&enrollAgentRequest{ + EnrollSecret: t.Name(), + HostIdentifier: t.Name(), + }) + require.NoError(t, err) + + var resp enrollAgentResponse + hres := s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusOK) + defer hres.Body.Close() + require.NoError(t, json.NewDecoder(hres.Body).Decode(&resp)) + require.NotEmpty(t, resp.NodeKey) +} + +func (s *integrationTestSuite) TestCarve() { + t := s.T() + hosts := s.createHosts(t) + + // begin a carve with an invalid node key + var errRes map[string]interface{} + s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{ + NodeKey: hosts[0].NodeKey + "zzz", + BlockCount: 1, + BlockSize: 1, + CarveSize: 1, + CarveId: "c1", + }, http.StatusUnauthorized, &errRes) + assert.Contains(t, errRes["error"], "invalid node key") + + // invalid carve size + s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{ + NodeKey: hosts[0].NodeKey, + BlockCount: 3, + BlockSize: 3, + CarveSize: 0, + CarveId: "c1", + }, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406 + assert.Contains(t, errRes["error"], "carve_size must be greater") + + // invalid block size too big + s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{ + NodeKey: hosts[0].NodeKey, + BlockCount: 3, + BlockSize: maxBlockSize + 1, + CarveSize: maxCarveSize, + CarveId: "c1", + }, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406 + assert.Contains(t, errRes["error"], "block_size exceeds max") + + // invalid carve size too big + s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{ + NodeKey: hosts[0].NodeKey, + BlockCount: 3, + BlockSize: maxBlockSize, + CarveSize: maxCarveSize + 1, + CarveId: "c1", + }, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406 + assert.Contains(t, errRes["error"], "carve_size exceeds max") + + // invalid carve size, does not match blocks + s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{ + NodeKey: hosts[0].NodeKey, + BlockCount: 3, + BlockSize: 3, + CarveSize: 1, + CarveId: "c1", + }, http.StatusInternalServerError, &errRes) // TODO: should be 4xx, see #4406 + assert.Contains(t, errRes["error"], "carve_size does not match") + + // valid carve begin + var beginResp carveBeginResponse + s.DoJSON("POST", "/api/v1/osquery/carve/begin", carveBeginRequest{ + NodeKey: hosts[0].NodeKey, + BlockCount: 3, + BlockSize: 3, + CarveSize: 8, + CarveId: "c1", + RequestId: "r1", + }, http.StatusOK, &beginResp) + require.NotEmpty(t, beginResp.SessionId) + sid := beginResp.SessionId + + // sending a block with invalid session id + var blockResp carveBlockResponse + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 1, + SessionId: sid + "zz", + RequestId: "??", + Data: []byte("p1."), + }, http.StatusNotFound, &blockResp) + + // sending a block with valid session id but invalid request id + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 1, + SessionId: sid, + RequestId: "??", + Data: []byte("p1."), + }, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406 + + // sending a block with unexpected block id (expects 0, got 1) + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 1, + SessionId: sid, + RequestId: "r1", + Data: []byte("p1."), + }, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406 + + // sending a block with valid payload, block 0 + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 0, + SessionId: sid, + RequestId: "r1", + Data: []byte("p1."), + }, http.StatusOK, &blockResp) + require.True(t, blockResp.Success) + + // sending next block + blockResp = carveBlockResponse{} + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 1, + SessionId: sid, + RequestId: "r1", + Data: []byte("p2."), + }, http.StatusOK, &blockResp) + require.True(t, blockResp.Success) + + // sending already-sent block again + blockResp = carveBlockResponse{} + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 1, + SessionId: sid, + RequestId: "r1", + Data: []byte("p2."), + }, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406 + + // sending final block with too many bytes + blockResp = carveBlockResponse{} + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 2, + SessionId: sid, + RequestId: "r1", + Data: []byte("p3extra"), + }, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406 + + // sending actual final block + blockResp = carveBlockResponse{} + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 2, + SessionId: sid, + RequestId: "r1", + Data: []byte("p3"), + }, http.StatusOK, &blockResp) + require.True(t, blockResp.Success) + + // sending unexpected block + blockResp = carveBlockResponse{} + s.DoJSON("POST", "/api/v1/osquery/carve/block", carveBlockRequest{ + BlockId: 3, + SessionId: sid, + RequestId: "r1", + Data: []byte("p4."), + }, http.StatusInternalServerError, &blockResp) // TODO: should be 400, see #4406 +} + +func (s *integrationTestSuite) TestPasswordReset() { + t := s.T() + + // create a new user + var createResp createUserResponse + userRawPwd := "passw0rd!" + params := fleet.UserPayload{ + Name: ptr.String("forgotpwd"), + Email: ptr.String("forgotpwd@example.com"), + Password: ptr.String(userRawPwd), + GlobalRole: ptr.String(fleet.RoleObserver), + } + s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp) + require.NotZero(t, createResp.User.ID) + u := *createResp.User + _ = u + + // request forgot password, invalid email + res := s.DoRawNoAuth("POST", "/api/v1/fleet/forgot_password", jsonMustMarshal(t, forgotPasswordRequest{Email: "invalid@asd.com"}), http.StatusAccepted) + res.Body.Close() + + // TODO: tested manually (adds too much time to the test), works but hitting the rate + // limit returns 500 instead of 429, see #4406. We get the authz check missing error instead. + //// trigger the rate limit with a batch of requests in a short burst + //for i := 0; i < 20; i++ { + // s.DoJSON("POST", "/api/v1/fleet/forgot_password", forgotPasswordRequest{Email: "invalid@asd.com"}, http.StatusAccepted, &forgotResp) + //} + + // request forgot password, valid email + res = s.DoRawNoAuth("POST", "/api/v1/fleet/forgot_password", jsonMustMarshal(t, forgotPasswordRequest{Email: u.Email}), http.StatusAccepted) + res.Body.Close() + + var token string + mysql.ExecAdhocSQL(t, s.ds, func(db sqlx.ExtContext) error { + return sqlx.GetContext(context.Background(), db, &token, "SELECT token FROM password_reset_requests WHERE user_id = ?", u.ID) + }) + + // proceed with reset password + userNewPwd := "newpassw0rd!" + res = s.DoRawNoAuth("POST", "/api/v1/fleet/reset_password", jsonMustMarshal(t, resetPasswordRequest{PasswordResetToken: token, NewPassword: userNewPwd}), http.StatusOK) + res.Body.Close() + + // attempt it again with already-used token + userUnusedPwd := "unusedpassw0rd!" + res = s.DoRawNoAuth("POST", "/api/v1/fleet/reset_password", jsonMustMarshal(t, resetPasswordRequest{PasswordResetToken: token, NewPassword: userUnusedPwd}), http.StatusInternalServerError) // TODO: should be 40x, see #4406 + res.Body.Close() + + // login with the old password, should not succeed + res = s.DoRawNoAuth("POST", "/api/v1/fleet/login", jsonMustMarshal(t, loginRequest{Email: u.Email, Password: userRawPwd}), http.StatusUnauthorized) + res.Body.Close() + + // login with the new password, should succeed + res = s.DoRawNoAuth("POST", "/api/v1/fleet/login", jsonMustMarshal(t, loginRequest{Email: u.Email, Password: userNewPwd}), http.StatusOK) + res.Body.Close() +} + // creates a session and returns it, its key is to be passed as authorization header. func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session { key := make([]byte, 64) @@ -3116,3 +3533,9 @@ func cleanupQuery(s *integrationTestSuite, queryID uint) { var delResp deleteQueryByIDResponse s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/queries/id/%d", queryID), nil, http.StatusOK, &delResp) } + +func jsonMustMarshal(t testing.TB, v interface{}) []byte { + b, err := json.Marshal(v) + require.NoError(t, err) + return b +} diff --git a/server/service/invites.go b/server/service/invites.go index b727a84654..8304c56d13 100644 --- a/server/service/invites.go +++ b/server/service/invites.go @@ -10,6 +10,7 @@ import ( "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mail" @@ -264,3 +265,51 @@ func (svc *Service) DeleteInvite(ctx context.Context, id uint) error { } return svc.ds.DeleteInvite(ctx, id) } + +//////////////////////////////////////////////////////////////////////////////// +// Verify invite +//////////////////////////////////////////////////////////////////////////////// + +type verifyInviteRequest struct { + Token string `url:"token"` +} + +type verifyInviteResponse struct { + Invite *fleet.Invite `json:"invite"` + Err error `json:"error,omitempty"` +} + +func (r verifyInviteResponse) error() error { return r.Err } + +func verifyInviteEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*verifyInviteRequest) + invite, err := svc.VerifyInvite(ctx, req.Token) + if err != nil { + return verifyInviteResponse{Err: err}, nil + } + return verifyInviteResponse{Invite: invite}, nil +} + +func (svc *Service) VerifyInvite(ctx context.Context, token string) (*fleet.Invite, error) { + // skipauth: There is no viewer context at this point. We rely on verifying + // the invite for authNZ. + svc.authz.SkipAuthorization(ctx) + + logging.WithExtras(ctx, "token", token) + + invite, err := svc.ds.InviteByToken(ctx, token) + if err != nil { + return nil, err + } + + if invite.Token != token { + return nil, fleet.NewInvalidArgumentError("invite_token", "Invite Token does not match Email Address.") + } + + expiresAt := invite.CreatedAt.Add(svc.config.App.InviteTokenValidityPeriod) + if svc.clock.Now().After(expiresAt) { + return nil, fleet.NewInvalidArgumentError("invite_token", "Invite token has expired.") + } + + return invite, nil +} diff --git a/server/service/service_invites_test.go b/server/service/invites_test.go similarity index 100% rename from server/service/service_invites_test.go rename to server/service/invites_test.go diff --git a/server/service/jitter.go b/server/service/jitter.go new file mode 100644 index 0000000000..e69fd39c82 --- /dev/null +++ b/server/service/jitter.go @@ -0,0 +1,84 @@ +package service + +import ( + "sync" + "time" +) + +// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value +// that is properly balanced. Balance in this context means that hosts would be distributed uniformly +// across the total jitter time so there are no spikes. +// The way this structure works is as follows: +// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to +// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should +// end up with 1 per bucket. +// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount +// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we +// increase the maxCapacity by 1 for all buckets, and start placing hosts. +// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is +// running. +// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the +// next one will be tried. If all buckets are full, then capacity gets increased and the bucket +// selection process restarts. +// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of +// minutes added to the host check in time. +// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to +// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case. +// In the worst possible case that all hosts start at the same time, max jitter percent can be set to +// 100, and this method will distribute hosts evenly. +// The main caveat of this approach is that it works at the fleet instance. So depending on what +// instance gets chosen by the load balancer, the jitter might be different. However, load tests have +// shown that the distribution in practice is pretty balance even when all hosts try to check in at +// the same time. +type jitterHashTable struct { + mu sync.Mutex + maxCapacity int + bucketCount int + buckets map[int]int + cache map[uint]time.Duration +} + +func newJitterHashTable(bucketCount int) *jitterHashTable { + if bucketCount == 0 { + bucketCount = 1 + } + return &jitterHashTable{ + maxCapacity: 1, + bucketCount: bucketCount, + buckets: make(map[int]int), + cache: make(map[uint]time.Duration), + } +} + +func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration { + // if no jitter is configured just return 0 + if jh.bucketCount <= 1 { + return 0 + } + + jh.mu.Lock() + if jitter, ok := jh.cache[hostID]; ok { + jh.mu.Unlock() + return jitter + } + + for i := 0; i < jh.bucketCount; i++ { + possibleBucket := (int(hostID) + i) % jh.bucketCount + + // if the next bucket has capacity, great! + if jh.buckets[possibleBucket] < jh.maxCapacity { + jh.buckets[possibleBucket]++ + jitter := time.Duration(possibleBucket) * time.Minute + jh.cache[hostID] = jitter + + jh.mu.Unlock() + return jitter + } + } + + // otherwise, bump the capacity and restart the process + jh.maxCapacity++ + + jh.mu.Unlock() + return jh.jitterForHost(hostID) +} diff --git a/server/service/jitter_test.go b/server/service/jitter_test.go new file mode 100644 index 0000000000..cc838fc973 --- /dev/null +++ b/server/service/jitter_test.go @@ -0,0 +1,52 @@ +package service + +import ( + crand "crypto/rand" + "math" + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestJitterForHost(t *testing.T) { + jh := newJitterHashTable(30) + + histogram := make(map[int64]int) + hostCount := 3000 + for i := 0; i < hostCount; i++ { + hostID, err := crand.Int(crand.Reader, big.NewInt(10000)) + require.NoError(t, err) + jitter := jh.jitterForHost(uint(hostID.Int64() + 10000)) + jitterMinutes := int64(jitter.Minutes()) + histogram[jitterMinutes]++ + } + min, max := math.MaxInt, 0 + for jitterMinutes, count := range histogram { + if count < min { + min = count + } + if count > max { + max = count + } + t.Logf("jitterMinutes=%d \t count=%d\n", jitterMinutes, count) + } + variation := max - min + t.Logf("min=%d \t max=%d \t variation=%d\n", min, max, variation) + + // check that variation is below 1% of the total amount of hosts + require.Less(t, variation, int(float32(hostCount)/0.01)) +} + +func TestNoJitter(t *testing.T) { + jh := newJitterHashTable(0) + + hostCount := 3000 + for i := 0; i < hostCount; i++ { + hostID, err := crand.Int(crand.Reader, big.NewInt(10000)) + require.NoError(t, err) + jitter := jh.jitterForHost(uint(hostID.Int64() + 10000)) + jitterMinutes := int64(jitter.Minutes()) + require.Equal(t, int64(0), jitterMinutes) + } +} diff --git a/server/service/middleware/ratelimit/ratelimit.go b/server/service/middleware/ratelimit/ratelimit.go index 77a27b5d5b..3fb3971ff2 100644 --- a/server/service/middleware/ratelimit/ratelimit.go +++ b/server/service/middleware/ratelimit/ratelimit.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "net/http" - "reflect" - "runtime" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/go-kit/kit/endpoint" @@ -28,19 +26,15 @@ func NewMiddleware(store throttled.GCRAStore) *Middleware { } // Limit returns a new middleware function enforcing the provided quota. -func (m *Middleware) Limit(quota throttled.RateQuota) endpoint.Middleware { +func (m *Middleware) Limit(keyName string, quota throttled.RateQuota) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { - // Get function name to use as a key for rate limiting (each wrapped function - // gets a separate quota) - funcName := runtime.FuncForPC(reflect.ValueOf(next).Pointer()).Name() - limiter, err := throttled.NewGCRARateLimiter(m.store, quota) if err != nil { panic(err) } return func(ctx context.Context, req interface{}) (response interface{}, err error) { - limited, result, err := limiter.RateLimit(funcName, 1) + limited, result, err := limiter.RateLimit(keyName, 1) if err != nil { return nil, ctxerr.Wrap(ctx, err, "check rate limit") } diff --git a/server/service/middleware/ratelimit/ratelimit_test.go b/server/service/middleware/ratelimit/ratelimit_test.go index e719fd0c42..8e3338840e 100644 --- a/server/service/middleware/ratelimit/ratelimit_test.go +++ b/server/service/middleware/ratelimit/ratelimit_test.go @@ -20,6 +20,7 @@ func TestLimit(t *testing.T) { limiter := NewMiddleware(store) endpoint := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } wrapped := limiter.Limit( + "test_limit", throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 0}, )(endpoint) diff --git a/server/service/osquery.go b/server/service/osquery.go index 22db162ee5..a90e36b88d 100644 --- a/server/service/osquery.go +++ b/server/service/osquery.go @@ -8,8 +8,10 @@ import ( "regexp" "strconv" "strings" + "sync/atomic" "time" + "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" "github.com/fleetdm/fleet/v4/server/contexts/logging" @@ -22,6 +24,255 @@ import ( "github.com/spf13/cast" ) +type osqueryError struct { + message string + nodeInvalid bool +} + +func (e osqueryError) Error() string { + return e.message +} + +func (e osqueryError) NodeInvalid() bool { + return e.nodeInvalid +} + +func (svc *Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + if nodeKey == "" { + return nil, false, osqueryError{ + message: "authentication error: missing node key", + nodeInvalid: true, + } + } + + host, err := svc.ds.LoadHostByNodeKey(ctx, nodeKey) + switch { + case err == nil: + // OK + case fleet.IsNotFound(err): + return nil, false, osqueryError{ + message: "authentication error: invalid node key: " + nodeKey, + nodeInvalid: true, + } + default: + return nil, false, osqueryError{ + message: "authentication error: " + err.Error(), + } + } + + // Update the "seen" time used to calculate online status. These updates are + // batched for MySQL performance reasons. Because this is done + // asynchronously, it is possible for the server to shut down before + // updating the seen time for these hosts. This seems to be an acceptable + // tradeoff as an online host will continue to check in and quickly be + // marked online again. + svc.seenHostSet.addHostID(host.ID) + host.SeenTime = svc.clock.Now() + + return host, svc.debugEnabledForHost(ctx, host.ID), nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Enroll Agent +//////////////////////////////////////////////////////////////////////////////// + +type enrollAgentRequest struct { + EnrollSecret string `json:"enroll_secret"` + HostIdentifier string `json:"host_identifier"` + HostDetails map[string](map[string]string) `json:"host_details"` +} + +type enrollAgentResponse struct { + NodeKey string `json:"node_key,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r enrollAgentResponse) error() error { return r.Err } + +func enrollAgentEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*enrollAgentRequest) + nodeKey, err := svc.EnrollAgent(ctx, req.EnrollSecret, req.HostIdentifier, req.HostDetails) + if err != nil { + return enrollAgentResponse{Err: err}, nil + } + return enrollAgentResponse{NodeKey: nodeKey}, nil +} + +func (svc *Service) EnrollAgent(ctx context.Context, enrollSecret, hostIdentifier string, hostDetails map[string](map[string]string)) (string, error) { + // skipauth: Authorization is currently for user endpoints only. + svc.authz.SkipAuthorization(ctx) + + logging.WithExtras(ctx, "hostIdentifier", hostIdentifier) + + secret, err := svc.ds.VerifyEnrollSecret(ctx, enrollSecret) + if err != nil { + return "", osqueryError{ + message: "enroll failed: " + err.Error(), + nodeInvalid: true, + } + } + + nodeKey, err := server.GenerateRandomText(svc.config.Osquery.NodeKeySize) + if err != nil { + return "", osqueryError{ + message: "generate node key failed: " + err.Error(), + nodeInvalid: true, + } + } + + hostIdentifier = getHostIdentifier(svc.logger, svc.config.Osquery.HostIdentifier, hostIdentifier, hostDetails) + + host, err := svc.ds.EnrollHost(ctx, hostIdentifier, nodeKey, secret.TeamID, svc.config.Osquery.EnrollCooldown) + if err != nil { + return "", osqueryError{message: "save enroll failed: " + err.Error(), nodeInvalid: true} + } + + appConfig, err := svc.ds.AppConfig(ctx) + if err != nil { + return "", osqueryError{message: "app config load failed: " + err.Error(), nodeInvalid: true} + } + + // Save enrollment details if provided + detailQueries := osquery_utils.GetDetailQueries(appConfig, svc.config) + save := false + if r, ok := hostDetails["os_version"]; ok { + err := detailQueries["os_version"].IngestFunc(svc.logger, host, []map[string]string{r}) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "Ingesting os_version") + } + save = true + } + if r, ok := hostDetails["osquery_info"]; ok { + err := detailQueries["osquery_info"].IngestFunc(svc.logger, host, []map[string]string{r}) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "Ingesting osquery_info") + } + save = true + } + if r, ok := hostDetails["system_info"]; ok { + err := detailQueries["system_info"].IngestFunc(svc.logger, host, []map[string]string{r}) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "Ingesting system_info") + } + save = true + } + + if save { + if appConfig.ServerSettings.DeferredSaveHost { + go svc.serialUpdateHost(host) + } else { + if err := svc.ds.UpdateHost(ctx, host); err != nil { + return "", ctxerr.Wrap(ctx, err, "save host in enroll agent") + } + } + } + + return nodeKey, nil +} + +var counter = int64(0) + +func (svc *Service) serialUpdateHost(host *fleet.Host) { + newVal := atomic.AddInt64(&counter, 1) + defer func() { + atomic.AddInt64(&counter, -1) + }() + level.Debug(svc.logger).Log("background", newVal) + + ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelFunc() + err := svc.ds.SerialUpdateHost(ctx, host) + if err != nil { + level.Error(svc.logger).Log("background-err", err) + } +} + +func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier string, details map[string](map[string]string)) string { + switch identifierOption { + case "provided": + // Use the host identifier already provided in the request. + return providedIdentifier + + case "instance": + r, ok := details["osquery_info"] + if !ok { + level.Info(logger).Log( + "msg", "could not get host identifier", + "reason", "missing osquery_info", + "identifier", "instance", + ) + } else if r["instance_id"] == "" { + level.Info(logger).Log( + "msg", "could not get host identifier", + "reason", "missing instance_id in osquery_info", + "identifier", "instance", + ) + } else { + return r["instance_id"] + } + + case "uuid": + r, ok := details["osquery_info"] + if !ok { + level.Info(logger).Log( + "msg", "could not get host identifier", + "reason", "missing osquery_info", + "identifier", "uuid", + ) + } else if r["uuid"] == "" { + level.Info(logger).Log( + "msg", "could not get host identifier", + "reason", "missing instance_id in osquery_info", + "identifier", "uuid", + ) + } else { + return r["uuid"] + } + + case "hostname": + r, ok := details["system_info"] + if !ok { + level.Info(logger).Log( + "msg", "could not get host identifier", + "reason", "missing system_info", + "identifier", "hostname", + ) + } else if r["hostname"] == "" { + level.Info(logger).Log( + "msg", "could not get host identifier", + "reason", "missing instance_id in system_info", + "identifier", "hostname", + ) + } else { + return r["hostname"] + } + + default: + panic("Unknown option for host_identifier: " + identifierOption) + } + + return providedIdentifier +} + +func (svc *Service) debugEnabledForHost(ctx context.Context, id uint) bool { + hlogger := log.With(svc.logger, "host-id", id) + ac, err := svc.ds.AppConfig(ctx) + if err != nil { + level.Debug(hlogger).Log("err", ctxerr.Wrap(ctx, err, "getting app config for host debug")) + return false + } + + for _, hostID := range ac.ServerSettings.DebugHostIDs { + if hostID == id { + return true + } + } + return false +} + //////////////////////////////////////////////////////////////////////////////// // Get Client Config //////////////////////////////////////////////////////////////////////////////// diff --git a/server/service/osquery_test.go b/server/service/osquery_test.go index b02a8602ba..248ab86c25 100644 --- a/server/service/osquery_test.go +++ b/server/service/osquery_test.go @@ -1,15 +1,37 @@ package service import ( + "bytes" "context" "encoding/json" "errors" + "fmt" + "io/ioutil" + "reflect" + "sort" + "strconv" + "strings" + "sync" "testing" + "time" + "github.com/WatchBeam/clock" + "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/config" hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" + fleetLogging "github.com/fleetdm/fleet/v4/server/contexts/logging" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/live_query" + "github.com/fleetdm/fleet/v4/server/logging" "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/fleetdm/fleet/v4/server/pubsub" + "github.com/fleetdm/fleet/v4/server/service/osquery_utils" + "github.com/fleetdm/fleet/v4/server/service/redis_policy_set" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -161,3 +183,2325 @@ func TestAgentOptionsForHost(t *testing.T) { require.NoError(t, err) assert.JSONEq(t, `{"foo":"override2"}`, string(opt)) } + +// One of these queries is the disk space, only one of the two works in a platform +var expectedDetailQueries = len(osquery_utils.GetDetailQueries(&fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, config.FleetConfig{})) - 1 + +func TestEnrollAgent(t *testing.T) { + ds := new(mock.Store) + ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { + switch secret { + case "valid_secret": + return &fleet.EnrollSecret{Secret: "valid_secret", TeamID: ptr.Uint(3)}, nil + default: + return nil, errors.New("not found") + } + } + ds.EnrollHostFunc = func(ctx context.Context, osqueryHostId, nodeKey string, teamID *uint, cooldown time.Duration) (*fleet.Host, error) { + assert.Equal(t, ptr.Uint(3), teamID) + return &fleet.Host{ + OsqueryHostID: osqueryHostId, NodeKey: nodeKey, + }, nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + svc := newTestService(ds, nil, nil) + + nodeKey, err := svc.EnrollAgent(context.Background(), "valid_secret", "host123", nil) + require.NoError(t, err) + assert.NotEmpty(t, nodeKey) +} + +func TestEnrollAgentIncorrectEnrollSecret(t *testing.T) { + ds := new(mock.Store) + ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { + switch secret { + case "valid_secret": + return &fleet.EnrollSecret{Secret: "valid_secret", TeamID: ptr.Uint(3)}, nil + default: + return nil, errors.New("not found") + } + } + + svc := newTestService(ds, nil, nil) + + nodeKey, err := svc.EnrollAgent(context.Background(), "not_correct", "host123", nil) + assert.NotNil(t, err) + assert.Empty(t, nodeKey) +} + +func TestEnrollAgentDetails(t *testing.T) { + ds := new(mock.Store) + ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { + return &fleet.EnrollSecret{}, nil + } + ds.EnrollHostFunc = func(ctx context.Context, osqueryHostId, nodeKey string, teamID *uint, cooldown time.Duration) (*fleet.Host, error) { + return &fleet.Host{ + OsqueryHostID: osqueryHostId, NodeKey: nodeKey, + }, nil + } + var gotHost *fleet.Host + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + gotHost = host + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + svc := newTestService(ds, nil, nil) + + details := map[string](map[string]string){ + "osquery_info": {"version": "2.12.0"}, + "system_info": {"hostname": "zwass.local", "uuid": "froobling_uuid"}, + "os_version": { + "name": "Mac OS X", + "major": "10", + "minor": "14", + "patch": "5", + "platform": "darwin", + }, + "foo": {"foo": "bar"}, + } + nodeKey, err := svc.EnrollAgent(context.Background(), "", "host123", details) + require.NoError(t, err) + assert.NotEmpty(t, nodeKey) + + assert.Equal(t, "Mac OS X 10.14.5", gotHost.OSVersion) + assert.Equal(t, "darwin", gotHost.Platform) + assert.Equal(t, "2.12.0", gotHost.OsqueryVersion) + assert.Equal(t, "zwass.local", gotHost.Hostname) + assert.Equal(t, "froobling_uuid", gotHost.UUID) +} + +func TestAuthenticateHost(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + var gotKey string + host := fleet.Host{ID: 1, Hostname: "foobar"} + ds.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + gotKey = nodeKey + return &host, nil + } + var gotHostIDs []uint + ds.MarkHostsSeenFunc = func(ctx context.Context, hostIDs []uint, t time.Time) error { + gotHostIDs = hostIDs + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + _, _, err := svc.AuthenticateHost(context.Background(), "test") + require.NoError(t, err) + assert.Equal(t, "test", gotKey) + assert.False(t, ds.MarkHostsSeenFuncInvoked) + + host = fleet.Host{ID: 7, Hostname: "foobar"} + _, _, err = svc.AuthenticateHost(context.Background(), "floobar") + require.NoError(t, err) + assert.Equal(t, "floobar", gotKey) + assert.False(t, ds.MarkHostsSeenFuncInvoked) + // Host checks in twice + host = fleet.Host{ID: 7, Hostname: "foobar"} + _, _, err = svc.AuthenticateHost(context.Background(), "floobar") + require.NoError(t, err) + assert.Equal(t, "floobar", gotKey) + assert.False(t, ds.MarkHostsSeenFuncInvoked) + + err = svc.FlushSeenHosts(context.Background()) + require.NoError(t, err) + assert.True(t, ds.MarkHostsSeenFuncInvoked) + assert.ElementsMatch(t, []uint{1, 7}, gotHostIDs) + + err = svc.FlushSeenHosts(context.Background()) + require.NoError(t, err) + assert.True(t, ds.MarkHostsSeenFuncInvoked) + require.Len(t, gotHostIDs, 0) +} + +func TestAuthenticateHostFailure(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + return nil, errors.New("not found") + } + + _, _, err := svc.AuthenticateHost(context.Background(), "test") + require.NotNil(t, err) +} + +type testJSONLogger struct { + logs []json.RawMessage +} + +func (n *testJSONLogger) Write(ctx context.Context, logs []json.RawMessage) error { + n.logs = logs + return nil +} + +func TestSubmitStatusLogs(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + // Hack to get at the service internals and modify the writer + serv := ((svc.(validationMiddleware)).Service).(*Service) + + testLogger := &testJSONLogger{} + serv.osqueryLogWriter = &logging.OsqueryLogger{Status: testLogger} + + logs := []string{ + `{"severity":"0","filename":"tls.cpp","line":"216","message":"some message","version":"1.8.2","decorations":{"host_uuid":"uuid_foobar","username":"zwass"}}`, + `{"severity":"1","filename":"buffered.cpp","line":"122","message":"warning!","version":"1.8.2","decorations":{"host_uuid":"uuid_foobar","username":"zwass"}}`, + } + logJSON := fmt.Sprintf("[%s]", strings.Join(logs, ",")) + + var status []json.RawMessage + err := json.Unmarshal([]byte(logJSON), &status) + require.NoError(t, err) + + host := fleet.Host{} + ctx := hostctx.NewContext(context.Background(), &host) + err = serv.SubmitStatusLogs(ctx, status) + require.NoError(t, err) + + assert.Equal(t, status, testLogger.logs) +} + +func TestSubmitResultLogs(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + // Hack to get at the service internals and modify the writer + serv := ((svc.(validationMiddleware)).Service).(*Service) + + testLogger := &testJSONLogger{} + serv.osqueryLogWriter = &logging.OsqueryLogger{Result: testLogger} + + logs := []string{ + `{"name":"system_info","hostIdentifier":"some_uuid","calendarTime":"Fri Sep 30 17:55:15 2016 UTC","unixTime":"1475258115","decorations":{"host_uuid":"some_uuid","username":"zwass"},"columns":{"cpu_brand":"Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz","hostname":"hostimus","physical_memory":"17179869184"},"action":"added"}`, + `{"name":"encrypted","hostIdentifier":"some_uuid","calendarTime":"Fri Sep 30 21:19:15 2016 UTC","unixTime":"1475270355","decorations":{"host_uuid":"4740D59F-699E-5B29-960B-979AAF9BBEEB","username":"zwass"},"columns":{"encrypted":"1","name":"\/dev\/disk1","type":"AES-XTS","uid":"","user_uuid":"","uuid":"some_uuid"},"action":"added"}`, + `{"snapshot":[{"hour":"20","minutes":"8"}],"action":"snapshot","name":"time","hostIdentifier":"1379f59d98f4","calendarTime":"Tue Jan 10 20:08:51 2017 UTC","unixTime":"1484078931","decorations":{"host_uuid":"EB714C9D-C1F8-A436-B6DA-3F853C5502EA"}}`, + `{"diffResults":{"removed":[{"address":"127.0.0.1","hostnames":"kl.groob.io"}],"added":""},"name":"pack\/test\/hosts","hostIdentifier":"FA01680E-98CA-5557-8F59-7716ECFEE964","calendarTime":"Sun Nov 19 00:02:08 2017 UTC","unixTime":"1511049728","epoch":"0","counter":"10","decorations":{"host_uuid":"FA01680E-98CA-5557-8F59-7716ECFEE964","hostname":"kl.groob.io"}}`, + // fleet will accept anything in the "data" field of a log request. + `{"unknown":{"foo": [] }}`, + } + logJSON := fmt.Sprintf("[%s]", strings.Join(logs, ",")) + + var results []json.RawMessage + err := json.Unmarshal([]byte(logJSON), &results) + require.NoError(t, err) + + host := fleet.Host{} + ctx := hostctx.NewContext(context.Background(), &host) + err = serv.SubmitResultLogs(ctx, results) + require.NoError(t, err) + + assert.Equal(t, results, testLogger.logs) +} + +func TestHostDetailQueries(t *testing.T) { + ds := new(mock.Store) + additional := json.RawMessage(`{"foobar": "select foo", "bim": "bam"}`) + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{AdditionalQueries: &additional, EnableHostUsers: true}}, nil + } + + mockClock := clock.NewMockClock() + host := fleet.Host{ + ID: 1, + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + UpdateTimestamp: fleet.UpdateTimestamp{ + UpdatedAt: mockClock.Now(), + }, + CreateTimestamp: fleet.CreateTimestamp{ + CreatedAt: mockClock.Now(), + }, + }, + + Platform: "darwin", + DetailUpdatedAt: mockClock.Now(), + NodeKey: "test_key", + Hostname: "test_hostname", + UUID: "test_uuid", + } + + svc := &Service{ + clock: mockClock, + logger: log.NewNopLogger(), + config: config.TestConfig(), + ds: ds, + jitterMu: new(sync.Mutex), + jitterH: make(map[time.Duration]*jitterHashTable), + } + + queries, err := svc.detailQueriesForHost(context.Background(), &host) + require.NoError(t, err) + assert.Empty(t, queries) + + // With refetch requested detail queries should be returned + host.RefetchRequested = true + queries, err = svc.detailQueriesForHost(context.Background(), &host) + require.NoError(t, err) + assert.NotEmpty(t, queries) + host.RefetchRequested = false + + // Advance the time + mockClock.AddTime(1*time.Hour + 1*time.Minute) + + queries, err = svc.detailQueriesForHost(context.Background(), &host) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+2) + for name := range queries { + assert.True(t, + strings.HasPrefix(name, hostDetailQueryPrefix) || strings.HasPrefix(name, hostAdditionalQueryPrefix), + ) + } + assert.Equal(t, "bam", queries[hostAdditionalQueryPrefix+"bim"]) + assert.Equal(t, "select foo", queries[hostAdditionalQueryPrefix+"foobar"]) +} + +func TestGetDistributedQueriesMissingHost(t *testing.T) { + svc := newTestService(&mock.Store{}, nil, nil) + + _, _, err := svc.GetDistributedQueries(context.Background()) + require.NotNil(t, err) + assert.Contains(t, err.Error(), "missing host") +} + +func TestLabelQueries(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + lq := new(live_query.MockLiveQuery) + svc := newTestServiceWithClock(ds, nil, lq, mockClock) + + host := &fleet.Host{ + Platform: "darwin", + } + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + return host, nil + } + ds.UpdateHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { + host = gotHost + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil + } + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + + lq.On("QueriesForHost", uint(0)).Return(map[string]string{}, nil) + + ctx := hostctx.NewContext(context.Background(), host) + + // With a new host, we should get the detail queries (and accelerate + // should be turned on so that we can quickly fill labels) + queries, acc, err := svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + assert.NotZero(t, acc) + + // Simulate the detail queries being added. + host.DetailUpdatedAt = mockClock.Now().Add(-1 * time.Minute) + host.Hostname = "zwass.local" + ctx = hostctx.NewContext(ctx, host) + + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{ + "label1": "query1", + "label2": "query2", + "label3": "query3", + }, nil + } + + // Now we should get the label queries + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 3) + assert.Zero(t, acc) + + var gotHost *fleet.Host + var gotResults map[uint]*bool + var gotTime time.Time + ds.RecordLabelQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferred bool) error { + gotHost = host + gotResults = results + gotTime = t + return nil + } + + // Record a query execution + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostLabelQueryPrefix + "1": {{"col1": "val1"}}, + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + host.LabelUpdatedAt = mockClock.Now() + assert.Equal(t, host, gotHost) + assert.Equal(t, mockClock.Now(), gotTime) + require.Len(t, gotResults, 1) + assert.Equal(t, true, *gotResults[1]) + + mockClock.AddTime(1 * time.Second) + + // Record a query execution + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostLabelQueryPrefix + "2": {{"col1": "val1"}}, + hostLabelQueryPrefix + "3": {}, + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + host.LabelUpdatedAt = mockClock.Now() + assert.Equal(t, host, gotHost) + assert.Equal(t, mockClock.Now(), gotTime) + require.Len(t, gotResults, 2) + assert.Equal(t, true, *gotResults[2]) + assert.Equal(t, false, *gotResults[3]) + + // We should get no labels now. + host.LabelUpdatedAt = mockClock.Now() + ctx = hostctx.NewContext(ctx, host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) + + // With refetch requested details+label queries should be returned. + host.RefetchRequested = true + ctx = hostctx.NewContext(ctx, host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+3) + assert.Zero(t, acc) + + // Record a query execution + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostLabelQueryPrefix + "2": {{"col1": "val1"}}, + hostLabelQueryPrefix + "3": {}, + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + host.LabelUpdatedAt = mockClock.Now() + assert.Equal(t, host, gotHost) + assert.Equal(t, mockClock.Now(), gotTime) + require.Len(t, gotResults, 2) + assert.Equal(t, true, *gotResults[2]) + assert.Equal(t, false, *gotResults[3]) + + // SubmitDistributedQueryResults will set RefetchRequested to false. + require.False(t, host.RefetchRequested) + + // There shouldn't be any labels now. + ctx = hostctx.NewContext(context.Background(), host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) +} + +func TestDetailQueriesWithEmptyStrings(t *testing.T) { + ds := new(mock.Store) + mockClock := clock.NewMockClock() + lq := new(live_query.MockLiveQuery) + svc := newTestServiceWithClock(ds, nil, lq, mockClock) + + host := &fleet.Host{ + ID: 1, + Platform: "windows", + } + ctx := hostctx.NewContext(context.Background(), host) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil + } + ds.LabelQueriesForHostFunc = func(context.Context, *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if id != 1 { + return nil, errors.New("not found") + } + return host, nil + } + + lq.On("QueriesForHost", host.ID).Return(map[string]string{}, nil) + + // With a new host, we should get the detail queries (and accelerated + // queries) + queries, acc, err := svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries-3) + assert.NotZero(t, acc) + + resultJSON := ` +{ +"fleet_detail_query_network_interface": [ + { + "address": "192.168.0.1", + "broadcast": "192.168.0.255", + "ibytes": "", + "ierrors": "", + "interface": "en0", + "ipackets": "25698094", + "last_change": "1474233476", + "mac": "5f:3d:4b:10:25:82", + "mask": "255.255.255.0", + "metric": "", + "mtu": "", + "obytes": "", + "oerrors": "", + "opackets": "", + "point_to_point": "", + "type": "" + } +], +"fleet_detail_query_os_version": [ + { + "platform": "darwin", + "build": "15G1004", + "major": "10", + "minor": "10", + "name": "Mac OS X", + "patch": "6" + } +], +"fleet_detail_query_osquery_info": [ + { + "build_distro": "10.10", + "build_platform": "darwin", + "config_hash": "3c6e4537c4d0eb71a7c6dda19d", + "config_valid": "1", + "extensions": "active", + "pid": "38113", + "start_time": "1475603155", + "version": "1.8.2", + "watcher": "38112" + } +], +"fleet_detail_query_system_info": [ + { + "computer_name": "computer", + "cpu_brand": "Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz", + "cpu_logical_cores": "8", + "cpu_physical_cores": "4", + "cpu_subtype": "Intel x86-64h Haswell", + "cpu_type": "x86_64h", + "hardware_model": "MacBookPro11,4", + "hardware_serial": "ABCDEFGH", + "hardware_vendor": "Apple Inc.", + "hardware_version": "1.0", + "hostname": "computer.local", + "physical_memory": "17179869184", + "uuid": "uuid" + } +], +"fleet_detail_query_uptime": [ + { + "days": "20", + "hours": "0", + "minutes": "48", + "seconds": "13", + "total_seconds": "1730893" + } +], +"fleet_detail_query_osquery_flags": [ + { + "name":"config_tls_refresh", + "value":"" + }, + { + "name":"distributed_interval", + "value":"" + }, + { + "name":"logger_tls_period", + "value":"" + } +] +} +` + + var results fleet.OsqueryDistributedQueryResults + err = json.Unmarshal([]byte(resultJSON), &results) + require.NoError(t, err) + + var gotHost *fleet.Host + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + gotHost = host + return nil + } + + // Verify that results are ingested properly + svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}) + + // osquery_info + assert.Equal(t, "darwin", gotHost.Platform) + assert.Equal(t, "1.8.2", gotHost.OsqueryVersion) + + // system_info + assert.Equal(t, int64(17179869184), gotHost.Memory) + assert.Equal(t, "computer.local", gotHost.Hostname) + assert.Equal(t, "uuid", gotHost.UUID) + + // os_version + assert.Equal(t, "Mac OS X 10.10.6", gotHost.OSVersion) + + // uptime + assert.Equal(t, 1730893*time.Second, gotHost.Uptime) + + // osquery_flags + assert.Equal(t, uint(0), gotHost.ConfigTLSRefresh) + assert.Equal(t, uint(0), gotHost.DistributedInterval) + assert.Equal(t, uint(0), gotHost.LoggerTLSPeriod) + + host.Hostname = "computer.local" + host.DetailUpdatedAt = mockClock.Now() + mockClock.AddTime(1 * time.Minute) + + // Now no detail queries should be required + ctx = hostctx.NewContext(context.Background(), host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) + + // Advance clock and queries should exist again + mockClock.AddTime(1*time.Hour + 1*time.Minute) + + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + assert.Zero(t, acc) +} + +func TestDetailQueries(t *testing.T) { + ds := new(mock.Store) + mockClock := clock.NewMockClock() + lq := new(live_query.MockLiveQuery) + svc := newTestServiceWithClock(ds, nil, lq, mockClock) + + host := &fleet.Host{ + ID: 1, + Platform: "linux", + } + ctx := hostctx.NewContext(context.Background(), host) + + lq.On("QueriesForHost", host.ID).Return(map[string]string{}, nil) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true, EnableSoftwareInventory: true}}, nil + } + ds.LabelQueriesForHostFunc = func(context.Context, *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.SetOrUpdateMDMDataFunc = func(ctx context.Context, hostID uint, enrolled bool, serverURL string, installedFromDep bool) error { + require.True(t, enrolled) + require.False(t, installedFromDep) + require.Equal(t, "hi.com", serverURL) + return nil + } + ds.SetOrUpdateMunkiVersionFunc = func(ctx context.Context, hostID uint, version string) error { + require.Equal(t, "3.4.5", version) + return nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if id != 1 { + return nil, errors.New("not found") + } + return host, nil + } + + // With a new host, we should get the detail queries (and accelerated + // queries) + queries, acc, err := svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries-2) + assert.NotZero(t, acc) + + resultJSON := ` +{ +"fleet_detail_query_network_interface": [ + { + "address": "192.168.0.1", + "broadcast": "192.168.0.255", + "ibytes": "1601207629", + "ierrors": "314179", + "interface": "en0", + "ipackets": "25698094", + "last_change": "1474233476", + "mac": "5f:3d:4b:10:25:82", + "mask": "255.255.255.0", + "metric": "1", + "mtu": "1453", + "obytes": "2607283152", + "oerrors": "101010", + "opackets": "12264603", + "point_to_point": "", + "type": "6" + } +], +"fleet_detail_query_os_version": [ + { + "platform": "darwin", + "build": "15G1004", + "major": "10", + "minor": "10", + "name": "Mac OS X", + "patch": "6" + } +], +"fleet_detail_query_osquery_info": [ + { + "build_distro": "10.10", + "build_platform": "darwin", + "config_hash": "3c6e4537c4d0eb71a7c6dda19d", + "config_valid": "1", + "extensions": "active", + "pid": "38113", + "start_time": "1475603155", + "version": "1.8.2", + "watcher": "38112" + } +], +"fleet_detail_query_system_info": [ + { + "computer_name": "computer", + "cpu_brand": "Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz", + "cpu_logical_cores": "8", + "cpu_physical_cores": "4", + "cpu_subtype": "Intel x86-64h Haswell", + "cpu_type": "x86_64h", + "hardware_model": "MacBookPro11,4", + "hardware_serial": "ABCDEFGH", + "hardware_vendor": "Apple Inc.", + "hardware_version": "1.0", + "hostname": "computer.local", + "physical_memory": "17179869184", + "uuid": "uuid" + } +], +"fleet_detail_query_uptime": [ + { + "days": "20", + "hours": "0", + "minutes": "48", + "seconds": "13", + "total_seconds": "1730893" + } +], +"fleet_detail_query_osquery_flags": [ + { + "name":"config_tls_refresh", + "value":"10" + }, + { + "name":"config_refresh", + "value":"9" + }, + { + "name":"distributed_interval", + "value":"5" + }, + { + "name":"logger_tls_period", + "value":"60" + } +], +"fleet_detail_query_users": [ + { + "uid": "1234", + "username": "user1", + "type": "sometype", + "groupname": "somegroup", + "shell": "someloginshell" + }, + { + "uid": "5678", + "username": "user2", + "type": "sometype", + "groupname": "somegroup" + } +], +"fleet_detail_query_software_macos": [ + { + "name": "app1", + "version": "1.0.0", + "source": "source1" + }, + { + "name": "app2", + "version": "1.0.0", + "source": "source2", + "bundle_identifier": "somebundle" + } +], +"fleet_detail_query_disk_space_unix": [ + { + "percent_disk_space_available": "56", + "gigs_disk_space_available": "277.0" + } +], +"fleet_detail_query_mdm": [ + { + "enrolled": "true", + "server_url": "hi.com", + "installed_from_dep": "false" + } +], +"fleet_detail_query_munki_info": [ + { + "version": "3.4.5" + } +] +} +` + + var results fleet.OsqueryDistributedQueryResults + err = json.Unmarshal([]byte(resultJSON), &results) + require.NoError(t, err) + + var gotHost *fleet.Host + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + gotHost = host + return nil + } + var gotUsers []fleet.HostUser + ds.SaveHostUsersFunc = func(ctx context.Context, hostID uint, users []fleet.HostUser) error { + if hostID != 1 { + return errors.New("not found") + } + gotUsers = users + return nil + } + var gotSoftware []fleet.Software + ds.UpdateHostSoftwareFunc = func(ctx context.Context, hostID uint, software []fleet.Software) error { + if hostID != 1 { + return errors.New("not found") + } + gotSoftware = software + return nil + } + + // Verify that results are ingested properly + require.NoError(t, svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{})) + require.NotNil(t, gotHost) + + require.True(t, ds.SetOrUpdateMDMDataFuncInvoked) + require.True(t, ds.SetOrUpdateMunkiVersionFuncInvoked) + + // osquery_info + assert.Equal(t, "darwin", gotHost.Platform) + assert.Equal(t, "1.8.2", gotHost.OsqueryVersion) + + // system_info + assert.Equal(t, int64(17179869184), gotHost.Memory) + assert.Equal(t, "computer.local", gotHost.Hostname) + assert.Equal(t, "uuid", gotHost.UUID) + + // os_version + assert.Equal(t, "Mac OS X 10.10.6", gotHost.OSVersion) + + // uptime + assert.Equal(t, 1730893*time.Second, gotHost.Uptime) + + // osquery_flags + assert.Equal(t, uint(10), gotHost.ConfigTLSRefresh) + assert.Equal(t, uint(5), gotHost.DistributedInterval) + assert.Equal(t, uint(60), gotHost.LoggerTLSPeriod) + + // users + require.Len(t, gotUsers, 2) + assert.Equal(t, fleet.HostUser{ + Uid: 1234, + Username: "user1", + Type: "sometype", + GroupName: "somegroup", + Shell: "someloginshell", + }, gotUsers[0]) + assert.Equal(t, fleet.HostUser{ + Uid: 5678, + Username: "user2", + Type: "sometype", + GroupName: "somegroup", + Shell: "", + }, gotUsers[1]) + + // software + require.Len(t, gotSoftware, 2) + assert.Equal(t, []fleet.Software{ + { + Name: "app1", + Version: "1.0.0", + Source: "source1", + }, + { + Name: "app2", + Version: "1.0.0", + BundleIdentifier: "somebundle", + Source: "source2", + }, + }, gotSoftware) + + assert.Equal(t, 56.0, gotHost.PercentDiskSpaceAvailable) + assert.Equal(t, 277.0, gotHost.GigsDiskSpaceAvailable) + + host.Hostname = "computer.local" + host.Platform = "darwin" + host.DetailUpdatedAt = mockClock.Now() + mockClock.AddTime(1 * time.Minute) + + // Now no detail queries should be required + ctx = hostctx.NewContext(ctx, host) + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, 0) + assert.Zero(t, acc) + + // Advance clock and queries should exist again + mockClock.AddTime(1*time.Hour + 1*time.Minute) + + queries, acc, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+1) + assert.Zero(t, acc) +} + +func TestNewDistributedQueryCampaign(t *testing.T) { + ds := new(mock.Store) + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + rs := &mock.QueryResultStore{ + HealthCheckFunc: func() error { + return nil + }, + } + lq := &live_query.MockLiveQuery{} + mockClock := clock.NewMockClock() + svc := newTestServiceWithClock(ds, rs, lq, mockClock) + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { + return nil + } + var gotQuery *fleet.Query + ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { + gotQuery = query + query.ID = 42 + return query, nil + } + var gotCampaign *fleet.DistributedQueryCampaign + ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { + gotCampaign = camp + camp.ID = 21 + return camp, nil + } + var gotTargets []*fleet.DistributedQueryCampaignTarget + ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) { + gotTargets = append(gotTargets, target) + return target, nil + } + + ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { + return fleet.TargetMetrics{}, nil + } + ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { + return []uint{1, 3, 5}, nil + } + lq.On("RunQuery", "21", "select year, month, day, hour, minutes, seconds from time", []uint{1, 3, 5}).Return(nil) + viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{ + User: &fleet.User{ + ID: 0, + GlobalRole: ptr.String(fleet.RoleAdmin), + }, + }) + q := "select year, month, day, hour, minutes, seconds from time" + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { + return nil + } + campaign, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) + require.NoError(t, err) + assert.Equal(t, gotQuery.ID, gotCampaign.QueryID) + assert.True(t, ds.NewActivityFuncInvoked) + assert.Equal(t, []*fleet.DistributedQueryCampaignTarget{ + { + Type: fleet.TargetHost, + DistributedQueryCampaignID: campaign.ID, + TargetID: 2, + }, + { + Type: fleet.TargetLabel, + DistributedQueryCampaignID: campaign.ID, + TargetID: 1, + }, + }, gotTargets, + ) +} + +func TestDistributedQueryResults(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := newTestServiceWithClock(ds, rs, lq, mockClock) + + campaign := &fleet.DistributedQueryCampaign{ID: 42} + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + host := &fleet.Host{ + ID: 1, + Platform: "windows", + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + if id != 1 { + return nil, errors.New("not found") + } + return host, nil + } + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + if host.ID != 1 { + return errors.New("not found") + } + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil + } + + hostCtx := hostctx.NewContext(context.Background(), host) + + lq.On("QueriesForHost", uint(1)).Return( + map[string]string{ + strconv.Itoa(int(campaign.ID)): "select * from time", + }, + nil, + ) + lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil) + + // Now we should get the active distributed query + queries, acc, err := svc.GetDistributedQueries(hostCtx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries-2) + queryKey := fmt.Sprintf("%s%d", hostDistributedQueryPrefix, campaign.ID) + assert.Equal(t, "select * from time", queries[queryKey]) + assert.NotZero(t, acc) + + expectedRows := []map[string]string{ + { + "year": "2016", + "month": "11", + "day": "11", + "hour": "6", + "minutes": "12", + "seconds": "10", + }, + } + results := map[string][]map[string]string{ + queryKey: expectedRows, + } + + // TODO use service method + readChan, err := rs.ReadChannel(context.Background(), *campaign) + require.NoError(t, err) + + // We need to listen for the result in a separate thread to prevent the + // write to the result channel from failing + var waitSetup, waitComplete sync.WaitGroup + waitSetup.Add(1) + waitComplete.Add(1) + go func() { + waitSetup.Done() + select { + case val := <-readChan: + if res, ok := val.(fleet.DistributedQueryResult); ok { + assert.Equal(t, campaign.ID, res.DistributedQueryCampaignID) + assert.Equal(t, expectedRows, res.Rows) + assert.Equal(t, *host, res.Host) + } else { + t.Error("Wrong result type") + } + assert.NotNil(t, val) + + case <-time.After(1 * time.Second): + t.Error("No result received") + } + waitComplete.Done() + }() + + waitSetup.Wait() + // Sleep a short time to ensure that the above goroutine is blocking on + // the channel read (the waitSetup.Wait() is not necessarily sufficient + // if there is a context switch immediately after waitSetup.Done() is + // called). This should be a small price to pay to prevent flakiness in + // this test. + time.Sleep(10 * time.Millisecond) + + err = svc.SubmitDistributedQueryResults(hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}) + require.NoError(t, err) +} + +func TestIngestDistributedQueryParseIdError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + host := fleet.Host{ID: 1} + err := svc.ingestDistributedQuery(context.Background(), host, "bad_name", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to parse campaign") +} + +func TestIngestDistributedQueryOrphanedCampaignLoadError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { + return nil, errors.New("missing campaign") + } + + lq.On("StopQuery", "42").Return(nil) + + host := fleet.Host{ID: 1} + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "loading orphaned campaign") +} + +func TestIngestDistributedQueryOrphanedCampaignWaitListener(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &fleet.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + CreateTimestamp: fleet.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-1 * time.Second), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { + return campaign, nil + } + + host := fleet.Host{ID: 1} + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "campaign waiting for listener") +} + +func TestIngestDistributedQueryOrphanedCloseError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &fleet.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + CreateTimestamp: fleet.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-2 * time.Minute), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { + return campaign, nil + } + ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { + return errors.New("failed save") + } + + host := fleet.Host{ID: 1} + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "closing orphaned campaign") +} + +func TestIngestDistributedQueryOrphanedStopError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &fleet.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + CreateTimestamp: fleet.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-2 * time.Minute), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { + return campaign, nil + } + ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { + return nil + } + lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(errors.New("failed")) + + host := fleet.Host{ID: 1} + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "stopping orphaned campaign") +} + +func TestIngestDistributedQueryOrphanedStop(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &fleet.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + CreateTimestamp: fleet.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-2 * time.Minute), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { + return campaign, nil + } + ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { + return nil + } + lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(nil) + + host := fleet.Host{ID: 1} + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "campaign stopped") + lq.AssertExpectations(t) +} + +func TestIngestDistributedQueryRecordCompletionError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &fleet.DistributedQueryCampaign{ID: 42} + host := fleet.Host{ID: 1} + + lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(errors.New("fail")) + + go func() { + ch, err := rs.ReadChannel(context.Background(), *campaign) + require.NoError(t, err) + <-ch + }() + time.Sleep(10 * time.Millisecond) + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "record query completion") + lq.AssertExpectations(t) +} + +func TestIngestDistributedQuery(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := &Service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &fleet.DistributedQueryCampaign{ID: 42} + host := fleet.Host{ID: 1} + + lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil) + + go func() { + ch, err := rs.ReadChannel(context.Background(), *campaign) + require.NoError(t, err) + <-ch + }() + time.Sleep(10 * time.Millisecond) + + err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") + require.NoError(t, err) + lq.AssertExpectations(t) +} + +func TestUpdateHostIntervals(t *testing.T) { + ds := new(mock.Store) + + svc := newTestService(ds, nil, nil) + + ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) { + return []*fleet.Pack{}, nil + } + + testCases := []struct { + name string + initIntervals fleet.HostOsqueryIntervals + finalIntervals fleet.HostOsqueryIntervals + configOptions json.RawMessage + updateIntervalsCalled bool + }{ + { + "Both updated", + fleet.HostOsqueryIntervals{ + ConfigTLSRefresh: 60, + }, + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + json.RawMessage(`{"options": { + "distributed_interval": 11, + "logger_tls_period": 33, + "logger_plugin": "tls" + }}`), + true, + }, + { + "Only logger_tls_period updated", + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + ConfigTLSRefresh: 60, + }, + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + json.RawMessage(`{"options": { + "distributed_interval": 11, + "logger_tls_period": 33 + }}`), + true, + }, + { + "Only distributed_interval updated", + fleet.HostOsqueryIntervals{ + ConfigTLSRefresh: 60, + LoggerTLSPeriod: 33, + }, + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + json.RawMessage(`{"options": { + "distributed_interval": 11, + "logger_tls_period": 33 + }}`), + true, + }, + { + "Fleet not managing distributed_interval", + fleet.HostOsqueryIntervals{ + ConfigTLSRefresh: 60, + DistributedInterval: 11, + }, + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + json.RawMessage(`{"options":{ + "logger_tls_period": 33 + }}`), + true, + }, + { + "config_refresh should also cause an update", + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 42, + }, + json.RawMessage(`{"options":{ + "distributed_interval": 11, + "logger_tls_period": 33, + "config_refresh": 42 + }}`), + true, + }, + { + "update intervals should not be called with no changes", + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + fleet.HostOsqueryIntervals{ + DistributedInterval: 11, + LoggerTLSPeriod: 33, + ConfigTLSRefresh: 60, + }, + json.RawMessage(`{"options":{ + "distributed_interval": 11, + "logger_tls_period": 33 + }}`), + false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ctx := hostctx.NewContext(context.Background(), &fleet.Host{ + ID: 1, + NodeKey: "123456", + DistributedInterval: tt.initIntervals.DistributedInterval, + ConfigTLSRefresh: tt.initIntervals.ConfigTLSRefresh, + LoggerTLSPeriod: tt.initIntervals.LoggerTLSPeriod, + }) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{AgentOptions: ptr.RawMessage(json.RawMessage(`{"config":` + string(tt.configOptions) + `}`))}, nil + } + + updateIntervalsCalled := false + ds.UpdateHostOsqueryIntervalsFunc = func(ctx context.Context, hostID uint, intervals fleet.HostOsqueryIntervals) error { + if hostID != 1 { + return errors.New("not found") + } + updateIntervalsCalled = true + assert.Equal(t, tt.finalIntervals, intervals) + return nil + } + + _, err := svc.GetClientConfig(ctx) + require.NoError(t, err) + assert.Equal(t, tt.updateIntervalsCalled, updateIntervalsCalled) + }) + } +} + +type notFoundError struct{} + +func (e notFoundError) Error() string { + return "not found" +} + +func (e notFoundError) IsNotFound() bool { + return true +} + +func TestAuthenticationErrors(t *testing.T) { + ms := new(mock.Store) + ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + return nil, nil + } + + svc := newTestService(ms, nil, nil) + ctx := context.Background() + + _, _, err := svc.AuthenticateHost(ctx, "") + require.Error(t, err) + require.True(t, err.(osqueryError).NodeInvalid()) + + ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + return &fleet.Host{ID: 1}, nil + } + ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + _, _, err = svc.AuthenticateHost(ctx, "foo") + require.NoError(t, err) + + // return not found error + ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + return nil, notFoundError{} + } + + _, _, err = svc.AuthenticateHost(ctx, "foo") + require.Error(t, err) + require.True(t, err.(osqueryError).NodeInvalid()) + + // return other error + ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { + return nil, errors.New("foo") + } + + _, _, err = svc.AuthenticateHost(ctx, "foo") + require.NotNil(t, err) + require.False(t, err.(osqueryError).NodeInvalid()) +} + +func TestGetHostIdentifier(t *testing.T) { + t.Parallel() + + details := map[string](map[string]string){ + "osquery_info": map[string]string{ + "uuid": "foouuid", + "instance_id": "fooinstance", + }, + "system_info": map[string]string{ + "hostname": "foohost", + }, + } + + emptyDetails := map[string](map[string]string){ + "osquery_info": map[string]string{ + "uuid": "", + "instance_id": "", + }, + "system_info": map[string]string{ + "hostname": "", + }, + } + + testCases := []struct { + identifierOption string + providedIdentifier string + details map[string](map[string]string) + expected string + shouldPanic bool + }{ + // Panix + {identifierOption: "bad", shouldPanic: true}, + {identifierOption: "", shouldPanic: true}, + + // Missing details + {identifierOption: "instance", providedIdentifier: "foobar", expected: "foobar"}, + {identifierOption: "uuid", providedIdentifier: "foobar", expected: "foobar"}, + {identifierOption: "hostname", providedIdentifier: "foobar", expected: "foobar"}, + {identifierOption: "provided", providedIdentifier: "foobar", expected: "foobar"}, + + // Empty details + {identifierOption: "instance", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, + {identifierOption: "uuid", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, + {identifierOption: "hostname", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, + {identifierOption: "provided", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, + + // Successes + {identifierOption: "instance", providedIdentifier: "foobar", details: details, expected: "fooinstance"}, + {identifierOption: "uuid", providedIdentifier: "foobar", details: details, expected: "foouuid"}, + {identifierOption: "hostname", providedIdentifier: "foobar", details: details, expected: "foohost"}, + {identifierOption: "provided", providedIdentifier: "foobar", details: details, expected: "foobar"}, + } + logger := log.NewNopLogger() + + for _, tt := range testCases { + t.Run("", func(t *testing.T) { + if tt.shouldPanic { + assert.Panics( + t, + func() { getHostIdentifier(logger, tt.identifierOption, tt.providedIdentifier, tt.details) }, + ) + return + } + + assert.Equal( + t, + tt.expected, + getHostIdentifier(logger, tt.identifierOption, tt.providedIdentifier, tt.details), + ) + }) + } +} + +func TestDistributedQueriesLogsManyErrors(t *testing.T) { + buf := new(bytes.Buffer) + logger := log.NewJSONLogger(buf) + logger = level.NewFilter(logger, level.AllowDebug()) + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + host := &fleet.Host{ + ID: 1, + Platform: "darwin", + } + + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + return authz.CheckMissingWithResponse(nil) + } + ds.RecordLabelQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferred bool) error { + return errors.New("something went wrong") + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + ds.SaveHostAdditionalFunc = func(ctx context.Context, hostID uint, additional *json.RawMessage) error { + return errors.New("something went wrong") + } + + lCtx := &fleetLogging.LoggingContext{} + ctx := fleetLogging.NewContext(context.Background(), lCtx) + ctx = hostctx.NewContext(ctx, host) + + err := svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostDetailQueryPrefix + "network_interface": {{"col1": "val1"}}, // we need one detail query that updates hosts. + hostLabelQueryPrefix + "1": {{"col1": "val1"}}, + hostAdditionalQueryPrefix + "1": {{"col1": "val1"}}, + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + + lCtx.Log(ctx, logger) + + logs := buf.String() + parts := strings.Split(strings.TrimSpace(logs), "\n") + require.Len(t, parts, 1) + logData := make(map[string]json.RawMessage) + err = json.Unmarshal([]byte(parts[0]), &logData) + require.NoError(t, err) + assert.Equal(t, json.RawMessage(`"something went wrong || something went wrong"`), logData["err"]) + assert.Equal(t, json.RawMessage(`"Missing authorization check"`), logData["internal"]) +} + +func TestDistributedQueriesReloadsHostIfDetailsAreIn(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + host := &fleet.Host{ + ID: 42, + Platform: "darwin", + } + + ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + ctx := hostctx.NewContext(context.Background(), host) + + err := svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostDetailQueryPrefix + "network_interface": {{"col1": "val1"}}, + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + assert.True(t, ds.UpdateHostFuncInvoked) +} + +func TestObserversCanOnlyRunDistributedCampaigns(t *testing.T) { + ds := new(mock.Store) + rs := &mock.QueryResultStore{ + HealthCheckFunc: func() error { + return nil + }, + } + lq := &live_query.MockLiveQuery{} + mockClock := clock.NewMockClock() + svc := newTestServiceWithClock(ds, rs, lq, mockClock) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { + return camp, nil + } + ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { + return &fleet.Query{ + ID: 42, + Name: "query", + Query: "select 1;", + ObserverCanRun: false, + }, nil + } + viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{ + User: &fleet.User{ID: 0, GlobalRole: ptr.String(fleet.RoleObserver)}, + }) + + q := "select year, month, day, hour, minutes, seconds from time" + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { + return nil + } + _, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) + require.Error(t, err) + + _, err = svc.NewDistributedQueryCampaign(viewerCtx, "", ptr.Uint(42), fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) + require.Error(t, err) + + ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { + return &fleet.Query{ + ID: 42, + Name: "query", + Query: "select 1;", + ObserverCanRun: true, + }, nil + } + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { return nil } + ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { + camp.ID = 21 + return camp, nil + } + ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) { + return target, nil + } + ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { + return fleet.TargetMetrics{}, nil + } + ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { + return []uint{1, 3, 5}, nil + } + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { + return nil + } + lq.On("RunQuery", "21", "select 1;", []uint{1, 3, 5}).Return(nil) + _, err = svc.NewDistributedQueryCampaign(viewerCtx, "", ptr.Uint(42), fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) + require.NoError(t, err) +} + +func TestTeamMaintainerCanRunNewDistributedCampaigns(t *testing.T) { + ds := new(mock.Store) + rs := &mock.QueryResultStore{ + HealthCheckFunc: func() error { + return nil + }, + } + lq := &live_query.MockLiveQuery{} + mockClock := clock.NewMockClock() + svc := newTestServiceWithClock(ds, rs, lq, mockClock) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { + return camp, nil + } + ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { + return &fleet.Query{ + ID: 42, + AuthorID: ptr.Uint(99), + Name: "query", + Query: "select 1;", + ObserverCanRun: false, + }, nil + } + viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{ + User: &fleet.User{ID: 99, Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 123}, Role: fleet.RoleMaintainer}}}, + }) + + q := "select year, month, day, hour, minutes, seconds from time" + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { + return nil + } + // var gotQuery *fleet.Query + ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { + // gotQuery = query + query.ID = 42 + return query, nil + } + ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) { + return target, nil + } + ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { + return fleet.TargetMetrics{}, nil + } + ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { + return []uint{1, 3, 5}, nil + } + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { + return nil + } + lq.On("RunQuery", "0", "select year, month, day, hour, minutes, seconds from time", []uint{1, 3, 5}).Return(nil) + _, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}, TeamIDs: []uint{123}}) + require.NoError(t, err) +} + +func TestPolicyQueries(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + lq := new(live_query.MockLiveQuery) + svc := newTestServiceWithClock(ds, nil, lq, mockClock) + + host := &fleet.Host{ + Platform: "darwin", + } + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + return host, nil + } + ds.UpdateHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { + host = gotHost + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil + } + + lq.On("QueriesForHost", uint(0)).Return(map[string]string{}, nil) + + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{"1": "select 1", "2": "select 42;"}, nil + } + recordedResults := make(map[uint]*bool) + ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, deferred bool) error { + recordedResults = results + host = gotHost + return nil + } + ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { + return nil, nil, nil + } + + ctx := hostctx.NewContext(context.Background(), host) + + queries, _, err := svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+2) + + checkPolicyResults := func(queries map[string]string) { + hasPolicy1, hasPolicy2 := false, false + for name := range queries { + if strings.HasPrefix(name, hostPolicyQueryPrefix) { + if name[len(hostPolicyQueryPrefix):] == "1" { + hasPolicy1 = true + } + if name[len(hostPolicyQueryPrefix):] == "2" { + hasPolicy2 = true + } + } + } + assert.True(t, hasPolicy1) + assert.True(t, hasPolicy2) + } + + checkPolicyResults(queries) + + // Record a query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, + hostPolicyQueryPrefix + "2": {}, + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, + }, + map[string]string{}, + ) + require.NoError(t, err) + require.Len(t, recordedResults, 2) + require.NotNil(t, recordedResults[1]) + require.True(t, *recordedResults[1]) + result, ok := recordedResults[2] + require.True(t, ok) + require.Nil(t, result) + + noPolicyResults := func(queries map[string]string) { + hasAnyPolicy := false + for name := range queries { + if strings.HasPrefix(name, hostPolicyQueryPrefix) { + hasAnyPolicy = true + break + } + } + assert.False(t, hasAnyPolicy) + } + + // After the first time we get policies and update the host, then there shouldn't be any policies. + ctx = hostctx.NewContext(context.Background(), host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) + + // Let's move time forward, there should be policies now. + mockClock.AddTime(2 * time.Hour) + + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+2) + checkPolicyResults(queries) + + // Record another query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, + hostPolicyQueryPrefix + "2": {}, + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, + }, + map[string]string{}, + ) + require.NoError(t, err) + require.Len(t, recordedResults, 2) + require.NotNil(t, recordedResults[1]) + require.True(t, *recordedResults[1]) + result, ok = recordedResults[2] + require.True(t, ok) + require.Nil(t, result) + + // There shouldn't be any policies now. + ctx = hostctx.NewContext(context.Background(), host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) + + // With refetch requested policy queries should be returned. + host.RefetchRequested = true + ctx = hostctx.NewContext(context.Background(), host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+2) + checkPolicyResults(queries) + + // Record another query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, + hostPolicyQueryPrefix + "2": {}, + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, + }, + map[string]string{}, + ) + require.NoError(t, err) + require.NotNil(t, recordedResults[1]) + require.True(t, *recordedResults[1]) + result, ok = recordedResults[2] + require.True(t, ok) + require.Nil(t, result) + + // SubmitDistributedQueryResults will set RefetchRequested to false. + require.False(t, host.RefetchRequested) + + // There shouldn't be any policies now. + ctx = hostctx.NewContext(context.Background(), host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) +} + +func TestPolicyWebhooks(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + lq := new(live_query.MockLiveQuery) + pool := redistest.SetupRedis(t, t.Name(), false, false, false) + failingPolicySet := redis_policy_set.NewFailingTest(t, pool) + testConfig := config.TestConfig() + svc := newTestServiceWithConfig(ds, testConfig, nil, lq, TestServerOpts{ + FailingPolicySet: failingPolicySet, + Clock: mockClock, + }) + + host := &fleet.Host{ + ID: 5, + Platform: "darwin", + Hostname: "test.hostname", + } + + lq.On("QueriesForHost", uint(5)).Return(map[string]string{}, nil) + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + return host, nil + } + ds.UpdateHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { + host = gotHost + return nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{ + HostSettings: fleet.HostSettings{ + EnableHostUsers: true, + }, + WebhookSettings: fleet.WebhookSettings{ + FailingPoliciesWebhook: fleet.FailingPoliciesWebhookSettings{ + Enable: true, + PolicyIDs: []uint{1, 2, 3}, + }, + }, + }, nil + } + + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{ + "1": "select 1;", // passing policy + "2": "select * from unexistent_table;", // policy that fails to execute (e.g. missing table) + "3": "select 1 where 1 = 0;", // failing policy + }, nil + } + recordedResults := make(map[uint]*bool) + ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, deferred bool) error { + recordedResults = results + host = gotHost + return nil + } + ctx := hostctx.NewContext(context.Background(), host) + + queries, _, err := svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+3) + + checkPolicyResults := func(queries map[string]string) { + hasPolicy1, hasPolicy2, hasPolicy3 := false, false, false + for name := range queries { + if strings.HasPrefix(name, hostPolicyQueryPrefix) { + switch name[len(hostPolicyQueryPrefix):] { + case "1": + hasPolicy1 = true + case "2": + hasPolicy2 = true + case "3": + hasPolicy3 = true + } + } + } + assert.True(t, hasPolicy1) + assert.True(t, hasPolicy2) + assert.True(t, hasPolicy3) + } + + checkPolicyResults(queries) + + ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { + return []uint{3}, nil, nil + } + + // Record a query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, // succeeds + hostPolicyQueryPrefix + "2": {}, // didn't execute + hostPolicyQueryPrefix + "3": {}, // fails + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, // didn't execute + }, + map[string]string{}, + ) + require.NoError(t, err) + require.Len(t, recordedResults, 3) + require.NotNil(t, recordedResults[1]) + require.True(t, *recordedResults[1]) + result, ok := recordedResults[2] + require.True(t, ok) + require.Nil(t, result) + require.NotNil(t, recordedResults[3]) + require.False(t, *recordedResults[3]) + + cmpSets := func(expSets map[uint][]fleet.PolicySetHost) error { + actualSets, err := failingPolicySet.ListSets() + if err != nil { + return err + } + var expSets_ []uint + for expSet := range expSets { + expSets_ = append(expSets_, expSet) + } + sort.Slice(expSets_, func(i, j int) bool { + return expSets_[i] < expSets_[j] + }) + sort.Slice(actualSets, func(i, j int) bool { + return actualSets[i] < actualSets[j] + }) + if !reflect.DeepEqual(actualSets, expSets_) { + return fmt.Errorf("sets mismatch: %+v vs %+v", actualSets, expSets_) + } + for expID, expHosts := range expSets { + actualHosts, err := failingPolicySet.ListHosts(expID) + if err != nil { + return err + } + sort.Slice(actualHosts, func(i, j int) bool { + return actualHosts[i].ID < actualHosts[j].ID + }) + sort.Slice(expHosts, func(i, j int) bool { + return expHosts[i].ID < expHosts[j].ID + }) + if !reflect.DeepEqual(actualHosts, expHosts) { + return fmt.Errorf("hosts mismatch %d: %+v vs %+v", expID, actualHosts, expHosts) + } + } + return nil + } + + assert.Eventually(t, func() bool { + err = cmpSets(map[uint][]fleet.PolicySetHost{ + 3: {{ + ID: host.ID, + Hostname: host.Hostname, + }}, + }) + return err == nil + }, 1*time.Minute, 250*time.Millisecond) + require.NoError(t, err) + + noPolicyResults := func(queries map[string]string) { + hasAnyPolicy := false + for name := range queries { + if strings.HasPrefix(name, hostPolicyQueryPrefix) { + hasAnyPolicy = true + break + } + } + assert.False(t, hasAnyPolicy) + } + + // After the first time we get policies and update the host, then there shouldn't be any policies. + ctx = hostctx.NewContext(context.Background(), host) + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + noPolicyResults(queries) + + // Let's move time forward, there should be policies now. + mockClock.AddTime(2 * time.Hour) + + queries, _, err = svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries+3) + checkPolicyResults(queries) + + ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { + return []uint{1}, []uint{3}, nil + } + + // Record another query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {}, // 1 now fails + hostPolicyQueryPrefix + "2": {}, // didn't execute + hostPolicyQueryPrefix + "3": {{"col1": "val1"}}, // 1 now succeeds + }, + map[string]fleet.OsqueryStatus{ + hostPolicyQueryPrefix + "2": 1, // didn't execute + }, + map[string]string{}, + ) + require.NoError(t, err) + require.Len(t, recordedResults, 3) + require.NotNil(t, recordedResults[1]) + require.False(t, *recordedResults[1]) + result, ok = recordedResults[2] + require.True(t, ok) + require.Nil(t, result) + require.NotNil(t, recordedResults[3]) + require.True(t, *recordedResults[3]) + + assert.Eventually(t, func() bool { + err = cmpSets(map[uint][]fleet.PolicySetHost{ + 1: {{ + ID: host.ID, + Hostname: host.Hostname, + }}, + 3: {}, + }) + return err == nil + }, 1*time.Minute, 250*time.Millisecond) + require.NoError(t, err) + + // Simulate webhook trigger by removing the hosts. + err = failingPolicySet.RemoveHosts(1, []fleet.PolicySetHost{{ + ID: host.ID, + Hostname: host.Hostname, + }}) + require.NoError(t, err) + + ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { + return []uint{}, []uint{2}, nil + } + + // Record another query execution. + err = svc.SubmitDistributedQueryResults( + ctx, + map[string][]map[string]string{ + hostPolicyQueryPrefix + "1": {}, // continues to fail + hostPolicyQueryPrefix + "2": {{"col1": "val1"}}, // now passes + hostPolicyQueryPrefix + "3": {{"col1": "val1"}}, // continues to succeed + }, + map[string]fleet.OsqueryStatus{}, + map[string]string{}, + ) + require.NoError(t, err) + require.Len(t, recordedResults, 3) + require.NotNil(t, recordedResults[1]) + require.False(t, *recordedResults[1]) + require.NotNil(t, recordedResults[2]) + require.True(t, *recordedResults[2]) + require.NotNil(t, recordedResults[3]) + require.True(t, *recordedResults[3]) + + assert.Eventually(t, func() bool { + err = cmpSets(map[uint][]fleet.PolicySetHost{ + 1: {}, + 3: {}, + }) + return err == nil + }, 1*time.Minute, 250*time.Millisecond) + require.NoError(t, err) +} + +// If the live query store (Redis) is down we still (see #3503) +// want hosts to get queries and continue to check in. +func TestLiveQueriesFailing(t *testing.T) { + ds := new(mock.Store) + lq := new(live_query.MockLiveQuery) + cfg := config.TestConfig() + buf := new(bytes.Buffer) + logger := log.NewLogfmtLogger(buf) + svc := newTestServiceWithConfig(ds, cfg, nil, lq, TestServerOpts{ + Logger: logger, + }) + + hostID := uint(1) + host := &fleet.Host{ + ID: hostID, + Platform: "darwin", + } + lq.On("QueriesForHost", hostID).Return( + map[string]string{}, + errors.New("failed to get queries for host"), + ) + + ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { + return host, nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil + } + ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { + return map[string]string{}, nil + } + + ctx := hostctx.NewContext(context.Background(), host) + + queries, _, err := svc.GetDistributedQueries(ctx) + require.NoError(t, err) + require.Len(t, queries, expectedDetailQueries) + + logs, err := ioutil.ReadAll(buf) + require.NoError(t, err) + require.Contains(t, string(logs), "level=error") + require.Contains(t, string(logs), "failed to get queries for host") +} diff --git a/server/service/service.go b/server/service/service.go index 62c6622c3d..895d66d375 100644 --- a/server/service/service.go +++ b/server/service/service.go @@ -90,7 +90,7 @@ func NewService( return validationMiddleware{svc, ds, sso}, nil } -func (s Service) SendEmail(mail fleet.Email) error { +func (s *Service) SendEmail(mail fleet.Email) error { return s.mailService.SendEmail(mail) } diff --git a/server/service/service_carves.go b/server/service/service_carves.go deleted file mode 100644 index 5da8ba10fd..0000000000 --- a/server/service/service_carves.go +++ /dev/null @@ -1,46 +0,0 @@ -package service - -import ( - "context" - "errors" - "fmt" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/fleet" -) - -func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayload) error { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - // Note host did not authenticate via node key. We need to authenticate them - // by the session ID and request ID - carve, err := svc.carveStore.CarveBySessionId(ctx, payload.SessionId) - if err != nil { - return ctxerr.Wrap(ctx, err, "find carve by session_id") - } - - if payload.RequestId != carve.RequestId { - return errors.New("request_id does not match") - } - - // Request is now authenticated - - if payload.BlockId > carve.BlockCount-1 { - return fmt.Errorf("block_id exceeds expected max (%d): %d", carve.BlockCount-1, payload.BlockId) - } - - if payload.BlockId != carve.MaxBlock+1 { - return fmt.Errorf("block_id does not match expected block (%d): %d", carve.MaxBlock+1, payload.BlockId) - } - - if int64(len(payload.Data)) > carve.BlockSize { - return fmt.Errorf("exceeded declared block size %d: %d", carve.BlockSize, len(payload.Data)) - } - - if err := svc.carveStore.NewBlock(ctx, carve, payload.BlockId, payload.Data); err != nil { - return ctxerr.Wrap(ctx, err, "save block data") - } - - return nil -} diff --git a/server/service/service_carves_test.go b/server/service/service_carves_test.go deleted file mode 100644 index 209b2b3ad4..0000000000 --- a/server/service/service_carves_test.go +++ /dev/null @@ -1,432 +0,0 @@ -package service - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/fleetdm/fleet/v4/server/fleet" - - hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" - "github.com/fleetdm/fleet/v4/server/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCarveBegin(t *testing.T) { - host := fleet.Host{ID: 3} - payload := fleet.CarveBeginPayload{ - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - } - ms := new(mock.Store) - ds := new(mock.Store) - svc := &Service{ - carveStore: ms, - ds: ds, - } - expectedMetadata := fleet.CarveMetadata{ - ID: 7, - HostId: host.ID, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - } - ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) { - metadata.ID = 7 - return metadata, nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if host.ID != id { - return nil, errors.New("not found") - } - return &fleet.Host{ - Hostname: host.Hostname, - }, nil - } - - ctx := hostctx.NewContext(context.Background(), &host) - - metadata, err := svc.CarveBegin(ctx, payload) - require.NoError(t, err) - assert.NotEmpty(t, metadata.SessionId) - metadata.SessionId = "" // Clear this before comparison - metadata.Name = "" // Clear this before comparison - metadata.CreatedAt = time.Time{} // Clear this before comparison - assert.Equal(t, expectedMetadata, *metadata) -} - -func TestCarveBeginNewCarveError(t *testing.T) { - host := fleet.Host{ID: 3} - payload := fleet.CarveBeginPayload{ - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - } - ms := new(mock.Store) - ds := new(mock.Store) - svc := &Service{ - carveStore: ms, - ds: ds, - } - ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) { - return nil, errors.New("ouch!") - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if host.ID != id { - return nil, errors.New("not found") - } - return &fleet.Host{ - Hostname: host.Hostname, - }, nil - } - - ctx := hostctx.NewContext(context.Background(), &host) - - _, err := svc.CarveBegin(ctx, payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "ouch!") -} - -func TestCarveBeginEmptyError(t *testing.T) { - ms := new(mock.Store) - ds := new(mock.Store) - svc := &Service{ - carveStore: ms, - ds: ds, - } - ctx := hostctx.NewContext(context.Background(), &fleet.Host{ID: 1}) - - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if id != 1 { - return nil, errors.New("not found") - } - return &fleet.Host{}, nil - } - - _, err := svc.CarveBegin(ctx, fleet.CarveBeginPayload{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "carve_size must be greater than 0") -} - -func TestCarveBeginMissingHostError(t *testing.T) { - ms := new(mock.Store) - svc := &Service{carveStore: ms} - - _, err := svc.CarveBegin(context.Background(), fleet.CarveBeginPayload{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "missing host") -} - -func TestCarveBeginBlockSizeMaxError(t *testing.T) { - host := fleet.Host{ID: 3} - payload := fleet.CarveBeginPayload{ - BlockCount: 10, - BlockSize: 1024 * 1024 * 1024 * 1024, // 1TB - CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB - RequestId: "carve_request", - } - ms := new(mock.Store) - ds := new(mock.Store) - svc := &Service{ - carveStore: ms, - ds: ds, - } - - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if host.ID != id { - return nil, errors.New("not found") - } - return &fleet.Host{ - Hostname: host.Hostname, - }, nil - } - - ctx := hostctx.NewContext(context.Background(), &host) - - _, err := svc.CarveBegin(ctx, payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "block_size exceeds max") -} - -func TestCarveBeginCarveSizeMaxError(t *testing.T) { - host := fleet.Host{ID: 3} - payload := fleet.CarveBeginPayload{ - BlockCount: 1024 * 1024, - BlockSize: 10 * 1024 * 1024, // 1TB - CarveSize: 10 * 1024 * 1024 * 1024 * 1024, // 10TB - RequestId: "carve_request", - } - ms := new(mock.Store) - ds := new(mock.Store) - svc := &Service{ - carveStore: ms, - ds: ds, - } - - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if host.ID != id { - return nil, errors.New("not found") - } - return &fleet.Host{ - Hostname: host.Hostname, - }, nil - } - - ctx := hostctx.NewContext(context.Background(), &host) - - _, err := svc.CarveBegin(ctx, payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "carve_size exceeds max") -} - -func TestCarveBeginCarveSizeError(t *testing.T) { - host := fleet.Host{ID: 3} - payload := fleet.CarveBeginPayload{ - BlockCount: 7, - BlockSize: 13, - CarveSize: 7*13 + 1, - RequestId: "carve_request", - } - ms := new(mock.Store) - ds := new(mock.Store) - svc := &Service{ - carveStore: ms, - ds: ds, - } - ctx := hostctx.NewContext(context.Background(), &host) - - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if host.ID != id { - return nil, errors.New("not found") - } - return &fleet.Host{ - Hostname: host.Hostname, - }, nil - } - - // Too big - _, err := svc.CarveBegin(ctx, payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "carve_size does not match") - - // Too small - payload.CarveSize = 6 * 13 - _, err = svc.CarveBegin(ctx, payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "carve_size does not match") -} - -func TestCarveCarveBlockGetCarveError(t *testing.T) { - sessionId := "foobar" - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - return nil, errors.New("ouch!") - } - - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :)"), - SessionId: sessionId, - } - - err := svc.CarveBlock(context.Background(), payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "ouch!") -} - -func TestCarveCarveBlockRequestIdError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.SessionId, sessionId) - return metadata, nil - } - - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :)"), - RequestId: "not_matching", - SessionId: sessionId, - } - - err := svc.CarveBlock(context.Background(), payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "request_id does not match") -} - -func TestCarveCarveBlockBlockCountExceedError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.SessionId, sessionId) - return metadata, nil - } - - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :)"), - RequestId: "carve_request", - SessionId: sessionId, - BlockId: 23, - } - - err := svc.CarveBlock(context.Background(), payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "block_id exceeds expected max") -} - -func TestCarveCarveBlockBlockCountMatchError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.SessionId, sessionId) - return metadata, nil - } - - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :)"), - RequestId: "carve_request", - SessionId: sessionId, - BlockId: 7, - } - - err := svc.CarveBlock(context.Background(), payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "block_id does not match") -} - -func TestCarveCarveBlockBlockSizeError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 16, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.SessionId, sessionId) - return metadata, nil - } - - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :) TOO LONG!!!"), - RequestId: "carve_request", - SessionId: sessionId, - BlockId: 4, - } - - err := svc.CarveBlock(context.Background(), payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "exceeded declared block size") -} - -func TestCarveCarveBlockNewBlockError(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.SessionId, sessionId) - return metadata, nil - } - ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error { - return errors.New("kaboom!") - } - - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :)"), - RequestId: "carve_request", - SessionId: sessionId, - BlockId: 4, - } - - err := svc.CarveBlock(context.Background(), payload) - require.Error(t, err) - assert.Contains(t, err.Error(), "kaboom!") -} - -func TestCarveCarveBlock(t *testing.T) { - sessionId := "foobar" - metadata := &fleet.CarveMetadata{ - ID: 2, - HostId: 3, - BlockCount: 23, - BlockSize: 64, - CarveSize: 23 * 64, - RequestId: "carve_request", - SessionId: sessionId, - MaxBlock: 3, - } - payload := fleet.CarveBlockPayload{ - Data: []byte("this is the carve data :)"), - RequestId: "carve_request", - SessionId: sessionId, - BlockId: 4, - } - ms := new(mock.Store) - svc := &Service{carveStore: ms} - ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) { - assert.Equal(t, metadata.SessionId, sessionId) - return metadata, nil - } - ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error { - assert.Equal(t, metadata, carve) - assert.Equal(t, int64(4), blockId) - assert.Equal(t, payload.Data, data) - return nil - } - - err := svc.CarveBlock(context.Background(), payload) - require.NoError(t, err) - assert.True(t, ms.NewBlockFuncInvoked) -} diff --git a/server/service/service_invites.go b/server/service/service_invites.go deleted file mode 100644 index 7007214c4e..0000000000 --- a/server/service/service_invites.go +++ /dev/null @@ -1,34 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/contexts/logging" - - "github.com/fleetdm/fleet/v4/server/fleet" -) - -func (svc *Service) VerifyInvite(ctx context.Context, token string) (*fleet.Invite, error) { - // skipauth: There is no viewer context at this point. We rely on verifying - // the invite for authNZ. - svc.authz.SkipAuthorization(ctx) - - logging.WithExtras(ctx, "token", token) - - invite, err := svc.ds.InviteByToken(ctx, token) - if err != nil { - return nil, err - } - - if invite.Token != token { - return nil, fleet.NewInvalidArgumentError("invite_token", "Invite Token does not match Email Address.") - } - - expiresAt := invite.CreatedAt.Add(svc.config.App.InviteTokenValidityPeriod) - if svc.clock.Now().After(expiresAt) { - return nil, fleet.NewInvalidArgumentError("invite_token", "Invite token has expired.") - } - - return invite, nil - -} diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go deleted file mode 100644 index 00f005c847..0000000000 --- a/server/service/service_osquery.go +++ /dev/null @@ -1,317 +0,0 @@ -package service - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/fleetdm/fleet/v4/server" - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/contexts/logging" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/service/osquery_utils" - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" -) - -type osqueryError struct { - message string - nodeInvalid bool -} - -func (e osqueryError) Error() string { - return e.message -} - -func (e osqueryError) NodeInvalid() bool { - return e.nodeInvalid -} - -var counter = int64(0) - -func (svc Service) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, bool, error) { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - if nodeKey == "" { - return nil, false, osqueryError{ - message: "authentication error: missing node key", - nodeInvalid: true, - } - } - - host, err := svc.ds.LoadHostByNodeKey(ctx, nodeKey) - switch { - case err == nil: - // OK - case fleet.IsNotFound(err): - return nil, false, osqueryError{ - message: "authentication error: invalid node key: " + nodeKey, - nodeInvalid: true, - } - default: - return nil, false, osqueryError{ - message: "authentication error: " + err.Error(), - } - } - - // Update the "seen" time used to calculate online status. These updates are - // batched for MySQL performance reasons. Because this is done - // asynchronously, it is possible for the server to shut down before - // updating the seen time for these hosts. This seems to be an acceptable - // tradeoff as an online host will continue to check in and quickly be - // marked online again. - svc.seenHostSet.addHostID(host.ID) - host.SeenTime = svc.clock.Now() - - return host, svc.debugEnabledForHost(ctx, host.ID), nil -} - -func (svc Service) debugEnabledForHost(ctx context.Context, id uint) bool { - hlogger := log.With(svc.logger, "host-id", id) - ac, err := svc.ds.AppConfig(ctx) - if err != nil { - level.Debug(hlogger).Log("err", ctxerr.Wrap(ctx, err, "getting app config for host debug")) - return false - } - - for _, hostID := range ac.ServerSettings.DebugHostIDs { - if hostID == id { - return true - } - } - return false -} - -func (svc Service) EnrollAgent(ctx context.Context, enrollSecret, hostIdentifier string, hostDetails map[string](map[string]string)) (string, error) { - // skipauth: Authorization is currently for user endpoints only. - svc.authz.SkipAuthorization(ctx) - - logging.WithExtras(ctx, "hostIdentifier", hostIdentifier) - - secret, err := svc.ds.VerifyEnrollSecret(ctx, enrollSecret) - if err != nil { - return "", osqueryError{ - message: "enroll failed: " + err.Error(), - nodeInvalid: true, - } - } - - nodeKey, err := server.GenerateRandomText(svc.config.Osquery.NodeKeySize) - if err != nil { - return "", osqueryError{ - message: "generate node key failed: " + err.Error(), - nodeInvalid: true, - } - } - - hostIdentifier = getHostIdentifier(svc.logger, svc.config.Osquery.HostIdentifier, hostIdentifier, hostDetails) - - host, err := svc.ds.EnrollHost(ctx, hostIdentifier, nodeKey, secret.TeamID, svc.config.Osquery.EnrollCooldown) - if err != nil { - return "", osqueryError{message: "save enroll failed: " + err.Error(), nodeInvalid: true} - } - - appConfig, err := svc.ds.AppConfig(ctx) - if err != nil { - return "", osqueryError{message: "app config load failed: " + err.Error(), nodeInvalid: true} - } - - // Save enrollment details if provided - detailQueries := osquery_utils.GetDetailQueries(appConfig, svc.config) - save := false - if r, ok := hostDetails["os_version"]; ok { - err := detailQueries["os_version"].IngestFunc(svc.logger, host, []map[string]string{r}) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "Ingesting os_version") - } - save = true - } - if r, ok := hostDetails["osquery_info"]; ok { - err := detailQueries["osquery_info"].IngestFunc(svc.logger, host, []map[string]string{r}) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "Ingesting osquery_info") - } - save = true - } - if r, ok := hostDetails["system_info"]; ok { - err := detailQueries["system_info"].IngestFunc(svc.logger, host, []map[string]string{r}) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "Ingesting system_info") - } - save = true - } - - if save { - if appConfig.ServerSettings.DeferredSaveHost { - go svc.serialUpdateHost(host) - } else { - if err := svc.ds.UpdateHost(ctx, host); err != nil { - return "", ctxerr.Wrap(ctx, err, "save host in enroll agent") - } - } - } - - return nodeKey, nil -} - -func (svc Service) serialUpdateHost(host *fleet.Host) { - newVal := atomic.AddInt64(&counter, 1) - defer func() { - atomic.AddInt64(&counter, -1) - }() - level.Debug(svc.logger).Log("background", newVal) - - ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second) - defer cancelFunc() - err := svc.ds.SerialUpdateHost(ctx, host) - if err != nil { - level.Error(svc.logger).Log("background-err", err) - } -} - -func getHostIdentifier(logger log.Logger, identifierOption, providedIdentifier string, details map[string](map[string]string)) string { - switch identifierOption { - case "provided": - // Use the host identifier already provided in the request. - return providedIdentifier - - case "instance": - r, ok := details["osquery_info"] - if !ok { - level.Info(logger).Log( - "msg", "could not get host identifier", - "reason", "missing osquery_info", - "identifier", "instance", - ) - } else if r["instance_id"] == "" { - level.Info(logger).Log( - "msg", "could not get host identifier", - "reason", "missing instance_id in osquery_info", - "identifier", "instance", - ) - } else { - return r["instance_id"] - } - - case "uuid": - r, ok := details["osquery_info"] - if !ok { - level.Info(logger).Log( - "msg", "could not get host identifier", - "reason", "missing osquery_info", - "identifier", "uuid", - ) - } else if r["uuid"] == "" { - level.Info(logger).Log( - "msg", "could not get host identifier", - "reason", "missing instance_id in osquery_info", - "identifier", "uuid", - ) - } else { - return r["uuid"] - } - - case "hostname": - r, ok := details["system_info"] - if !ok { - level.Info(logger).Log( - "msg", "could not get host identifier", - "reason", "missing system_info", - "identifier", "hostname", - ) - } else if r["hostname"] == "" { - level.Info(logger).Log( - "msg", "could not get host identifier", - "reason", "missing instance_id in system_info", - "identifier", "hostname", - ) - } else { - return r["hostname"] - } - - default: - panic("Unknown option for host_identifier: " + identifierOption) - } - - return providedIdentifier -} - -// jitterHashTable implements a data structure that allows a fleet to generate a static jitter value -// that is properly balanced. Balance in this context means that hosts would be distributed uniformly -// across the total jitter time so there are no spikes. -// The way this structure works is as follows: -// Given an amount of buckets, we want to place hosts in buckets evenly. So we don't want bucket 0 to -// have 1000 hosts, and all the other buckets 0. If there were 1000 buckets, and 1000 hosts, we should -// end up with 1 per bucket. -// The total amount of online hosts is unknown, so first it assumes that amount of buckets >= amount -// of total hosts (maxCapacity of 1 per bucket). Once we have more hosts than buckets, then we -// increase the maxCapacity by 1 for all buckets, and start placing hosts. -// Hosts that have been placed in a bucket remain in that bucket for as long as the fleet instance is -// running. -// The preferred bucket for a host is the one at (host id % bucketCount). If that bucket is full, the -// next one will be tried. If all buckets are full, then capacity gets increased and the bucket -// selection process restarts. -// Once a bucket is found, the index for the bucket (going from 0 to bucketCount) will be the amount of -// minutes added to the host check in time. -// For example: at a 1hr interval, and the default 10% max jitter percent. That allows hosts to -// distribute within 6 minutes around the hour mark. We would have 6 buckets in that case. -// In the worst possible case that all hosts start at the same time, max jitter percent can be set to -// 100, and this method will distribute hosts evenly. -// The main caveat of this approach is that it works at the fleet instance. So depending on what -// instance gets chosen by the load balancer, the jitter might be different. However, load tests have -// shown that the distribution in practice is pretty balance even when all hosts try to check in at -// the same time. -type jitterHashTable struct { - mu sync.Mutex - maxCapacity int - bucketCount int - buckets map[int]int - cache map[uint]time.Duration -} - -func newJitterHashTable(bucketCount int) *jitterHashTable { - if bucketCount == 0 { - bucketCount = 1 - } - return &jitterHashTable{ - maxCapacity: 1, - bucketCount: bucketCount, - buckets: make(map[int]int), - cache: make(map[uint]time.Duration), - } -} - -func (jh *jitterHashTable) jitterForHost(hostID uint) time.Duration { - // if no jitter is configured just return 0 - if jh.bucketCount <= 1 { - return 0 - } - - jh.mu.Lock() - if jitter, ok := jh.cache[hostID]; ok { - jh.mu.Unlock() - return jitter - } - - for i := 0; i < jh.bucketCount; i++ { - possibleBucket := (int(hostID) + i) % jh.bucketCount - - // if the next bucket has capacity, great! - if jh.buckets[possibleBucket] < jh.maxCapacity { - jh.buckets[possibleBucket]++ - jitter := time.Duration(possibleBucket) * time.Minute - jh.cache[hostID] = jitter - - jh.mu.Unlock() - return jitter - } - } - - // otherwise, bump the capacity and restart the process - jh.maxCapacity++ - - jh.mu.Unlock() - return jh.jitterForHost(hostID) -} diff --git a/server/service/service_osquery_test.go b/server/service/service_osquery_test.go deleted file mode 100644 index 244fceb4b8..0000000000 --- a/server/service/service_osquery_test.go +++ /dev/null @@ -1,2404 +0,0 @@ -package service - -import ( - "bytes" - "context" - crand "crypto/rand" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "math" - "math/big" - "reflect" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/WatchBeam/clock" - "github.com/fleetdm/fleet/v4/server/authz" - "github.com/fleetdm/fleet/v4/server/config" - hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" - fleetLogging "github.com/fleetdm/fleet/v4/server/contexts/logging" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/live_query" - "github.com/fleetdm/fleet/v4/server/logging" - "github.com/fleetdm/fleet/v4/server/mock" - "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/fleetdm/fleet/v4/server/pubsub" - "github.com/fleetdm/fleet/v4/server/service/osquery_utils" - "github.com/fleetdm/fleet/v4/server/service/redis_policy_set" - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// One of these queries is the disk space, only one of the two works in a platform -var expectedDetailQueries = len(osquery_utils.GetDetailQueries(&fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, config.FleetConfig{})) - 1 - -func TestEnrollAgent(t *testing.T) { - ds := new(mock.Store) - ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { - switch secret { - case "valid_secret": - return &fleet.EnrollSecret{Secret: "valid_secret", TeamID: ptr.Uint(3)}, nil - default: - return nil, errors.New("not found") - } - } - ds.EnrollHostFunc = func(ctx context.Context, osqueryHostId, nodeKey string, teamID *uint, cooldown time.Duration) (*fleet.Host, error) { - assert.Equal(t, ptr.Uint(3), teamID) - return &fleet.Host{ - OsqueryHostID: osqueryHostId, NodeKey: nodeKey, - }, nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - - svc := newTestService(ds, nil, nil) - - nodeKey, err := svc.EnrollAgent(context.Background(), "valid_secret", "host123", nil) - require.NoError(t, err) - assert.NotEmpty(t, nodeKey) -} - -func TestEnrollAgentIncorrectEnrollSecret(t *testing.T) { - ds := new(mock.Store) - ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { - switch secret { - case "valid_secret": - return &fleet.EnrollSecret{Secret: "valid_secret", TeamID: ptr.Uint(3)}, nil - default: - return nil, errors.New("not found") - } - } - - svc := newTestService(ds, nil, nil) - - nodeKey, err := svc.EnrollAgent(context.Background(), "not_correct", "host123", nil) - assert.NotNil(t, err) - assert.Empty(t, nodeKey) -} - -func TestEnrollAgentDetails(t *testing.T) { - ds := new(mock.Store) - ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { - return &fleet.EnrollSecret{}, nil - } - ds.EnrollHostFunc = func(ctx context.Context, osqueryHostId, nodeKey string, teamID *uint, cooldown time.Duration) (*fleet.Host, error) { - return &fleet.Host{ - OsqueryHostID: osqueryHostId, NodeKey: nodeKey, - }, nil - } - var gotHost *fleet.Host - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - gotHost = host - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - - svc := newTestService(ds, nil, nil) - - details := map[string](map[string]string){ - "osquery_info": {"version": "2.12.0"}, - "system_info": {"hostname": "zwass.local", "uuid": "froobling_uuid"}, - "os_version": { - "name": "Mac OS X", - "major": "10", - "minor": "14", - "patch": "5", - "platform": "darwin", - }, - "foo": {"foo": "bar"}, - } - nodeKey, err := svc.EnrollAgent(context.Background(), "", "host123", details) - require.NoError(t, err) - assert.NotEmpty(t, nodeKey) - - assert.Equal(t, "Mac OS X 10.14.5", gotHost.OSVersion) - assert.Equal(t, "darwin", gotHost.Platform) - assert.Equal(t, "2.12.0", gotHost.OsqueryVersion) - assert.Equal(t, "zwass.local", gotHost.Hostname) - assert.Equal(t, "froobling_uuid", gotHost.UUID) -} - -func TestAuthenticateHost(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - var gotKey string - host := fleet.Host{ID: 1, Hostname: "foobar"} - ds.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { - gotKey = nodeKey - return &host, nil - } - var gotHostIDs []uint - ds.MarkHostsSeenFunc = func(ctx context.Context, hostIDs []uint, t time.Time) error { - gotHostIDs = hostIDs - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - - _, _, err := svc.AuthenticateHost(context.Background(), "test") - require.NoError(t, err) - assert.Equal(t, "test", gotKey) - assert.False(t, ds.MarkHostsSeenFuncInvoked) - - host = fleet.Host{ID: 7, Hostname: "foobar"} - _, _, err = svc.AuthenticateHost(context.Background(), "floobar") - require.NoError(t, err) - assert.Equal(t, "floobar", gotKey) - assert.False(t, ds.MarkHostsSeenFuncInvoked) - // Host checks in twice - host = fleet.Host{ID: 7, Hostname: "foobar"} - _, _, err = svc.AuthenticateHost(context.Background(), "floobar") - require.NoError(t, err) - assert.Equal(t, "floobar", gotKey) - assert.False(t, ds.MarkHostsSeenFuncInvoked) - - err = svc.FlushSeenHosts(context.Background()) - require.NoError(t, err) - assert.True(t, ds.MarkHostsSeenFuncInvoked) - assert.ElementsMatch(t, []uint{1, 7}, gotHostIDs) - - err = svc.FlushSeenHosts(context.Background()) - require.NoError(t, err) - assert.True(t, ds.MarkHostsSeenFuncInvoked) - require.Len(t, gotHostIDs, 0) -} - -func TestAuthenticateHostFailure(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - ds.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { - return nil, errors.New("not found") - } - - _, _, err := svc.AuthenticateHost(context.Background(), "test") - require.NotNil(t, err) -} - -type testJSONLogger struct { - logs []json.RawMessage -} - -func (n *testJSONLogger) Write(ctx context.Context, logs []json.RawMessage) error { - n.logs = logs - return nil -} - -func TestSubmitStatusLogs(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - // Hack to get at the service internals and modify the writer - serv := ((svc.(validationMiddleware)).Service).(*Service) - - testLogger := &testJSONLogger{} - serv.osqueryLogWriter = &logging.OsqueryLogger{Status: testLogger} - - logs := []string{ - `{"severity":"0","filename":"tls.cpp","line":"216","message":"some message","version":"1.8.2","decorations":{"host_uuid":"uuid_foobar","username":"zwass"}}`, - `{"severity":"1","filename":"buffered.cpp","line":"122","message":"warning!","version":"1.8.2","decorations":{"host_uuid":"uuid_foobar","username":"zwass"}}`, - } - logJSON := fmt.Sprintf("[%s]", strings.Join(logs, ",")) - - var status []json.RawMessage - err := json.Unmarshal([]byte(logJSON), &status) - require.NoError(t, err) - - host := fleet.Host{} - ctx := hostctx.NewContext(context.Background(), &host) - err = serv.SubmitStatusLogs(ctx, status) - require.NoError(t, err) - - assert.Equal(t, status, testLogger.logs) -} - -func TestSubmitResultLogs(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - // Hack to get at the service internals and modify the writer - serv := ((svc.(validationMiddleware)).Service).(*Service) - - testLogger := &testJSONLogger{} - serv.osqueryLogWriter = &logging.OsqueryLogger{Result: testLogger} - - logs := []string{ - `{"name":"system_info","hostIdentifier":"some_uuid","calendarTime":"Fri Sep 30 17:55:15 2016 UTC","unixTime":"1475258115","decorations":{"host_uuid":"some_uuid","username":"zwass"},"columns":{"cpu_brand":"Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz","hostname":"hostimus","physical_memory":"17179869184"},"action":"added"}`, - `{"name":"encrypted","hostIdentifier":"some_uuid","calendarTime":"Fri Sep 30 21:19:15 2016 UTC","unixTime":"1475270355","decorations":{"host_uuid":"4740D59F-699E-5B29-960B-979AAF9BBEEB","username":"zwass"},"columns":{"encrypted":"1","name":"\/dev\/disk1","type":"AES-XTS","uid":"","user_uuid":"","uuid":"some_uuid"},"action":"added"}`, - `{"snapshot":[{"hour":"20","minutes":"8"}],"action":"snapshot","name":"time","hostIdentifier":"1379f59d98f4","calendarTime":"Tue Jan 10 20:08:51 2017 UTC","unixTime":"1484078931","decorations":{"host_uuid":"EB714C9D-C1F8-A436-B6DA-3F853C5502EA"}}`, - `{"diffResults":{"removed":[{"address":"127.0.0.1","hostnames":"kl.groob.io"}],"added":""},"name":"pack\/test\/hosts","hostIdentifier":"FA01680E-98CA-5557-8F59-7716ECFEE964","calendarTime":"Sun Nov 19 00:02:08 2017 UTC","unixTime":"1511049728","epoch":"0","counter":"10","decorations":{"host_uuid":"FA01680E-98CA-5557-8F59-7716ECFEE964","hostname":"kl.groob.io"}}`, - // fleet will accept anything in the "data" field of a log request. - `{"unknown":{"foo": [] }}`, - } - logJSON := fmt.Sprintf("[%s]", strings.Join(logs, ",")) - - var results []json.RawMessage - err := json.Unmarshal([]byte(logJSON), &results) - require.NoError(t, err) - - host := fleet.Host{} - ctx := hostctx.NewContext(context.Background(), &host) - err = serv.SubmitResultLogs(ctx, results) - require.NoError(t, err) - - assert.Equal(t, results, testLogger.logs) -} - -func TestHostDetailQueries(t *testing.T) { - ds := new(mock.Store) - additional := json.RawMessage(`{"foobar": "select foo", "bim": "bam"}`) - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{AdditionalQueries: &additional, EnableHostUsers: true}}, nil - } - - mockClock := clock.NewMockClock() - host := fleet.Host{ - ID: 1, - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - UpdateTimestamp: fleet.UpdateTimestamp{ - UpdatedAt: mockClock.Now(), - }, - CreateTimestamp: fleet.CreateTimestamp{ - CreatedAt: mockClock.Now(), - }, - }, - - Platform: "darwin", - DetailUpdatedAt: mockClock.Now(), - NodeKey: "test_key", - Hostname: "test_hostname", - UUID: "test_uuid", - } - - svc := &Service{ - clock: mockClock, - logger: log.NewNopLogger(), - config: config.TestConfig(), - ds: ds, - jitterMu: new(sync.Mutex), - jitterH: make(map[time.Duration]*jitterHashTable), - } - - queries, err := svc.detailQueriesForHost(context.Background(), &host) - require.NoError(t, err) - assert.Empty(t, queries) - - // With refetch requested detail queries should be returned - host.RefetchRequested = true - queries, err = svc.detailQueriesForHost(context.Background(), &host) - require.NoError(t, err) - assert.NotEmpty(t, queries) - host.RefetchRequested = false - - // Advance the time - mockClock.AddTime(1*time.Hour + 1*time.Minute) - - queries, err = svc.detailQueriesForHost(context.Background(), &host) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+2) - for name := range queries { - assert.True(t, - strings.HasPrefix(name, hostDetailQueryPrefix) || strings.HasPrefix(name, hostAdditionalQueryPrefix), - ) - } - assert.Equal(t, "bam", queries[hostAdditionalQueryPrefix+"bim"]) - assert.Equal(t, "select foo", queries[hostAdditionalQueryPrefix+"foobar"]) -} - -func TestGetDistributedQueriesMissingHost(t *testing.T) { - svc := newTestService(&mock.Store{}, nil, nil) - - _, _, err := svc.GetDistributedQueries(context.Background()) - require.NotNil(t, err) - assert.Contains(t, err.Error(), "missing host") -} - -func TestLabelQueries(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - lq := new(live_query.MockLiveQuery) - svc := newTestServiceWithClock(ds, nil, lq, mockClock) - - host := &fleet.Host{ - Platform: "darwin", - } - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - return host, nil - } - ds.UpdateHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { - host = gotHost - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil - } - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - - lq.On("QueriesForHost", uint(0)).Return(map[string]string{}, nil) - - ctx := hostctx.NewContext(context.Background(), host) - - // With a new host, we should get the detail queries (and accelerate - // should be turned on so that we can quickly fill labels) - queries, acc, err := svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - assert.NotZero(t, acc) - - // Simulate the detail queries being added. - host.DetailUpdatedAt = mockClock.Now().Add(-1 * time.Minute) - host.Hostname = "zwass.local" - ctx = hostctx.NewContext(ctx, host) - - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, 0) - assert.Zero(t, acc) - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{ - "label1": "query1", - "label2": "query2", - "label3": "query3", - }, nil - } - - // Now we should get the label queries - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, 3) - assert.Zero(t, acc) - - var gotHost *fleet.Host - var gotResults map[uint]*bool - var gotTime time.Time - ds.RecordLabelQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferred bool) error { - gotHost = host - gotResults = results - gotTime = t - return nil - } - - // Record a query execution - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostLabelQueryPrefix + "1": {{"col1": "val1"}}, - }, - map[string]fleet.OsqueryStatus{}, - map[string]string{}, - ) - require.NoError(t, err) - host.LabelUpdatedAt = mockClock.Now() - assert.Equal(t, host, gotHost) - assert.Equal(t, mockClock.Now(), gotTime) - require.Len(t, gotResults, 1) - assert.Equal(t, true, *gotResults[1]) - - mockClock.AddTime(1 * time.Second) - - // Record a query execution - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostLabelQueryPrefix + "2": {{"col1": "val1"}}, - hostLabelQueryPrefix + "3": {}, - }, - map[string]fleet.OsqueryStatus{}, - map[string]string{}, - ) - require.NoError(t, err) - host.LabelUpdatedAt = mockClock.Now() - assert.Equal(t, host, gotHost) - assert.Equal(t, mockClock.Now(), gotTime) - require.Len(t, gotResults, 2) - assert.Equal(t, true, *gotResults[2]) - assert.Equal(t, false, *gotResults[3]) - - // We should get no labels now. - host.LabelUpdatedAt = mockClock.Now() - ctx = hostctx.NewContext(ctx, host) - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, 0) - assert.Zero(t, acc) - - // With refetch requested details+label queries should be returned. - host.RefetchRequested = true - ctx = hostctx.NewContext(ctx, host) - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+3) - assert.Zero(t, acc) - - // Record a query execution - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostLabelQueryPrefix + "2": {{"col1": "val1"}}, - hostLabelQueryPrefix + "3": {}, - }, - map[string]fleet.OsqueryStatus{}, - map[string]string{}, - ) - require.NoError(t, err) - host.LabelUpdatedAt = mockClock.Now() - assert.Equal(t, host, gotHost) - assert.Equal(t, mockClock.Now(), gotTime) - require.Len(t, gotResults, 2) - assert.Equal(t, true, *gotResults[2]) - assert.Equal(t, false, *gotResults[3]) - - // SubmitDistributedQueryResults will set RefetchRequested to false. - require.False(t, host.RefetchRequested) - - // There shouldn't be any labels now. - ctx = hostctx.NewContext(context.Background(), host) - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, 0) - assert.Zero(t, acc) -} - -func TestDetailQueriesWithEmptyStrings(t *testing.T) { - ds := new(mock.Store) - mockClock := clock.NewMockClock() - lq := new(live_query.MockLiveQuery) - svc := newTestServiceWithClock(ds, nil, lq, mockClock) - - host := &fleet.Host{ - ID: 1, - Platform: "windows", - } - ctx := hostctx.NewContext(context.Background(), host) - - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil - } - ds.LabelQueriesForHostFunc = func(context.Context, *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if id != 1 { - return nil, errors.New("not found") - } - return host, nil - } - - lq.On("QueriesForHost", host.ID).Return(map[string]string{}, nil) - - // With a new host, we should get the detail queries (and accelerated - // queries) - queries, acc, err := svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries-3) - assert.NotZero(t, acc) - - resultJSON := ` -{ -"fleet_detail_query_network_interface": [ - { - "address": "192.168.0.1", - "broadcast": "192.168.0.255", - "ibytes": "", - "ierrors": "", - "interface": "en0", - "ipackets": "25698094", - "last_change": "1474233476", - "mac": "5f:3d:4b:10:25:82", - "mask": "255.255.255.0", - "metric": "", - "mtu": "", - "obytes": "", - "oerrors": "", - "opackets": "", - "point_to_point": "", - "type": "" - } -], -"fleet_detail_query_os_version": [ - { - "platform": "darwin", - "build": "15G1004", - "major": "10", - "minor": "10", - "name": "Mac OS X", - "patch": "6" - } -], -"fleet_detail_query_osquery_info": [ - { - "build_distro": "10.10", - "build_platform": "darwin", - "config_hash": "3c6e4537c4d0eb71a7c6dda19d", - "config_valid": "1", - "extensions": "active", - "pid": "38113", - "start_time": "1475603155", - "version": "1.8.2", - "watcher": "38112" - } -], -"fleet_detail_query_system_info": [ - { - "computer_name": "computer", - "cpu_brand": "Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz", - "cpu_logical_cores": "8", - "cpu_physical_cores": "4", - "cpu_subtype": "Intel x86-64h Haswell", - "cpu_type": "x86_64h", - "hardware_model": "MacBookPro11,4", - "hardware_serial": "ABCDEFGH", - "hardware_vendor": "Apple Inc.", - "hardware_version": "1.0", - "hostname": "computer.local", - "physical_memory": "17179869184", - "uuid": "uuid" - } -], -"fleet_detail_query_uptime": [ - { - "days": "20", - "hours": "0", - "minutes": "48", - "seconds": "13", - "total_seconds": "1730893" - } -], -"fleet_detail_query_osquery_flags": [ - { - "name":"config_tls_refresh", - "value":"" - }, - { - "name":"distributed_interval", - "value":"" - }, - { - "name":"logger_tls_period", - "value":"" - } -] -} -` - - var results fleet.OsqueryDistributedQueryResults - err = json.Unmarshal([]byte(resultJSON), &results) - require.NoError(t, err) - - var gotHost *fleet.Host - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - gotHost = host - return nil - } - - // Verify that results are ingested properly - svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}) - - // osquery_info - assert.Equal(t, "darwin", gotHost.Platform) - assert.Equal(t, "1.8.2", gotHost.OsqueryVersion) - - // system_info - assert.Equal(t, int64(17179869184), gotHost.Memory) - assert.Equal(t, "computer.local", gotHost.Hostname) - assert.Equal(t, "uuid", gotHost.UUID) - - // os_version - assert.Equal(t, "Mac OS X 10.10.6", gotHost.OSVersion) - - // uptime - assert.Equal(t, 1730893*time.Second, gotHost.Uptime) - - // osquery_flags - assert.Equal(t, uint(0), gotHost.ConfigTLSRefresh) - assert.Equal(t, uint(0), gotHost.DistributedInterval) - assert.Equal(t, uint(0), gotHost.LoggerTLSPeriod) - - host.Hostname = "computer.local" - host.DetailUpdatedAt = mockClock.Now() - mockClock.AddTime(1 * time.Minute) - - // Now no detail queries should be required - ctx = hostctx.NewContext(context.Background(), host) - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, 0) - assert.Zero(t, acc) - - // Advance clock and queries should exist again - mockClock.AddTime(1*time.Hour + 1*time.Minute) - - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - assert.Zero(t, acc) -} - -func TestDetailQueries(t *testing.T) { - ds := new(mock.Store) - mockClock := clock.NewMockClock() - lq := new(live_query.MockLiveQuery) - svc := newTestServiceWithClock(ds, nil, lq, mockClock) - - host := &fleet.Host{ - ID: 1, - Platform: "linux", - } - ctx := hostctx.NewContext(context.Background(), host) - - lq.On("QueriesForHost", host.ID).Return(map[string]string{}, nil) - - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true, EnableSoftwareInventory: true}}, nil - } - ds.LabelQueriesForHostFunc = func(context.Context, *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.SetOrUpdateMDMDataFunc = func(ctx context.Context, hostID uint, enrolled bool, serverURL string, installedFromDep bool) error { - require.True(t, enrolled) - require.False(t, installedFromDep) - require.Equal(t, "hi.com", serverURL) - return nil - } - ds.SetOrUpdateMunkiVersionFunc = func(ctx context.Context, hostID uint, version string) error { - require.Equal(t, "3.4.5", version) - return nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if id != 1 { - return nil, errors.New("not found") - } - return host, nil - } - - // With a new host, we should get the detail queries (and accelerated - // queries) - queries, acc, err := svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries-2) - assert.NotZero(t, acc) - - resultJSON := ` -{ -"fleet_detail_query_network_interface": [ - { - "address": "192.168.0.1", - "broadcast": "192.168.0.255", - "ibytes": "1601207629", - "ierrors": "314179", - "interface": "en0", - "ipackets": "25698094", - "last_change": "1474233476", - "mac": "5f:3d:4b:10:25:82", - "mask": "255.255.255.0", - "metric": "1", - "mtu": "1453", - "obytes": "2607283152", - "oerrors": "101010", - "opackets": "12264603", - "point_to_point": "", - "type": "6" - } -], -"fleet_detail_query_os_version": [ - { - "platform": "darwin", - "build": "15G1004", - "major": "10", - "minor": "10", - "name": "Mac OS X", - "patch": "6" - } -], -"fleet_detail_query_osquery_info": [ - { - "build_distro": "10.10", - "build_platform": "darwin", - "config_hash": "3c6e4537c4d0eb71a7c6dda19d", - "config_valid": "1", - "extensions": "active", - "pid": "38113", - "start_time": "1475603155", - "version": "1.8.2", - "watcher": "38112" - } -], -"fleet_detail_query_system_info": [ - { - "computer_name": "computer", - "cpu_brand": "Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz", - "cpu_logical_cores": "8", - "cpu_physical_cores": "4", - "cpu_subtype": "Intel x86-64h Haswell", - "cpu_type": "x86_64h", - "hardware_model": "MacBookPro11,4", - "hardware_serial": "ABCDEFGH", - "hardware_vendor": "Apple Inc.", - "hardware_version": "1.0", - "hostname": "computer.local", - "physical_memory": "17179869184", - "uuid": "uuid" - } -], -"fleet_detail_query_uptime": [ - { - "days": "20", - "hours": "0", - "minutes": "48", - "seconds": "13", - "total_seconds": "1730893" - } -], -"fleet_detail_query_osquery_flags": [ - { - "name":"config_tls_refresh", - "value":"10" - }, - { - "name":"config_refresh", - "value":"9" - }, - { - "name":"distributed_interval", - "value":"5" - }, - { - "name":"logger_tls_period", - "value":"60" - } -], -"fleet_detail_query_users": [ - { - "uid": "1234", - "username": "user1", - "type": "sometype", - "groupname": "somegroup", - "shell": "someloginshell" - }, - { - "uid": "5678", - "username": "user2", - "type": "sometype", - "groupname": "somegroup" - } -], -"fleet_detail_query_software_macos": [ - { - "name": "app1", - "version": "1.0.0", - "source": "source1" - }, - { - "name": "app2", - "version": "1.0.0", - "source": "source2", - "bundle_identifier": "somebundle" - } -], -"fleet_detail_query_disk_space_unix": [ - { - "percent_disk_space_available": "56", - "gigs_disk_space_available": "277.0" - } -], -"fleet_detail_query_mdm": [ - { - "enrolled": "true", - "server_url": "hi.com", - "installed_from_dep": "false" - } -], -"fleet_detail_query_munki_info": [ - { - "version": "3.4.5" - } -] -} -` - - var results fleet.OsqueryDistributedQueryResults - err = json.Unmarshal([]byte(resultJSON), &results) - require.NoError(t, err) - - var gotHost *fleet.Host - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - gotHost = host - return nil - } - var gotUsers []fleet.HostUser - ds.SaveHostUsersFunc = func(ctx context.Context, hostID uint, users []fleet.HostUser) error { - if hostID != 1 { - return errors.New("not found") - } - gotUsers = users - return nil - } - var gotSoftware []fleet.Software - ds.UpdateHostSoftwareFunc = func(ctx context.Context, hostID uint, software []fleet.Software) error { - if hostID != 1 { - return errors.New("not found") - } - gotSoftware = software - return nil - } - - // Verify that results are ingested properly - require.NoError(t, svc.SubmitDistributedQueryResults(ctx, results, map[string]fleet.OsqueryStatus{}, map[string]string{})) - require.NotNil(t, gotHost) - - require.True(t, ds.SetOrUpdateMDMDataFuncInvoked) - require.True(t, ds.SetOrUpdateMunkiVersionFuncInvoked) - - // osquery_info - assert.Equal(t, "darwin", gotHost.Platform) - assert.Equal(t, "1.8.2", gotHost.OsqueryVersion) - - // system_info - assert.Equal(t, int64(17179869184), gotHost.Memory) - assert.Equal(t, "computer.local", gotHost.Hostname) - assert.Equal(t, "uuid", gotHost.UUID) - - // os_version - assert.Equal(t, "Mac OS X 10.10.6", gotHost.OSVersion) - - // uptime - assert.Equal(t, 1730893*time.Second, gotHost.Uptime) - - // osquery_flags - assert.Equal(t, uint(10), gotHost.ConfigTLSRefresh) - assert.Equal(t, uint(5), gotHost.DistributedInterval) - assert.Equal(t, uint(60), gotHost.LoggerTLSPeriod) - - // users - require.Len(t, gotUsers, 2) - assert.Equal(t, fleet.HostUser{ - Uid: 1234, - Username: "user1", - Type: "sometype", - GroupName: "somegroup", - Shell: "someloginshell", - }, gotUsers[0]) - assert.Equal(t, fleet.HostUser{ - Uid: 5678, - Username: "user2", - Type: "sometype", - GroupName: "somegroup", - Shell: "", - }, gotUsers[1]) - - // software - require.Len(t, gotSoftware, 2) - assert.Equal(t, []fleet.Software{ - { - Name: "app1", - Version: "1.0.0", - Source: "source1", - }, - { - Name: "app2", - Version: "1.0.0", - BundleIdentifier: "somebundle", - Source: "source2", - }, - }, gotSoftware) - - assert.Equal(t, 56.0, gotHost.PercentDiskSpaceAvailable) - assert.Equal(t, 277.0, gotHost.GigsDiskSpaceAvailable) - - host.Hostname = "computer.local" - host.Platform = "darwin" - host.DetailUpdatedAt = mockClock.Now() - mockClock.AddTime(1 * time.Minute) - - // Now no detail queries should be required - ctx = hostctx.NewContext(ctx, host) - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, 0) - assert.Zero(t, acc) - - // Advance clock and queries should exist again - mockClock.AddTime(1*time.Hour + 1*time.Minute) - - queries, acc, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+1) - assert.Zero(t, acc) -} - -func TestNewDistributedQueryCampaign(t *testing.T) { - ds := new(mock.Store) - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - rs := &mock.QueryResultStore{ - HealthCheckFunc: func() error { - return nil - }, - } - lq := &live_query.MockLiveQuery{} - mockClock := clock.NewMockClock() - svc := newTestServiceWithClock(ds, rs, lq, mockClock) - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { - return nil - } - var gotQuery *fleet.Query - ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { - gotQuery = query - query.ID = 42 - return query, nil - } - var gotCampaign *fleet.DistributedQueryCampaign - ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { - gotCampaign = camp - camp.ID = 21 - return camp, nil - } - var gotTargets []*fleet.DistributedQueryCampaignTarget - ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) { - gotTargets = append(gotTargets, target) - return target, nil - } - - ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { - return fleet.TargetMetrics{}, nil - } - ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { - return []uint{1, 3, 5}, nil - } - lq.On("RunQuery", "21", "select year, month, day, hour, minutes, seconds from time", []uint{1, 3, 5}).Return(nil) - viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{ - User: &fleet.User{ - ID: 0, - GlobalRole: ptr.String(fleet.RoleAdmin), - }, - }) - q := "select year, month, day, hour, minutes, seconds from time" - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { - return nil - } - campaign, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) - require.NoError(t, err) - assert.Equal(t, gotQuery.ID, gotCampaign.QueryID) - assert.True(t, ds.NewActivityFuncInvoked) - assert.Equal(t, []*fleet.DistributedQueryCampaignTarget{ - { - Type: fleet.TargetHost, - DistributedQueryCampaignID: campaign.ID, - TargetID: 2, - }, - { - Type: fleet.TargetLabel, - DistributedQueryCampaignID: campaign.ID, - TargetID: 1, - }, - }, gotTargets, - ) -} - -func TestDistributedQueryResults(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := newTestServiceWithClock(ds, rs, lq, mockClock) - - campaign := &fleet.DistributedQueryCampaign{ID: 42} - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - host := &fleet.Host{ - ID: 1, - Platform: "windows", - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - if id != 1 { - return nil, errors.New("not found") - } - return host, nil - } - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - if host.ID != 1 { - return errors.New("not found") - } - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil - } - - hostCtx := hostctx.NewContext(context.Background(), host) - - lq.On("QueriesForHost", uint(1)).Return( - map[string]string{ - strconv.Itoa(int(campaign.ID)): "select * from time", - }, - nil, - ) - lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil) - - // Now we should get the active distributed query - queries, acc, err := svc.GetDistributedQueries(hostCtx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries-2) - queryKey := fmt.Sprintf("%s%d", hostDistributedQueryPrefix, campaign.ID) - assert.Equal(t, "select * from time", queries[queryKey]) - assert.NotZero(t, acc) - - expectedRows := []map[string]string{ - { - "year": "2016", - "month": "11", - "day": "11", - "hour": "6", - "minutes": "12", - "seconds": "10", - }, - } - results := map[string][]map[string]string{ - queryKey: expectedRows, - } - - // TODO use service method - readChan, err := rs.ReadChannel(context.Background(), *campaign) - require.NoError(t, err) - - // We need to listen for the result in a separate thread to prevent the - // write to the result channel from failing - var waitSetup, waitComplete sync.WaitGroup - waitSetup.Add(1) - waitComplete.Add(1) - go func() { - waitSetup.Done() - select { - case val := <-readChan: - if res, ok := val.(fleet.DistributedQueryResult); ok { - assert.Equal(t, campaign.ID, res.DistributedQueryCampaignID) - assert.Equal(t, expectedRows, res.Rows) - assert.Equal(t, *host, res.Host) - } else { - t.Error("Wrong result type") - } - assert.NotNil(t, val) - - case <-time.After(1 * time.Second): - t.Error("No result received") - } - waitComplete.Done() - }() - - waitSetup.Wait() - // Sleep a short time to ensure that the above goroutine is blocking on - // the channel read (the waitSetup.Wait() is not necessarily sufficient - // if there is a context switch immediately after waitSetup.Done() is - // called). This should be a small price to pay to prevent flakiness in - // this test. - time.Sleep(10 * time.Millisecond) - - err = svc.SubmitDistributedQueryResults(hostCtx, results, map[string]fleet.OsqueryStatus{}, map[string]string{}) - require.NoError(t, err) -} - -func TestIngestDistributedQueryParseIdError(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - host := fleet.Host{ID: 1} - err := svc.ingestDistributedQuery(context.Background(), host, "bad_name", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "unable to parse campaign") -} - -func TestIngestDistributedQueryOrphanedCampaignLoadError(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { - return nil, errors.New("missing campaign") - } - - lq.On("StopQuery", "42").Return(nil) - - host := fleet.Host{ID: 1} - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "loading orphaned campaign") -} - -func TestIngestDistributedQueryOrphanedCampaignWaitListener(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - campaign := &fleet.DistributedQueryCampaign{ - ID: 42, - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - CreateTimestamp: fleet.CreateTimestamp{ - CreatedAt: mockClock.Now().Add(-1 * time.Second), - }, - }, - } - - ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { - return campaign, nil - } - - host := fleet.Host{ID: 1} - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "campaign waiting for listener") -} - -func TestIngestDistributedQueryOrphanedCloseError(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - campaign := &fleet.DistributedQueryCampaign{ - ID: 42, - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - CreateTimestamp: fleet.CreateTimestamp{ - CreatedAt: mockClock.Now().Add(-2 * time.Minute), - }, - }, - } - - ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { - return campaign, nil - } - ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { - return errors.New("failed save") - } - - host := fleet.Host{ID: 1} - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "closing orphaned campaign") -} - -func TestIngestDistributedQueryOrphanedStopError(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - campaign := &fleet.DistributedQueryCampaign{ - ID: 42, - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - CreateTimestamp: fleet.CreateTimestamp{ - CreatedAt: mockClock.Now().Add(-2 * time.Minute), - }, - }, - } - - ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { - return campaign, nil - } - ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { - return nil - } - lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(errors.New("failed")) - - host := fleet.Host{ID: 1} - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "stopping orphaned campaign") -} - -func TestIngestDistributedQueryOrphanedStop(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - campaign := &fleet.DistributedQueryCampaign{ - ID: 42, - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - CreateTimestamp: fleet.CreateTimestamp{ - CreatedAt: mockClock.Now().Add(-2 * time.Minute), - }, - }, - } - - ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { - return campaign, nil - } - ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error { - return nil - } - lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(nil) - - host := fleet.Host{ID: 1} - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "campaign stopped") - lq.AssertExpectations(t) -} - -func TestIngestDistributedQueryRecordCompletionError(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - campaign := &fleet.DistributedQueryCampaign{ID: 42} - host := fleet.Host{ID: 1} - - lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(errors.New("fail")) - - go func() { - ch, err := rs.ReadChannel(context.Background(), *campaign) - require.NoError(t, err) - <-ch - }() - time.Sleep(10 * time.Millisecond) - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.Error(t, err) - assert.Contains(t, err.Error(), "record query completion") - lq.AssertExpectations(t) -} - -func TestIngestDistributedQuery(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - rs := pubsub.NewInmemQueryResults() - lq := new(live_query.MockLiveQuery) - svc := &Service{ - ds: ds, - resultStore: rs, - liveQueryStore: lq, - logger: log.NewNopLogger(), - clock: mockClock, - } - - campaign := &fleet.DistributedQueryCampaign{ID: 42} - host := fleet.Host{ID: 1} - - lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil) - - go func() { - ch, err := rs.ReadChannel(context.Background(), *campaign) - require.NoError(t, err) - <-ch - }() - time.Sleep(10 * time.Millisecond) - - err := svc.ingestDistributedQuery(context.Background(), host, "fleet_distributed_query_42", []map[string]string{}, false, "") - require.NoError(t, err) - lq.AssertExpectations(t) -} - -func TestUpdateHostIntervals(t *testing.T) { - ds := new(mock.Store) - - svc := newTestService(ds, nil, nil) - - ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) { - return []*fleet.Pack{}, nil - } - - testCases := []struct { - name string - initIntervals fleet.HostOsqueryIntervals - finalIntervals fleet.HostOsqueryIntervals - configOptions json.RawMessage - updateIntervalsCalled bool - }{ - { - "Both updated", - fleet.HostOsqueryIntervals{ - ConfigTLSRefresh: 60, - }, - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - json.RawMessage(`{"options": { - "distributed_interval": 11, - "logger_tls_period": 33, - "logger_plugin": "tls" - }}`), - true, - }, - { - "Only logger_tls_period updated", - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - ConfigTLSRefresh: 60, - }, - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - json.RawMessage(`{"options": { - "distributed_interval": 11, - "logger_tls_period": 33 - }}`), - true, - }, - { - "Only distributed_interval updated", - fleet.HostOsqueryIntervals{ - ConfigTLSRefresh: 60, - LoggerTLSPeriod: 33, - }, - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - json.RawMessage(`{"options": { - "distributed_interval": 11, - "logger_tls_period": 33 - }}`), - true, - }, - { - "Fleet not managing distributed_interval", - fleet.HostOsqueryIntervals{ - ConfigTLSRefresh: 60, - DistributedInterval: 11, - }, - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - json.RawMessage(`{"options":{ - "logger_tls_period": 33 - }}`), - true, - }, - { - "config_refresh should also cause an update", - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 42, - }, - json.RawMessage(`{"options":{ - "distributed_interval": 11, - "logger_tls_period": 33, - "config_refresh": 42 - }}`), - true, - }, - { - "update intervals should not be called with no changes", - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - fleet.HostOsqueryIntervals{ - DistributedInterval: 11, - LoggerTLSPeriod: 33, - ConfigTLSRefresh: 60, - }, - json.RawMessage(`{"options":{ - "distributed_interval": 11, - "logger_tls_period": 33 - }}`), - false, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - ctx := hostctx.NewContext(context.Background(), &fleet.Host{ - ID: 1, - NodeKey: "123456", - DistributedInterval: tt.initIntervals.DistributedInterval, - ConfigTLSRefresh: tt.initIntervals.ConfigTLSRefresh, - LoggerTLSPeriod: tt.initIntervals.LoggerTLSPeriod, - }) - - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{AgentOptions: ptr.RawMessage(json.RawMessage(`{"config":` + string(tt.configOptions) + `}`))}, nil - } - - updateIntervalsCalled := false - ds.UpdateHostOsqueryIntervalsFunc = func(ctx context.Context, hostID uint, intervals fleet.HostOsqueryIntervals) error { - if hostID != 1 { - return errors.New("not found") - } - updateIntervalsCalled = true - assert.Equal(t, tt.finalIntervals, intervals) - return nil - } - - _, err := svc.GetClientConfig(ctx) - require.NoError(t, err) - assert.Equal(t, tt.updateIntervalsCalled, updateIntervalsCalled) - }) - } -} - -type notFoundError struct{} - -func (e notFoundError) Error() string { - return "not found" -} - -func (e notFoundError) IsNotFound() bool { - return true -} - -func TestAuthenticationErrors(t *testing.T) { - ms := new(mock.Store) - ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { - return nil, nil - } - - svc := newTestService(ms, nil, nil) - ctx := context.Background() - - _, _, err := svc.AuthenticateHost(ctx, "") - require.Error(t, err) - require.True(t, err.(osqueryError).NodeInvalid()) - - ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { - return &fleet.Host{ID: 1}, nil - } - ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - _, _, err = svc.AuthenticateHost(ctx, "foo") - require.NoError(t, err) - - // return not found error - ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { - return nil, notFoundError{} - } - - _, _, err = svc.AuthenticateHost(ctx, "foo") - require.Error(t, err) - require.True(t, err.(osqueryError).NodeInvalid()) - - // return other error - ms.LoadHostByNodeKeyFunc = func(ctx context.Context, nodeKey string) (*fleet.Host, error) { - return nil, errors.New("foo") - } - - _, _, err = svc.AuthenticateHost(ctx, "foo") - require.NotNil(t, err) - require.False(t, err.(osqueryError).NodeInvalid()) -} - -func TestGetHostIdentifier(t *testing.T) { - t.Parallel() - - details := map[string](map[string]string){ - "osquery_info": map[string]string{ - "uuid": "foouuid", - "instance_id": "fooinstance", - }, - "system_info": map[string]string{ - "hostname": "foohost", - }, - } - - emptyDetails := map[string](map[string]string){ - "osquery_info": map[string]string{ - "uuid": "", - "instance_id": "", - }, - "system_info": map[string]string{ - "hostname": "", - }, - } - - testCases := []struct { - identifierOption string - providedIdentifier string - details map[string](map[string]string) - expected string - shouldPanic bool - }{ - // Panix - {identifierOption: "bad", shouldPanic: true}, - {identifierOption: "", shouldPanic: true}, - - // Missing details - {identifierOption: "instance", providedIdentifier: "foobar", expected: "foobar"}, - {identifierOption: "uuid", providedIdentifier: "foobar", expected: "foobar"}, - {identifierOption: "hostname", providedIdentifier: "foobar", expected: "foobar"}, - {identifierOption: "provided", providedIdentifier: "foobar", expected: "foobar"}, - - // Empty details - {identifierOption: "instance", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, - {identifierOption: "uuid", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, - {identifierOption: "hostname", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, - {identifierOption: "provided", providedIdentifier: "foobar", details: emptyDetails, expected: "foobar"}, - - // Successes - {identifierOption: "instance", providedIdentifier: "foobar", details: details, expected: "fooinstance"}, - {identifierOption: "uuid", providedIdentifier: "foobar", details: details, expected: "foouuid"}, - {identifierOption: "hostname", providedIdentifier: "foobar", details: details, expected: "foohost"}, - {identifierOption: "provided", providedIdentifier: "foobar", details: details, expected: "foobar"}, - } - logger := log.NewNopLogger() - - for _, tt := range testCases { - t.Run("", func(t *testing.T) { - if tt.shouldPanic { - assert.Panics( - t, - func() { getHostIdentifier(logger, tt.identifierOption, tt.providedIdentifier, tt.details) }, - ) - return - } - - assert.Equal( - t, - tt.expected, - getHostIdentifier(logger, tt.identifierOption, tt.providedIdentifier, tt.details), - ) - }) - } -} - -func TestDistributedQueriesLogsManyErrors(t *testing.T) { - buf := new(bytes.Buffer) - logger := log.NewJSONLogger(buf) - logger = level.NewFilter(logger, level.AllowDebug()) - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - host := &fleet.Host{ - ID: 1, - Platform: "darwin", - } - - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - return authz.CheckMissingWithResponse(nil) - } - ds.RecordLabelQueryExecutionsFunc = func(ctx context.Context, host *fleet.Host, results map[uint]*bool, t time.Time, deferred bool) error { - return errors.New("something went wrong") - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - ds.SaveHostAdditionalFunc = func(ctx context.Context, hostID uint, additional *json.RawMessage) error { - return errors.New("something went wrong") - } - - lCtx := &fleetLogging.LoggingContext{} - ctx := fleetLogging.NewContext(context.Background(), lCtx) - ctx = hostctx.NewContext(ctx, host) - - err := svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostDetailQueryPrefix + "network_interface": {{"col1": "val1"}}, // we need one detail query that updates hosts. - hostLabelQueryPrefix + "1": {{"col1": "val1"}}, - hostAdditionalQueryPrefix + "1": {{"col1": "val1"}}, - }, - map[string]fleet.OsqueryStatus{}, - map[string]string{}, - ) - require.NoError(t, err) - - lCtx.Log(ctx, logger) - - logs := buf.String() - parts := strings.Split(strings.TrimSpace(logs), "\n") - require.Len(t, parts, 1) - logData := make(map[string]json.RawMessage) - err = json.Unmarshal([]byte(parts[0]), &logData) - require.NoError(t, err) - assert.Equal(t, json.RawMessage(`"something went wrong || something went wrong"`), logData["err"]) - assert.Equal(t, json.RawMessage(`"Missing authorization check"`), logData["internal"]) -} - -func TestDistributedQueriesReloadsHostIfDetailsAreIn(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - host := &fleet.Host{ - ID: 42, - Platform: "darwin", - } - - ds.UpdateHostFunc = func(ctx context.Context, host *fleet.Host) error { - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - - ctx := hostctx.NewContext(context.Background(), host) - - err := svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostDetailQueryPrefix + "network_interface": {{"col1": "val1"}}, - }, - map[string]fleet.OsqueryStatus{}, - map[string]string{}, - ) - require.NoError(t, err) - assert.True(t, ds.UpdateHostFuncInvoked) -} - -func TestObserversCanOnlyRunDistributedCampaigns(t *testing.T) { - ds := new(mock.Store) - rs := &mock.QueryResultStore{ - HealthCheckFunc: func() error { - return nil - }, - } - lq := &live_query.MockLiveQuery{} - mockClock := clock.NewMockClock() - svc := newTestServiceWithClock(ds, rs, lq, mockClock) - - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - - ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { - return camp, nil - } - ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { - return &fleet.Query{ - ID: 42, - Name: "query", - Query: "select 1;", - ObserverCanRun: false, - }, nil - } - viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{ - User: &fleet.User{ID: 0, GlobalRole: ptr.String(fleet.RoleObserver)}, - }) - - q := "select year, month, day, hour, minutes, seconds from time" - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { - return nil - } - _, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) - require.Error(t, err) - - _, err = svc.NewDistributedQueryCampaign(viewerCtx, "", ptr.Uint(42), fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) - require.Error(t, err) - - ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { - return &fleet.Query{ - ID: 42, - Name: "query", - Query: "select 1;", - ObserverCanRun: true, - }, nil - } - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.SaveHostFunc = func(ctx context.Context, host *fleet.Host) error { return nil } - ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { - camp.ID = 21 - return camp, nil - } - ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) { - return target, nil - } - ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { - return fleet.TargetMetrics{}, nil - } - ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { - return []uint{1, 3, 5}, nil - } - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { - return nil - } - lq.On("RunQuery", "21", "select 1;", []uint{1, 3, 5}).Return(nil) - _, err = svc.NewDistributedQueryCampaign(viewerCtx, "", ptr.Uint(42), fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}}) - require.NoError(t, err) -} - -func TestTeamMaintainerCanRunNewDistributedCampaigns(t *testing.T) { - ds := new(mock.Store) - rs := &mock.QueryResultStore{ - HealthCheckFunc: func() error { - return nil - }, - } - lq := &live_query.MockLiveQuery{} - mockClock := clock.NewMockClock() - svc := newTestServiceWithClock(ds, rs, lq, mockClock) - - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{}, nil - } - - ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { - return camp, nil - } - ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { - return &fleet.Query{ - ID: 42, - AuthorID: ptr.Uint(99), - Name: "query", - Query: "select 1;", - ObserverCanRun: false, - }, nil - } - viewerCtx := viewer.NewContext(context.Background(), viewer.Viewer{ - User: &fleet.User{ID: 99, Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 123}, Role: fleet.RoleMaintainer}}}, - }) - - q := "select year, month, day, hour, minutes, seconds from time" - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { - return nil - } - // var gotQuery *fleet.Query - ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { - // gotQuery = query - query.ID = 42 - return query, nil - } - ds.NewDistributedQueryCampaignTargetFunc = func(ctx context.Context, target *fleet.DistributedQueryCampaignTarget) (*fleet.DistributedQueryCampaignTarget, error) { - return target, nil - } - ds.CountHostsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time) (fleet.TargetMetrics, error) { - return fleet.TargetMetrics{}, nil - } - ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { - return []uint{1, 3, 5}, nil - } - ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error { - return nil - } - lq.On("RunQuery", "0", "select year, month, day, hour, minutes, seconds from time", []uint{1, 3, 5}).Return(nil) - _, err := svc.NewDistributedQueryCampaign(viewerCtx, q, nil, fleet.HostTargets{HostIDs: []uint{2}, LabelIDs: []uint{1}, TeamIDs: []uint{123}}) - require.NoError(t, err) -} - -func TestPolicyQueries(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - lq := new(live_query.MockLiveQuery) - svc := newTestServiceWithClock(ds, nil, lq, mockClock) - - host := &fleet.Host{ - Platform: "darwin", - } - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - return host, nil - } - ds.UpdateHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { - host = gotHost - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil - } - - lq.On("QueriesForHost", uint(0)).Return(map[string]string{}, nil) - - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{"1": "select 1", "2": "select 42;"}, nil - } - recordedResults := make(map[uint]*bool) - ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, deferred bool) error { - recordedResults = results - host = gotHost - return nil - } - ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { - return nil, nil, nil - } - - ctx := hostctx.NewContext(context.Background(), host) - - queries, _, err := svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+2) - - checkPolicyResults := func(queries map[string]string) { - hasPolicy1, hasPolicy2 := false, false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - if name[len(hostPolicyQueryPrefix):] == "1" { - hasPolicy1 = true - } - if name[len(hostPolicyQueryPrefix):] == "2" { - hasPolicy2 = true - } - } - } - assert.True(t, hasPolicy1) - assert.True(t, hasPolicy2) - } - - checkPolicyResults(queries) - - // Record a query execution. - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, - hostPolicyQueryPrefix + "2": {}, - }, - map[string]fleet.OsqueryStatus{ - hostPolicyQueryPrefix + "2": 1, - }, - map[string]string{}, - ) - require.NoError(t, err) - require.Len(t, recordedResults, 2) - require.NotNil(t, recordedResults[1]) - require.True(t, *recordedResults[1]) - result, ok := recordedResults[2] - require.True(t, ok) - require.Nil(t, result) - - noPolicyResults := func(queries map[string]string) { - hasAnyPolicy := false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - hasAnyPolicy = true - break - } - } - assert.False(t, hasAnyPolicy) - } - - // After the first time we get policies and update the host, then there shouldn't be any policies. - ctx = hostctx.NewContext(context.Background(), host) - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - noPolicyResults(queries) - - // Let's move time forward, there should be policies now. - mockClock.AddTime(2 * time.Hour) - - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+2) - checkPolicyResults(queries) - - // Record another query execution. - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, - hostPolicyQueryPrefix + "2": {}, - }, - map[string]fleet.OsqueryStatus{ - hostPolicyQueryPrefix + "2": 1, - }, - map[string]string{}, - ) - require.NoError(t, err) - require.Len(t, recordedResults, 2) - require.NotNil(t, recordedResults[1]) - require.True(t, *recordedResults[1]) - result, ok = recordedResults[2] - require.True(t, ok) - require.Nil(t, result) - - // There shouldn't be any policies now. - ctx = hostctx.NewContext(context.Background(), host) - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - noPolicyResults(queries) - - // With refetch requested policy queries should be returned. - host.RefetchRequested = true - ctx = hostctx.NewContext(context.Background(), host) - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+2) - checkPolicyResults(queries) - - // Record another query execution. - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, - hostPolicyQueryPrefix + "2": {}, - }, - map[string]fleet.OsqueryStatus{ - hostPolicyQueryPrefix + "2": 1, - }, - map[string]string{}, - ) - require.NoError(t, err) - require.NotNil(t, recordedResults[1]) - require.True(t, *recordedResults[1]) - result, ok = recordedResults[2] - require.True(t, ok) - require.Nil(t, result) - - // SubmitDistributedQueryResults will set RefetchRequested to false. - require.False(t, host.RefetchRequested) - - // There shouldn't be any policies now. - ctx = hostctx.NewContext(context.Background(), host) - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - noPolicyResults(queries) -} - -func TestPolicyWebhooks(t *testing.T) { - mockClock := clock.NewMockClock() - ds := new(mock.Store) - lq := new(live_query.MockLiveQuery) - pool := redistest.SetupRedis(t, t.Name(), false, false, false) - failingPolicySet := redis_policy_set.NewFailingTest(t, pool) - testConfig := config.TestConfig() - svc := newTestServiceWithConfig(ds, testConfig, nil, lq, TestServerOpts{ - FailingPolicySet: failingPolicySet, - Clock: mockClock, - }) - - host := &fleet.Host{ - ID: 5, - Platform: "darwin", - Hostname: "test.hostname", - } - - lq.On("QueriesForHost", uint(5)).Return(map[string]string{}, nil) - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - return host, nil - } - ds.UpdateHostFunc = func(ctx context.Context, gotHost *fleet.Host) error { - host = gotHost - return nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{ - HostSettings: fleet.HostSettings{ - EnableHostUsers: true, - }, - WebhookSettings: fleet.WebhookSettings{ - FailingPoliciesWebhook: fleet.FailingPoliciesWebhookSettings{ - Enable: true, - PolicyIDs: []uint{1, 2, 3}, - }, - }, - }, nil - } - - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{ - "1": "select 1;", // passing policy - "2": "select * from unexistent_table;", // policy that fails to execute (e.g. missing table) - "3": "select 1 where 1 = 0;", // failing policy - }, nil - } - recordedResults := make(map[uint]*bool) - ds.RecordPolicyQueryExecutionsFunc = func(ctx context.Context, gotHost *fleet.Host, results map[uint]*bool, updated time.Time, deferred bool) error { - recordedResults = results - host = gotHost - return nil - } - ctx := hostctx.NewContext(context.Background(), host) - - queries, _, err := svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+3) - - checkPolicyResults := func(queries map[string]string) { - hasPolicy1, hasPolicy2, hasPolicy3 := false, false, false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - switch name[len(hostPolicyQueryPrefix):] { - case "1": - hasPolicy1 = true - case "2": - hasPolicy2 = true - case "3": - hasPolicy3 = true - } - } - } - assert.True(t, hasPolicy1) - assert.True(t, hasPolicy2) - assert.True(t, hasPolicy3) - } - - checkPolicyResults(queries) - - ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { - return []uint{3}, nil, nil - } - - // Record a query execution. - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostPolicyQueryPrefix + "1": {{"col1": "val1"}}, // succeeds - hostPolicyQueryPrefix + "2": {}, // didn't execute - hostPolicyQueryPrefix + "3": {}, // fails - }, - map[string]fleet.OsqueryStatus{ - hostPolicyQueryPrefix + "2": 1, // didn't execute - }, - map[string]string{}, - ) - require.NoError(t, err) - require.Len(t, recordedResults, 3) - require.NotNil(t, recordedResults[1]) - require.True(t, *recordedResults[1]) - result, ok := recordedResults[2] - require.True(t, ok) - require.Nil(t, result) - require.NotNil(t, recordedResults[3]) - require.False(t, *recordedResults[3]) - - cmpSets := func(expSets map[uint][]fleet.PolicySetHost) error { - actualSets, err := failingPolicySet.ListSets() - if err != nil { - return err - } - var expSets_ []uint - for expSet := range expSets { - expSets_ = append(expSets_, expSet) - } - sort.Slice(expSets_, func(i, j int) bool { - return expSets_[i] < expSets_[j] - }) - sort.Slice(actualSets, func(i, j int) bool { - return actualSets[i] < actualSets[j] - }) - if !reflect.DeepEqual(actualSets, expSets_) { - return fmt.Errorf("sets mismatch: %+v vs %+v", actualSets, expSets_) - } - for expID, expHosts := range expSets { - actualHosts, err := failingPolicySet.ListHosts(expID) - if err != nil { - return err - } - sort.Slice(actualHosts, func(i, j int) bool { - return actualHosts[i].ID < actualHosts[j].ID - }) - sort.Slice(expHosts, func(i, j int) bool { - return expHosts[i].ID < expHosts[j].ID - }) - if !reflect.DeepEqual(actualHosts, expHosts) { - return fmt.Errorf("hosts mismatch %d: %+v vs %+v", expID, actualHosts, expHosts) - } - } - return nil - } - - assert.Eventually(t, func() bool { - err = cmpSets(map[uint][]fleet.PolicySetHost{ - 3: {{ - ID: host.ID, - Hostname: host.Hostname, - }}, - }) - return err == nil - }, 1*time.Minute, 250*time.Millisecond) - require.NoError(t, err) - - noPolicyResults := func(queries map[string]string) { - hasAnyPolicy := false - for name := range queries { - if strings.HasPrefix(name, hostPolicyQueryPrefix) { - hasAnyPolicy = true - break - } - } - assert.False(t, hasAnyPolicy) - } - - // After the first time we get policies and update the host, then there shouldn't be any policies. - ctx = hostctx.NewContext(context.Background(), host) - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - noPolicyResults(queries) - - // Let's move time forward, there should be policies now. - mockClock.AddTime(2 * time.Hour) - - queries, _, err = svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries+3) - checkPolicyResults(queries) - - ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { - return []uint{1}, []uint{3}, nil - } - - // Record another query execution. - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostPolicyQueryPrefix + "1": {}, // 1 now fails - hostPolicyQueryPrefix + "2": {}, // didn't execute - hostPolicyQueryPrefix + "3": {{"col1": "val1"}}, // 1 now succeeds - }, - map[string]fleet.OsqueryStatus{ - hostPolicyQueryPrefix + "2": 1, // didn't execute - }, - map[string]string{}, - ) - require.NoError(t, err) - require.Len(t, recordedResults, 3) - require.NotNil(t, recordedResults[1]) - require.False(t, *recordedResults[1]) - result, ok = recordedResults[2] - require.True(t, ok) - require.Nil(t, result) - require.NotNil(t, recordedResults[3]) - require.True(t, *recordedResults[3]) - - assert.Eventually(t, func() bool { - err = cmpSets(map[uint][]fleet.PolicySetHost{ - 1: {{ - ID: host.ID, - Hostname: host.Hostname, - }}, - 3: {}, - }) - return err == nil - }, 1*time.Minute, 250*time.Millisecond) - require.NoError(t, err) - - // Simulate webhook trigger by removing the hosts. - err = failingPolicySet.RemoveHosts(1, []fleet.PolicySetHost{{ - ID: host.ID, - Hostname: host.Hostname, - }}) - require.NoError(t, err) - - ds.FlippingPoliciesForHostFunc = func(ctx context.Context, hostID uint, incomingResults map[uint]*bool) (newFailing []uint, newPassing []uint, err error) { - return []uint{}, []uint{2}, nil - } - - // Record another query execution. - err = svc.SubmitDistributedQueryResults( - ctx, - map[string][]map[string]string{ - hostPolicyQueryPrefix + "1": {}, // continues to fail - hostPolicyQueryPrefix + "2": {{"col1": "val1"}}, // now passes - hostPolicyQueryPrefix + "3": {{"col1": "val1"}}, // continues to succeed - }, - map[string]fleet.OsqueryStatus{}, - map[string]string{}, - ) - require.NoError(t, err) - require.Len(t, recordedResults, 3) - require.NotNil(t, recordedResults[1]) - require.False(t, *recordedResults[1]) - require.NotNil(t, recordedResults[2]) - require.True(t, *recordedResults[2]) - require.NotNil(t, recordedResults[3]) - require.True(t, *recordedResults[3]) - - assert.Eventually(t, func() bool { - err = cmpSets(map[uint][]fleet.PolicySetHost{ - 1: {}, - 3: {}, - }) - return err == nil - }, 1*time.Minute, 250*time.Millisecond) - require.NoError(t, err) -} - -// If the live query store (Redis) is down we still (see #3503) -// want hosts to get queries and continue to check in. -func TestLiveQueriesFailing(t *testing.T) { - ds := new(mock.Store) - lq := new(live_query.MockLiveQuery) - cfg := config.TestConfig() - buf := new(bytes.Buffer) - logger := log.NewLogfmtLogger(buf) - svc := newTestServiceWithConfig(ds, cfg, nil, lq, TestServerOpts{ - Logger: logger, - }) - - hostID := uint(1) - host := &fleet.Host{ - ID: hostID, - Platform: "darwin", - } - lq.On("QueriesForHost", hostID).Return( - map[string]string{}, - errors.New("failed to get queries for host"), - ) - - ds.LabelQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) { - return host, nil - } - ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - return &fleet.AppConfig{HostSettings: fleet.HostSettings{EnableHostUsers: true}}, nil - } - ds.PolicyQueriesForHostFunc = func(ctx context.Context, host *fleet.Host) (map[string]string, error) { - return map[string]string{}, nil - } - - ctx := hostctx.NewContext(context.Background(), host) - - queries, _, err := svc.GetDistributedQueries(ctx) - require.NoError(t, err) - require.Len(t, queries, expectedDetailQueries) - - logs, err := ioutil.ReadAll(buf) - require.NoError(t, err) - require.Contains(t, string(logs), "level=error") - require.Contains(t, string(logs), "failed to get queries for host") -} - -func TestJitterForHost(t *testing.T) { - jh := newJitterHashTable(30) - - histogram := make(map[int64]int) - hostCount := 3000 - for i := 0; i < hostCount; i++ { - hostID, err := crand.Int(crand.Reader, big.NewInt(10000)) - require.NoError(t, err) - jitter := jh.jitterForHost(uint(hostID.Int64() + 10000)) - jitterMinutes := int64(jitter.Minutes()) - histogram[jitterMinutes]++ - } - min, max := math.MaxInt, 0 - for jitterMinutes, count := range histogram { - if count < min { - min = count - } - if count > max { - max = count - } - t.Logf("jitterMinutes=%d \t count=%d\n", jitterMinutes, count) - } - variation := max - min - t.Logf("min=%d \t max=%d \t variation=%d\n", min, max, variation) - - // check that variation is below 1% of the total amount of hosts - require.Less(t, variation, int(float32(hostCount)/0.01)) -} - -func TestNoJitter(t *testing.T) { - jh := newJitterHashTable(0) - - hostCount := 3000 - for i := 0; i < hostCount; i++ { - hostID, err := crand.Int(crand.Reader, big.NewInt(10000)) - require.NoError(t, err) - jitter := jh.jitterForHost(uint(hostID.Int64() + 10000)) - jitterMinutes := int64(jitter.Minutes()) - require.Equal(t, int64(0), jitterMinutes) - } -} diff --git a/server/service/service_sessions.go b/server/service/service_sessions.go deleted file mode 100644 index 9aade15b3f..0000000000 --- a/server/service/service_sessions.go +++ /dev/null @@ -1,325 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/xml" - "errors" - "fmt" - "net/url" - "time" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/contexts/logging" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/sso" - "github.com/go-kit/kit/log/level" -) - -// SSOSettings returns a subset of the Single Sign-On settings as configured in -// the app config. Those can be exposed e.g. via the response to an HTTP request, -// and as such should not contain sensitive information. -func (svc *Service) SSOSettings(ctx context.Context) (*fleet.SessionSSOSettings, error) { - // skipauth: Basic SSO settings are available to unauthenticated users (so - // that they have the necessary information to initiate SSO). - svc.authz.SkipAuthorization(ctx) - - logging.WithLevel(ctx, level.Info) - - appConfig, err := svc.ds.AppConfig(ctx) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "SessionSSOSettings getting app config") - } - - settings := &fleet.SessionSSOSettings{ - IDPName: appConfig.SSOSettings.IDPName, - IDPImageURL: appConfig.SSOSettings.IDPImageURL, - SSOEnabled: appConfig.SSOSettings.EnableSSO, - } - return settings, nil -} - -// InitiateSSO initiates a Single Sign-On flow for a request to visit the -// protected URL identified by redirectURL. It returns the URL of the identity -// provider to make a request to to proceed with the authentication via that -// external service, and stores ephemeral session state to validate the -// callback from the identity provider to finalize the SSO flow. -func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string, error) { - // skipauth: User context does not yet exist. Unauthenticated users may - // initiate SSO. - svc.authz.SkipAuthorization(ctx) - - logging.WithLevel(ctx, level.Info) - - appConfig, err := svc.ds.AppConfig(ctx) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config") - } - - metadata, err := svc.getMetadata(appConfig) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata") - } - - serverURL := appConfig.ServerSettings.ServerURL - settings := sso.Settings{ - Metadata: metadata, - // Construct call back url to send to idp - AssertionConsumerServiceURL: serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/sso/callback", - SessionStore: svc.ssoSessionStore, - OriginalURL: redirectURL, - } - - // If issuer is not explicitly set, default to host name. - var issuer string - entityID := appConfig.SSOSettings.EntityID - if entityID == "" { - u, err := url.Parse(serverURL) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "parse server url") - } - issuer = u.Hostname() - } else { - issuer = entityID - } - - idpURL, err := sso.CreateAuthorizationRequest(&settings, issuer) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization") - } - - return idpURL, nil -} - -func (svc *Service) getMetadata(config *fleet.AppConfig) (*sso.Metadata, error) { - if config.SSOSettings.MetadataURL != "" { - metadata, err := sso.GetMetadata(config.SSOSettings.MetadataURL) - if err != nil { - return nil, err - } - return metadata, nil - } - - if config.SSOSettings.Metadata != "" { - metadata, err := sso.ParseMetadata(config.SSOSettings.Metadata) - if err != nil { - return nil, err - } - return metadata, nil - } - - return nil, fmt.Errorf("missing metadata for idp %s", config.SSOSettings.IDPName) -} - -func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SSOSession, error) { - // skipauth: User context does not yet exist. Unauthenticated users may - // hit the SSO callback. - svc.authz.SkipAuthorization(ctx) - - logging.WithLevel(ctx, level.Info) - - appConfig, err := svc.ds.AppConfig(ctx) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "get config for sso") - } - - // Load the request metadata if available - - // localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080 - var metadata *sso.Metadata - var redirectURL string - - if appConfig.SSOSettings.EnableSSOIdPLogin && auth.RequestID() == "" { - // Missing request ID indicates this was IdP-initiated. Only allow if - // configured to do so. - metadata, err = svc.getMetadata(appConfig) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "get sso metadata") - } - redirectURL = "/" - } else { - session, err := svc.ssoSessionStore.Get(auth.RequestID()) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "sso request invalid") - } - // Remove session to so that is can't be reused before it expires. - err = svc.ssoSessionStore.Expire(auth.RequestID()) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "remove sso request") - } - if err := xml.Unmarshal([]byte(session.Metadata), &metadata); err != nil { - return nil, ctxerr.Wrap(ctx, err, "unmarshal metadata") - } - redirectURL = session.OriginalURL - } - - // Validate response - validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience( - appConfig.SSOSettings.EntityID, - appConfig.ServerSettings.ServerURL, - appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS - )) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "create validator from metadata") - } - // make sure the response hasn't been tampered with - auth, err = validator.ValidateSignature(auth) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "signature validation failed") - } - // make sure the response isn't stale - err = validator.ValidateResponse(auth) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "response validation failed") - } - - // Get and log in user - user, err := svc.ds.UserByEmail(ctx, auth.UserID()) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "find user in sso callback") - } - // if the user is not sso enabled they are not authorized - if !user.SSOEnabled { - return nil, ctxerr.New(ctx, "user not configured to use sso") - } - token, err := svc.makeSession(ctx, user.ID) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "make session in sso callback") - } - result := &fleet.SSOSession{ - Token: token, - RedirectURL: redirectURL, - } - return result, nil -} - -func (svc *Service) Login(ctx context.Context, email, password string) (*fleet.User, string, error) { - // skipauth: No user context available yet to authorize against. - svc.authz.SkipAuthorization(ctx) - - logging.WithLevel(logging.WithNoUser(ctx), level.Info) - - // If there is an error, sleep until the request has taken at least 1 - // second. This means that generally a login failure for any reason will - // take ~1s and frustrate a timing attack. - var err error - defer func(start time.Time) { - if err != nil { - time.Sleep(time.Until(start.Add(1 * time.Second))) - } - }(time.Now()) - - user, err := svc.ds.UserByEmail(ctx, email) - var nfe fleet.NotFoundError - if errors.As(err, &nfe) { - return nil, "", fleet.NewAuthFailedError("user not found") - } - if err != nil { - return nil, "", fleet.NewAuthFailedError(err.Error()) - } - - if err = user.ValidatePassword(password); err != nil { - return nil, "", fleet.NewAuthFailedError("invalid password") - } - - if user.SSOEnabled { - return nil, "", fleet.NewAuthFailedError("password login disabled for sso users") - } - - token, err := svc.makeSession(ctx, user.ID) - if err != nil { - return nil, "", fleet.NewAuthFailedError(err.Error()) - } - - return user, token, nil -} - -// makeSession is a helper that creates a new session after authentication -func (svc *Service) makeSession(ctx context.Context, id uint) (string, error) { - sessionKeySize := svc.config.Session.KeySize - key := make([]byte, sessionKeySize) - _, err := rand.Read(key) - if err != nil { - return "", err - } - - sessionKey := base64.StdEncoding.EncodeToString(key) - session := &fleet.Session{ - UserID: id, - Key: sessionKey, - AccessedAt: time.Now().UTC(), - } - - _, err = svc.ds.NewSession(ctx, session) - if err != nil { - return "", ctxerr.Wrap(ctx, err, "creating new session") - } - - return sessionKey, nil -} - -func (svc *Service) Logout(ctx context.Context) error { - // skipauth: Any user can always log out of their own session. - svc.authz.SkipAuthorization(ctx) - - logging.WithLevel(ctx, level.Info) - - // TODO: this should not return an error if the user wasn't logged in - return svc.DestroySession(ctx) -} - -func (svc *Service) DestroySession(ctx context.Context) error { - vc, ok := viewer.FromContext(ctx) - if !ok { - return fleet.ErrNoContext - } - - session, err := svc.ds.SessionByID(ctx, vc.SessionID()) - if err != nil { - return err - } - - if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.DestroySession(ctx, session) -} - -func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) { - session, err := svc.ds.SessionByKey(ctx, key) - if err != nil { - return nil, err - } - - err = svc.validateSession(ctx, session) - if err != nil { - return nil, err - } - - return session, nil -} - -func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error { - if session == nil { - return fleet.NewAuthRequiredError("active session not present") - } - - sessionDuration := svc.config.Session.Duration - if session.APIOnly != nil && *session.APIOnly { - sessionDuration = 0 // make API-only tokens unlimited - } - - // duration 0 = unlimited - if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration { - err := svc.ds.DestroySession(ctx, session) - if err != nil { - return ctxerr.Wrap(ctx, err, "destroying session") - } - return fleet.NewAuthRequiredError("expired session") - } - - return svc.ds.MarkSessionAccessed(ctx, session) -} diff --git a/server/service/service_sessions_test.go b/server/service/service_sessions_test.go deleted file mode 100644 index c1e24c6afe..0000000000 --- a/server/service/service_sessions_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package service - -import ( - "context" - "testing" - "time" - - "github.com/fleetdm/fleet/v4/server/config" - "github.com/fleetdm/fleet/v4/server/datastore/mysql" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/mock" - "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/fleetdm/fleet/v4/server/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAuthenticate(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - svc := newTestService(ds, nil, nil) - createTestUsers(t, ds) - - var loginTests = []struct { - name string - email string - password string - wantErr error - }{ - { - name: "admin1", - email: testUsers["admin1"].Email, - password: testUsers["admin1"].PlaintextPassword, - }, - { - name: "user1", - email: testUsers["user1"].Email, - password: testUsers["user1"].PlaintextPassword, - }, - } - - for _, tt := range loginTests { - t.Run(tt.email, func(st *testing.T) { - loggedIn, token, err := svc.Login(test.UserContext(test.UserAdmin), tt.email, tt.password) - require.Nil(st, err, "login unsuccessful") - assert.Equal(st, tt.email, loggedIn.Email) - assert.NotEmpty(st, token) - - sessions, err := svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), loggedIn.ID) - require.Nil(st, err) - require.Len(st, sessions, 1, "user should have one session") - session := sessions[0] - assert.NotZero(st, session.UserID) - assert.WithinDuration(st, time.Now(), session.AccessedAt, 3*time.Second, - "access time should be set with current time at session creation") - }) - } -} - -func TestGetSessionByKey(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - cfg := config.TestConfig() - - theSession := &fleet.Session{UserID: 123, Key: "abc"} - - ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) { - return theSession, nil - } - ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error { - return nil - } - ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error { - return nil - } - - cases := []struct { - desc string - accessed time.Duration - apiOnly bool - fail bool - }{ - {"real user, accessed recently", -1 * time.Hour, false, false}, - {"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true}, - {"api-only, accessed recently", -1 * time.Hour, true, false}, - {"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false}, - } - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - var authErr *fleet.AuthRequiredError - ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false - - theSession.AccessedAt = time.Now().Add(tc.accessed) - theSession.APIOnly = ptr.Bool(tc.apiOnly) - _, err := svc.GetSessionByKey(context.Background(), theSession.Key) - if tc.fail { - require.Error(t, err) - require.ErrorAs(t, err, &authErr) - require.True(t, ds.SessionByKeyFuncInvoked) - require.True(t, ds.DestroySessionFuncInvoked) - require.False(t, ds.MarkSessionAccessedFuncInvoked) - } else { - require.NoError(t, err) - require.True(t, ds.SessionByKeyFuncInvoked) - require.False(t, ds.DestroySessionFuncInvoked) - require.True(t, ds.MarkSessionAccessedFuncInvoked) - } - }) - } -} diff --git a/server/service/service_teams.go b/server/service/service_teams.go deleted file mode 100644 index ed5236c91a..0000000000 --- a/server/service/service_teams.go +++ /dev/null @@ -1,15 +0,0 @@ -package service - -import ( - "context" - - "github.com/fleetdm/fleet/v4/server/fleet" -) - -func (svc *Service) ListAvailableTeamsForUser(ctx context.Context, user *fleet.User) ([]*fleet.TeamSummary, error) { - // skipauth: No authorization check needed due to implementation returning - // only license error. - svc.authz.SkipAuthorization(ctx) - - return nil, fleet.ErrMissingLicense -} diff --git a/server/service/service_users.go b/server/service/service_users.go index ddde508650..b6a5f1b0b1 100644 --- a/server/service/service_users.go +++ b/server/service/service_users.go @@ -2,45 +2,14 @@ package service import ( "context" - "encoding/base64" - "html/template" - "time" "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/mail" "github.com/fleetdm/fleet/v4/server/ptr" ) -func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { - // skipauth: There is no viewer context at this point. We rely on verifying - // the invite for authNZ. - svc.authz.SkipAuthorization(ctx) - - invite, err := svc.VerifyInvite(ctx, *p.InviteToken) - if err != nil { - return nil, err - } - - // set the payload role property based on an existing invite. - p.GlobalRole = invite.GlobalRole.Ptr() - p.Teams = &invite.Teams - - user, err := svc.newUser(ctx, p) - if err != nil { - return nil, err - } - - err = svc.ds.DeleteInvite(ctx, invite.ID) - if err != nil { - return nil, err - } - return user, nil -} - func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { // skipauth: Only the initial user creation should be allowed to skip // authorization (because there is not yet a user context to check against). @@ -88,154 +57,3 @@ func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User, // Explicitly no authorization check. Should only be used by middleware. return svc.ds.UserByID(ctx, id) } - -// setNewPassword is a helper for changing a user's password. It should be -// called to set the new password after proper authorization has been -// performed. -func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, password string) error { - err := user.SetPassword(password, svc.config.Auth.SaltKeySize, svc.config.Auth.BcryptCost) - if err != nil { - return ctxerr.Wrap(ctx, err, "setting new password") - } - if user.SSOEnabled { - return ctxerr.New(ctx, "set password for single sign on user not allowed") - } - err = svc.saveUser(ctx, user) - if err != nil { - return ctxerr.Wrap(ctx, err, "saving changed password") - } - - return nil -} - -func (svc *Service) ResetPassword(ctx context.Context, token, password string) error { - // skipauth: No viewer context available. The user is locked out of their - // account and authNZ is performed entirely by providing a valid password - // reset token. - svc.authz.SkipAuthorization(ctx) - - reset, err := svc.ds.FindPassswordResetByToken(ctx, token) - if err != nil { - return ctxerr.Wrap(ctx, err, "looking up reset by token") - } - user, err := svc.ds.UserByID(ctx, reset.UserID) - if err != nil { - return ctxerr.Wrap(ctx, err, "retrieving user") - } - - if user.SSOEnabled { - return ctxerr.New(ctx, "password reset for single sign on user not allowed") - } - - // prevent setting the same password - if err := user.ValidatePassword(password); err == nil { - return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") - } - - err = svc.setNewPassword(ctx, user, password) - if err != nil { - return ctxerr.Wrap(ctx, err, "setting new password") - } - - // delete password reset tokens for user - if err := svc.ds.DeletePasswordResetRequestsForUser(ctx, user.ID); err != nil { - return ctxerr.Wrap(ctx, err, "delete password reset requests") - } - - // Clear sessions so that any other browsers will have to log in with - // the new password - if err := svc.ds.DestroyAllSessionsForUser(ctx, user.ID); err != nil { - return ctxerr.Wrap(ctx, err, "delete user sessions") - } - - return nil -} - -func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password string) (*fleet.User, error) { - vc, ok := viewer.FromContext(ctx) - if !ok { - return nil, fleet.ErrNoContext - } - user := vc.User - - if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil { - return nil, err - } - - if user.SSOEnabled { - return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed") - } - if !user.IsAdminForcedPasswordReset() { - return nil, ctxerr.New(ctx, "user does not require password reset") - } - - // prevent setting the same password - if err := user.ValidatePassword(password); err == nil { - return nil, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") - } - - user.AdminForcedPasswordReset = false - err := svc.setNewPassword(ctx, user, password) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "setting new password") - } - - // Sessions should already have been cleared when the reset was - // required - - return user, nil -} - -func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error { - // skipauth: No viewer context available. The user is locked out of their - // account and trying to reset their password. - svc.authz.SkipAuthorization(ctx) - - // Regardless of error, sleep until the request has taken at least 1 second. - // This means that any request to this method will take ~1s and frustrate a timing attack. - defer func(start time.Time) { - time.Sleep(time.Until(start.Add(1 * time.Second))) - }(time.Now()) - - user, err := svc.ds.UserByEmail(ctx, email) - if err != nil { - return err - } - if user.SSOEnabled { - return ctxerr.New(ctx, "password reset for single sign on user not allowed") - } - - random, err := server.GenerateRandomText(svc.config.App.TokenKeySize) - if err != nil { - return err - } - token := base64.URLEncoding.EncodeToString([]byte(random)) - - request := &fleet.PasswordResetRequest{ - ExpiresAt: time.Now().Add(time.Hour * 24), - UserID: user.ID, - Token: token, - } - _, err = svc.ds.NewPasswordResetRequest(ctx, request) - if err != nil { - return err - } - - config, err := svc.ds.AppConfig(ctx) - if err != nil { - return err - } - - resetEmail := fleet.Email{ - Subject: "Reset Your Fleet Password", - To: []string{user.Email}, - Config: config, - Mailer: &mail.PasswordResetMailer{ - BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix), - AssetURL: getAssetURL(), - Token: token, - }, - } - - return svc.mailService.SendEmail(resetEmail) -} diff --git a/server/service/service_users_test.go b/server/service/service_users_test.go deleted file mode 100644 index f87b354ec1..0000000000 --- a/server/service/service_users_test.go +++ /dev/null @@ -1,188 +0,0 @@ -package service - -import ( - "context" - "database/sql" - "testing" - "time" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/datastore/mysql" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAuthenticatedUser(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - - createTestUsers(t, ds) - svc := newTestService(ds, nil, nil) - admin1, err := ds.UserByEmail(context.Background(), "admin1@example.com") - assert.Nil(t, err) - admin1Session, err := ds.NewSession(context.Background(), &fleet.Session{ - UserID: admin1.ID, - Key: "admin1", - }) - assert.Nil(t, err) - - ctx := context.Background() - ctx = viewer.NewContext(ctx, viewer.Viewer{User: admin1, Session: admin1Session}) - user, err := svc.AuthenticatedUser(ctx) - assert.Nil(t, err) - assert.Equal(t, user, admin1) -} - -func TestResetPassword(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - - svc := newTestService(ds, nil, nil) - createTestUsers(t, ds) - passwordResetTests := []struct { - token string - newPassword string - wantErr error - }{ - { // all good - token: "abcd", - newPassword: "123cat!", - }, - { // prevent reuse - token: "abcd", - newPassword: "123cat!", - wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"), - }, - { // bad token - token: "dcbaz", - newPassword: "123cat!", - wantErr: sql.ErrNoRows, - }, - { // missing token - newPassword: "123cat!", - wantErr: fleet.NewInvalidArgumentError("token", "Token cannot be empty field"), - }, - } - - for _, tt := range passwordResetTests { - t.Run("", func(t *testing.T) { - request := &fleet.PasswordResetRequest{ - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - CreateTimestamp: fleet.CreateTimestamp{ - CreatedAt: time.Now(), - }, - UpdateTimestamp: fleet.UpdateTimestamp{ - UpdatedAt: time.Now(), - }, - }, - ExpiresAt: time.Now().Add(time.Hour * 24), - UserID: 1, - Token: "abcd", - } - _, err := ds.NewPasswordResetRequest(context.Background(), request) - assert.Nil(t, err) - - serr := svc.ResetPassword(test.UserContext(&fleet.User{ID: 1}), tt.token, tt.newPassword) - if tt.wantErr != nil { - assert.Equal(t, tt.wantErr.Error(), ctxerr.Cause(serr).Error()) - } else { - assert.Nil(t, serr) - } - }) - } -} - -func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context { - reloadedUser, err := ds.UserByEmail(ctx, user.Email) - require.NoError(t, err) - - return viewer.NewContext(ctx, viewer.Viewer{User: reloadedUser, Session: session}) -} - -func TestPerformRequiredPasswordReset(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - - svc := newTestService(ds, nil, nil) - - createTestUsers(t, ds) - - for _, tt := range testUsers { - t.Run(tt.Email, func(t *testing.T) { - user, err := ds.UserByEmail(context.Background(), tt.Email) - require.Nil(t, err) - - ctx := context.Background() - - _, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true) - require.Nil(t, err) - - ctx = refreshCtx(t, ctx, user, ds, nil) - - session, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID}) - require.Nil(t, err) - ctx = refreshCtx(t, ctx, user, ds, session) - - // should error when reset not required - _, err = svc.RequirePasswordReset(ctx, user.ID, false) - require.Nil(t, err) - ctx = refreshCtx(t, ctx, user, ds, session) - _, err = svc.PerformRequiredPasswordReset(ctx, "new_pass") - require.NotNil(t, err) - - _, err = svc.RequirePasswordReset(ctx, user.ID, true) - require.Nil(t, err) - ctx = refreshCtx(t, ctx, user, ds, session) - - // should error when using same password - _, err = svc.PerformRequiredPasswordReset(ctx, tt.PlaintextPassword) - require.Equal(t, "validation failed: new_password cannot reuse old password", err.Error()) - - // should succeed with good new password - u, err := svc.PerformRequiredPasswordReset(ctx, "new_pass") - require.Nil(t, err) - assert.False(t, u.AdminForcedPasswordReset) - - ctx = context.Background() - - // Now user should be able to login with new password - u, _, err = svc.Login(ctx, tt.Email, "new_pass") - require.Nil(t, err) - assert.False(t, u.AdminForcedPasswordReset) - }) - } -} - -func TestUserPasswordRequirements(t *testing.T) { - passwordTests := []struct { - password string - wantErr bool - }{ - { - password: "foobar", - wantErr: true, - }, - { - password: "foobarbaz", - wantErr: true, - }, - { - password: "foobarbaz!", - wantErr: true, - }, - { - password: "foobarbaz!3", - }, - } - - for _, tt := range passwordTests { - t.Run(tt.password, func(t *testing.T) { - err := validatePasswordRequirements(tt.password) - if tt.wantErr { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - } - }) - } -} diff --git a/server/service/sessions.go b/server/service/sessions.go index 41de763afd..a2ce672098 100644 --- a/server/service/sessions.go +++ b/server/service/sessions.go @@ -1,10 +1,25 @@ package service import ( + "bytes" "context" + "crypto/rand" + "encoding/base64" + "encoding/xml" + "errors" + "fmt" + "html/template" + "net/http" + "net/url" + "strings" "time" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/logging" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/sso" + "github.com/go-kit/kit/log/level" ) //////////////////////////////////////////////////////////////////////////////// @@ -91,3 +106,480 @@ func (svc *Service) DeleteSession(ctx context.Context, id uint) error { return svc.ds.DestroySession(ctx, session) } + +//////////////////////////////////////////////////////////////////////////////// +// Login +//////////////////////////////////////////////////////////////////////////////// + +type loginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type loginResponse struct { + User *fleet.User `json:"user,omitempty"` + AvailableTeams []*fleet.TeamSummary `json:"available_teams"` + Token string `json:"token,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r loginResponse) error() error { return r.Err } + +func loginEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*loginRequest) + req.Email = strings.ToLower(req.Email) + + user, token, err := svc.Login(ctx, req.Email, req.Password) + if err != nil { + return loginResponse{Err: err}, nil + } + // Add viewer context allow access to service teams for list of available teams + v, err := authViewer(ctx, token, svc) + if err != nil { + return loginResponse{Err: err}, nil + } + ctx = viewer.NewContext(ctx, *v) + availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user) + if err != nil { + if errors.Is(err, fleet.ErrMissingLicense) { + availableTeams = []*fleet.TeamSummary{} + } else { + return loginResponse{Err: err}, nil + } + } + return loginResponse{user, availableTeams, token, nil}, nil +} + +func (svc *Service) Login(ctx context.Context, email, password string) (*fleet.User, string, error) { + // skipauth: No user context available yet to authorize against. + svc.authz.SkipAuthorization(ctx) + + logging.WithLevel(logging.WithNoUser(ctx), level.Info) + + // If there is an error, sleep until the request has taken at least 1 + // second. This means that generally a login failure for any reason will + // take ~1s and frustrate a timing attack. + var err error + defer func(start time.Time) { + if err != nil { + time.Sleep(time.Until(start.Add(1 * time.Second))) + } + }(time.Now()) + + user, err := svc.ds.UserByEmail(ctx, email) + var nfe fleet.NotFoundError + if errors.As(err, &nfe) { + return nil, "", fleet.NewAuthFailedError("user not found") + } + if err != nil { + return nil, "", fleet.NewAuthFailedError(err.Error()) + } + + if err = user.ValidatePassword(password); err != nil { + return nil, "", fleet.NewAuthFailedError("invalid password") + } + + if user.SSOEnabled { + return nil, "", fleet.NewAuthFailedError("password login disabled for sso users") + } + + token, err := svc.makeSession(ctx, user.ID) + if err != nil { + return nil, "", fleet.NewAuthFailedError(err.Error()) + } + + return user, token, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Logout +//////////////////////////////////////////////////////////////////////////////// + +type logoutResponse struct { + Err error `json:"error,omitempty"` +} + +func (r logoutResponse) error() error { return r.Err } + +func logoutEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + err := svc.Logout(ctx) + if err != nil { + return logoutResponse{Err: err}, nil + } + return logoutResponse{}, nil +} + +func (svc *Service) Logout(ctx context.Context) error { + // skipauth: Any user can always log out of their own session. + svc.authz.SkipAuthorization(ctx) + + logging.WithLevel(ctx, level.Info) + + // TODO: this should not return an error if the user wasn't logged in + return svc.DestroySession(ctx) +} + +func (svc *Service) DestroySession(ctx context.Context) error { + vc, ok := viewer.FromContext(ctx) + if !ok { + return fleet.ErrNoContext + } + + session, err := svc.ds.SessionByID(ctx, vc.SessionID()) + if err != nil { + return err + } + + if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil { + return err + } + + return svc.ds.DestroySession(ctx, session) +} + +//////////////////////////////////////////////////////////////////////////////// +// Initiate SSO +//////////////////////////////////////////////////////////////////////////////// + +type initiateSSORequest struct { + RelayURL string `json:"relay_url"` +} + +type initiateSSOResponse struct { + URL string `json:"url,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r initiateSSOResponse) error() error { return r.Err } + +func initiateSSOEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*initiateSSORequest) + idProviderURL, err := svc.InitiateSSO(ctx, req.RelayURL) + if err != nil { + return initiateSSOResponse{Err: err}, nil + } + return initiateSSOResponse{URL: idProviderURL}, nil +} + +// InitiateSSO initiates a Single Sign-On flow for a request to visit the +// protected URL identified by redirectURL. It returns the URL of the identity +// provider to make a request to to proceed with the authentication via that +// external service, and stores ephemeral session state to validate the +// callback from the identity provider to finalize the SSO flow. +func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string, error) { + // skipauth: User context does not yet exist. Unauthenticated users may + // initiate SSO. + svc.authz.SkipAuthorization(ctx) + + logging.WithLevel(ctx, level.Info) + + appConfig, err := svc.ds.AppConfig(ctx) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config") + } + + metadata, err := svc.getMetadata(appConfig) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata") + } + + serverURL := appConfig.ServerSettings.ServerURL + settings := sso.Settings{ + Metadata: metadata, + // Construct call back url to send to idp + AssertionConsumerServiceURL: serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/sso/callback", + SessionStore: svc.ssoSessionStore, + OriginalURL: redirectURL, + } + + // If issuer is not explicitly set, default to host name. + var issuer string + entityID := appConfig.SSOSettings.EntityID + if entityID == "" { + u, err := url.Parse(serverURL) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "parse server url") + } + issuer = u.Hostname() + } else { + issuer = entityID + } + + idpURL, err := sso.CreateAuthorizationRequest(&settings, issuer) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization") + } + + return idpURL, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Callback SSO +//////////////////////////////////////////////////////////////////////////////// + +type callbackSSORequest struct{} + +func (callbackSSORequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) { + err := r.ParseForm() + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "decode sso callback") + } + authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse")) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "decoding sso callback") + } + return authResponse, nil +} + +type callbackSSOResponse struct { + content string + Err error `json:"error,omitempty"` +} + +func (r callbackSSOResponse) error() error { return r.Err } + +// If html is present we return a web page +func (r callbackSSOResponse) html() string { return r.content } + +func makeCallbackSSOEndpoint(urlPrefix string) handlerFunc { + return func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + authResponse := request.(fleet.Auth) + session, err := svc.CallbackSSO(ctx, authResponse) + var resp callbackSSOResponse + if err != nil { + // redirect to login page on front end if there was some problem, + // errors should still be logged + session = &fleet.SSOSession{ + RedirectURL: urlPrefix + "/login", + Token: "", + } + resp.Err = err + } + relayStateLoadPage := ` + + + Redirecting to Fleet at {{ .RedirectURL }} ... + + + ` + tmpl, err := template.New("relayStateLoader").Parse(relayStateLoadPage) + if err != nil { + return nil, err + } + var writer bytes.Buffer + err = tmpl.Execute(&writer, session) + if err != nil { + return nil, err + } + resp.content = writer.String() + return resp, nil + } +} + +func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SSOSession, error) { + // skipauth: User context does not yet exist. Unauthenticated users may + // hit the SSO callback. + svc.authz.SkipAuthorization(ctx) + + logging.WithLevel(ctx, level.Info) + + appConfig, err := svc.ds.AppConfig(ctx) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "get config for sso") + } + + // Load the request metadata if available + + // localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080 + var metadata *sso.Metadata + var redirectURL string + + if appConfig.SSOSettings.EnableSSOIdPLogin && auth.RequestID() == "" { + // Missing request ID indicates this was IdP-initiated. Only allow if + // configured to do so. + metadata, err = svc.getMetadata(appConfig) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "get sso metadata") + } + redirectURL = "/" + } else { + session, err := svc.ssoSessionStore.Get(auth.RequestID()) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "sso request invalid") + } + // Remove session to so that is can't be reused before it expires. + err = svc.ssoSessionStore.Expire(auth.RequestID()) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "remove sso request") + } + if err := xml.Unmarshal([]byte(session.Metadata), &metadata); err != nil { + return nil, ctxerr.Wrap(ctx, err, "unmarshal metadata") + } + redirectURL = session.OriginalURL + } + + // Validate response + validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience( + appConfig.SSOSettings.EntityID, + appConfig.ServerSettings.ServerURL, + appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS + )) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "create validator from metadata") + } + // make sure the response hasn't been tampered with + auth, err = validator.ValidateSignature(auth) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "signature validation failed") + } + // make sure the response isn't stale + err = validator.ValidateResponse(auth) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "response validation failed") + } + + // Get and log in user + user, err := svc.ds.UserByEmail(ctx, auth.UserID()) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "find user in sso callback") + } + // if the user is not sso enabled they are not authorized + if !user.SSOEnabled { + return nil, ctxerr.New(ctx, "user not configured to use sso") + } + token, err := svc.makeSession(ctx, user.ID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "make session in sso callback") + } + result := &fleet.SSOSession{ + Token: token, + RedirectURL: redirectURL, + } + return result, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// SSO Settings +//////////////////////////////////////////////////////////////////////////////// + +type ssoSettingsResponse struct { + Settings *fleet.SessionSSOSettings `json:"settings,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r ssoSettingsResponse) error() error { return r.Err } + +func settingsSSOEndpoint(ctx context.Context, _ interface{}, svc fleet.Service) (interface{}, error) { + settings, err := svc.SSOSettings(ctx) + if err != nil { + return ssoSettingsResponse{Err: err}, nil + } + return ssoSettingsResponse{Settings: settings}, nil +} + +// SSOSettings returns a subset of the Single Sign-On settings as configured in +// the app config. Those can be exposed e.g. via the response to an HTTP request, +// and as such should not contain sensitive information. +func (svc *Service) SSOSettings(ctx context.Context) (*fleet.SessionSSOSettings, error) { + // skipauth: Basic SSO settings are available to unauthenticated users (so + // that they have the necessary information to initiate SSO). + svc.authz.SkipAuthorization(ctx) + + logging.WithLevel(ctx, level.Info) + + appConfig, err := svc.ds.AppConfig(ctx) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "SessionSSOSettings getting app config") + } + + settings := &fleet.SessionSSOSettings{ + IDPName: appConfig.SSOSettings.IDPName, + IDPImageURL: appConfig.SSOSettings.IDPImageURL, + SSOEnabled: appConfig.SSOSettings.EnableSSO, + } + return settings, nil +} + +// makeSession is a helper that creates a new session after authentication +func (svc *Service) makeSession(ctx context.Context, id uint) (string, error) { + sessionKeySize := svc.config.Session.KeySize + key := make([]byte, sessionKeySize) + _, err := rand.Read(key) + if err != nil { + return "", err + } + + sessionKey := base64.StdEncoding.EncodeToString(key) + session := &fleet.Session{ + UserID: id, + Key: sessionKey, + AccessedAt: time.Now().UTC(), + } + + _, err = svc.ds.NewSession(ctx, session) + if err != nil { + return "", ctxerr.Wrap(ctx, err, "creating new session") + } + + return sessionKey, nil +} + +func (svc *Service) getMetadata(config *fleet.AppConfig) (*sso.Metadata, error) { + if config.SSOSettings.MetadataURL != "" { + metadata, err := sso.GetMetadata(config.SSOSettings.MetadataURL) + if err != nil { + return nil, err + } + return metadata, nil + } + + if config.SSOSettings.Metadata != "" { + metadata, err := sso.ParseMetadata(config.SSOSettings.Metadata) + if err != nil { + return nil, err + } + return metadata, nil + } + + return nil, fmt.Errorf("missing metadata for idp %s", config.SSOSettings.IDPName) +} + +func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) { + session, err := svc.ds.SessionByKey(ctx, key) + if err != nil { + return nil, err + } + + err = svc.validateSession(ctx, session) + if err != nil { + return nil, err + } + + return session, nil +} + +func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error { + if session == nil { + return fleet.NewAuthRequiredError("active session not present") + } + + sessionDuration := svc.config.Session.Duration + if session.APIOnly != nil && *session.APIOnly { + sessionDuration = 0 // make API-only tokens unlimited + } + + // duration 0 = unlimited + if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration { + err := svc.ds.DestroySession(ctx, session) + if err != nil { + return ctxerr.Wrap(ctx, err, "destroying session") + } + return fleet.NewAuthRequiredError("expired session") + } + + return svc.ds.MarkSessionAccessed(ctx, session) +} diff --git a/server/service/sessions_test.go b/server/service/sessions_test.go index 1784b0c7d5..8fdf955a07 100644 --- a/server/service/sessions_test.go +++ b/server/service/sessions_test.go @@ -5,10 +5,15 @@ import ( "testing" "time" + "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/fleetdm/fleet/v4/server/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSessionAuth(t *testing.T) { @@ -85,3 +90,98 @@ func TestSessionAuth(t *testing.T) { }) } } + +func TestAuthenticate(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + defer ds.Close() + + svc := newTestService(ds, nil, nil) + createTestUsers(t, ds) + + var loginTests = []struct { + name string + email string + password string + wantErr error + }{ + { + name: "admin1", + email: testUsers["admin1"].Email, + password: testUsers["admin1"].PlaintextPassword, + }, + { + name: "user1", + email: testUsers["user1"].Email, + password: testUsers["user1"].PlaintextPassword, + }, + } + + for _, tt := range loginTests { + t.Run(tt.email, func(st *testing.T) { + loggedIn, token, err := svc.Login(test.UserContext(test.UserAdmin), tt.email, tt.password) + require.Nil(st, err, "login unsuccessful") + assert.Equal(st, tt.email, loggedIn.Email) + assert.NotEmpty(st, token) + + sessions, err := svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), loggedIn.ID) + require.Nil(st, err) + require.Len(st, sessions, 1, "user should have one session") + session := sessions[0] + assert.NotZero(st, session.UserID) + assert.WithinDuration(st, time.Now(), session.AccessedAt, 3*time.Second, + "access time should be set with current time at session creation") + }) + } +} + +func TestGetSessionByKey(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + cfg := config.TestConfig() + + theSession := &fleet.Session{UserID: 123, Key: "abc"} + + ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) { + return theSession, nil + } + ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error { + return nil + } + ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error { + return nil + } + + cases := []struct { + desc string + accessed time.Duration + apiOnly bool + fail bool + }{ + {"real user, accessed recently", -1 * time.Hour, false, false}, + {"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true}, + {"api-only, accessed recently", -1 * time.Hour, true, false}, + {"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false}, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var authErr *fleet.AuthRequiredError + ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false + + theSession.AccessedAt = time.Now().Add(tc.accessed) + theSession.APIOnly = ptr.Bool(tc.apiOnly) + _, err := svc.GetSessionByKey(context.Background(), theSession.Key) + if tc.fail { + require.Error(t, err) + require.ErrorAs(t, err, &authErr) + require.True(t, ds.SessionByKeyFuncInvoked) + require.True(t, ds.DestroySessionFuncInvoked) + require.False(t, ds.MarkSessionAccessedFuncInvoked) + } else { + require.NoError(t, err) + require.True(t, ds.SessionByKeyFuncInvoked) + require.False(t, ds.DestroySessionFuncInvoked) + require.True(t, ds.MarkSessionAccessedFuncInvoked) + } + }) + } +} diff --git a/server/service/testing_client.go b/server/service/testing_client.go index 6e44782ea5..270f673a54 100644 --- a/server/service/testing_client.go +++ b/server/service/testing_client.go @@ -36,9 +36,10 @@ func (ts *withDS) TearDownSuite() { type withServer struct { withDS - server *httptest.Server - users map[string]fleet.User - token string + server *httptest.Server + users map[string]fleet.User + token string + cachedAdminToken string } func (ts *withServer) SetupSuite(dbName string) { @@ -49,6 +50,7 @@ func (ts *withServer) SetupSuite(dbName string) { ts.server = server ts.users = users ts.token = ts.getTestAdminToken() + ts.cachedAdminToken = ts.token } func (ts *withServer) TearDownSuite() { @@ -122,7 +124,13 @@ func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStat func (ts *withServer) getTestAdminToken() string { testUser := testUsers["admin1"] - return ts.getTestToken(testUser.Email, testUser.PlaintextPassword) + // because the login endpoint is rate-limited, use the cached admin token + // if available (if for some reason a test needs to logout the admin user, + // then set cachedAdminToken = "" so that a new token is retrieved). + if ts.cachedAdminToken == "" { + ts.cachedAdminToken = ts.getTestToken(testUser.Email, testUser.PlaintextPassword) + } + return ts.cachedAdminToken } func (ts *withServer) getTestToken(email string, password string) string { diff --git a/server/service/translator.go b/server/service/translator.go index 7919e0da99..cefc12b2ab 100644 --- a/server/service/translator.go +++ b/server/service/translator.go @@ -60,7 +60,7 @@ func translateHostToID(ctx context.Context, ds fleet.Datastore, identifier strin return host.ID, nil } -func (svc Service) Translate(ctx context.Context, payloads []fleet.TranslatePayload) ([]fleet.TranslatePayload, error) { +func (svc *Service) Translate(ctx context.Context, payloads []fleet.TranslatePayload) ([]fleet.TranslatePayload, error) { var finalPayload []fleet.TranslatePayload for _, payload := range payloads { diff --git a/server/service/transport.go b/server/service/transport.go index 61a2295265..ddbe5dbea8 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -294,10 +294,6 @@ func userListOptionsFromRequest(r *http.Request) (fleet.UserListOptions, error) return uopt, nil } -func decodeNoParamsRequest(ctx context.Context, r *http.Request) (interface{}, error) { - return nil, nil -} - type getGenericSpecRequest struct { Name string `url:"name"` } diff --git a/server/service/transport_carves.go b/server/service/transport_carves.go deleted file mode 100644 index aa9e46f7c0..0000000000 --- a/server/service/transport_carves.go +++ /dev/null @@ -1,20 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" -) - -func decodeCarveBlockRequest(ctx context.Context, r *http.Request) (interface{}, error) { - defer r.Body.Close() - - var req carveBlockRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding JSON") - } - - return req, nil -} diff --git a/server/service/transport_invites.go b/server/service/transport_invites.go deleted file mode 100644 index 442ed20a3e..0000000000 --- a/server/service/transport_invites.go +++ /dev/null @@ -1,17 +0,0 @@ -package service - -import ( - "context" - "net/http" - - "github.com/gorilla/mux" -) - -func decodeVerifyInviteRequest(ctx context.Context, r *http.Request) (interface{}, error) { - vars := mux.Vars(r) - token, ok := vars["token"] - if !ok { - return 0, errBadRoute - } - return verifyInviteRequest{Token: token}, nil -} diff --git a/server/service/transport_osquery.go b/server/service/transport_osquery.go deleted file mode 100644 index 3ee0aba0c8..0000000000 --- a/server/service/transport_osquery.go +++ /dev/null @@ -1,17 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" -) - -func decodeEnrollAgentRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req enrollAgentRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - defer r.Body.Close() - - return req, nil -} diff --git a/server/service/transport_osquery_test.go b/server/service/transport_osquery_test.go deleted file mode 100644 index 68503ca205..0000000000 --- a/server/service/transport_osquery_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package service - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDecodeEnrollAgentRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeEnrollAgentRequest(context.Background(), request) - require.Nil(t, err) - - params := r.(enrollAgentRequest) - assert.Equal(t, "secret", params.EnrollSecret) - assert.Equal(t, "uuid", params.HostIdentifier) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "enroll_secret": "secret", - "host_identifier": "uuid" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/", &body), - ) -} diff --git a/server/service/transport_sessions.go b/server/service/transport_sessions.go deleted file mode 100644 index 07d815e547..0000000000 --- a/server/service/transport_sessions.go +++ /dev/null @@ -1,41 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" - "strings" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/sso" -) - -func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req loginRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - req.Email = strings.ToLower(req.Email) - return req, nil -} - -func decodeInitiateSSORequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req initiateSSORequest - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - return nil, err - } - return req, nil -} - -func decodeCallbackSSORequest(ctx context.Context, r *http.Request) (interface{}, error) { - err := r.ParseForm() - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "decode sso callback") - } - authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse")) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding sso callback") - } - return authResponse, nil -} diff --git a/server/service/transport_sessions_test.go b/server/service/transport_sessions_test.go deleted file mode 100644 index 3a9060c5a1..0000000000 --- a/server/service/transport_sessions_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package service - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" -) - -func TestDecodeLoginRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeLoginRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(loginRequest) - assert.Equal(t, "foo", params.Email) - assert.Equal(t, "bar", params.Password) - }).Methods("POST") - t.Run("lowercase email", func(t *testing.T) { - var body bytes.Buffer - body.Write([]byte(`{ - "email": "foo", - "password": "bar" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/api/v1/fleet/login", &body), - ) - }) - t.Run("uppercase email", func(t *testing.T) { - var body bytes.Buffer - body.Write([]byte(`{ - "email": "Foo", - "password": "bar" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/api/v1/fleet/login", &body), - ) - }) - -} diff --git a/server/service/transport_users.go b/server/service/transport_users.go deleted file mode 100644 index 88e1854b05..0000000000 --- a/server/service/transport_users.go +++ /dev/null @@ -1,42 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" -) - -func decodeCreateUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req createUserRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - - return req, nil -} - -func decodePerformRequiredPasswordResetRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req performRequiredPasswordResetRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding JSON") - } - return req, nil -} - -func decodeForgotPasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req forgotPasswordRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - return req, nil -} - -func decodeResetPasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req resetPasswordRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - return req, nil -} diff --git a/server/service/transport_users_test.go b/server/service/transport_users_test.go deleted file mode 100644 index 7d6e9ff571..0000000000 --- a/server/service/transport_users_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package service - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" -) - -func TestDecodeResetPasswordRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/users/{id}/password", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeResetPasswordRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(resetPasswordRequest) - assert.Equal(t, "bar", params.NewPassword) - assert.Equal(t, "baz", params.PasswordResetToken) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "new_password": "bar", - "password_reset_token": "baz" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/api/v1/fleet/users/1/password", &body), - ) -} diff --git a/server/service/user_roles.go b/server/service/user_roles.go index bb232301d5..1dc1dde404 100644 --- a/server/service/user_roles.go +++ b/server/service/user_roles.go @@ -26,7 +26,7 @@ func applyUserRoleSpecsEndpoint(ctx context.Context, request interface{}, svc fl return applyUserRoleSpecsResponse{}, nil } -func (svc Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRoleSpec) error { +func (svc *Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRoleSpec) error { if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionWrite); err != nil { return err } @@ -61,7 +61,7 @@ func (svc Service) ApplyUserRolesSpecs(ctx context.Context, specs fleet.UsersRol return svc.ds.SaveUsers(ctx, users) } -func (svc Service) checkAtLeastOneAdmin(ctx context.Context, user *fleet.User, spec *fleet.UserRoleSpec, email string) error { +func (svc *Service) checkAtLeastOneAdmin(ctx context.Context, user *fleet.User, spec *fleet.UserRoleSpec, email string) error { if null.StringFromPtr(user.GlobalRole).ValueOrZero() == fleet.RoleAdmin && null.StringFromPtr(spec.GlobalRole).ValueOrZero() != fleet.RoleAdmin { users, err := svc.ds.ListUsers(ctx, fleet.UserListOptions{}) diff --git a/server/service/users.go b/server/service/users.go index 863c18f83d..8f2f6ed6c0 100644 --- a/server/service/users.go +++ b/server/service/users.go @@ -6,6 +6,8 @@ import ( "encoding/base64" "errors" "html/template" + "net/http" + "time" "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/authz" @@ -49,6 +51,10 @@ func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet return nil, err } + if err := p.VerifyAdminCreate(); err != nil { + return nil, ctxerr.Wrap(ctx, err, "verify user payload") + } + if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil { return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email) } @@ -61,6 +67,49 @@ func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet return svc.newUser(ctx, p) } +//////////////////////////////////////////////////////////////////////////////// +// Create User From Invite +//////////////////////////////////////////////////////////////////////////////// + +func createUserFromInviteEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*createUserRequest) + user, err := svc.CreateUserFromInvite(ctx, req.UserPayload) + if err != nil { + return createUserResponse{Err: err}, nil + } + return createUserResponse{User: user}, nil +} + +func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { + // skipauth: There is no viewer context at this point. We rely on verifying + // the invite for authNZ. + svc.authz.SkipAuthorization(ctx) + + if err := p.VerifyInviteCreate(); err != nil { + return nil, ctxerr.Wrap(ctx, err, "verify user payload") + } + + invite, err := svc.VerifyInvite(ctx, *p.InviteToken) + if err != nil { + return nil, err + } + + // set the payload role property based on an existing invite. + p.GlobalRole = invite.GlobalRole.Ptr() + p.Teams = &invite.Teams + + user, err := svc.newUser(ctx, p) + if err != nil { + return nil, err + } + + err = svc.ds.DeleteInvite(ctx, invite.ID) + if err != nil { + return nil, err + } + return user, nil +} + //////////////////////////////////////////////////////////////////////////////// // List Users //////////////////////////////////////////////////////////////////////////////// @@ -215,6 +264,14 @@ func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPay return nil, err } + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, ctxerr.New(ctx, "viewer not present") // should never happen, authorize would've failed + } + if err := p.VerifyModify(vc.UserID() == userID); err != nil { + return nil, ctxerr.Wrap(ctx, err, "verify user payload") + } + if p.GlobalRole != nil || p.Teams != nil { if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil { return nil, err @@ -386,13 +443,21 @@ func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) return err } + if oldPass == "" { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty")) + } + if newPass == "" { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "New password cannot be empty")) + } + if err := fleet.ValidatePasswordRequirements(newPass); err != nil { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", err.Error())) + } if vc.User.SSOEnabled { return ctxerr.New(ctx, "change password for single sign on user not allowed") } if err := vc.User.ValidatePassword(newPass); err == nil { - return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")) } - if err := vc.User.ValidatePassword(oldPass); err != nil { return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match")) } @@ -622,3 +687,245 @@ func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, em func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error { return svc.ds.SaveUser(ctx, user) } + +//////////////////////////////////////////////////////////////////////////////// +// Perform Required Password Reset +//////////////////////////////////////////////////////////////////////////////// + +type performRequiredPasswordResetRequest struct { + Password string `json:"new_password"` + ID uint `json:"id"` +} + +type performRequiredPasswordResetResponse struct { + User *fleet.User `json:"user,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r performRequiredPasswordResetResponse) error() error { return r.Err } + +func performRequiredPasswordResetEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*performRequiredPasswordResetRequest) + user, err := svc.PerformRequiredPasswordReset(ctx, req.Password) + if err != nil { + return performRequiredPasswordResetResponse{Err: err}, nil + } + return performRequiredPasswordResetResponse{User: user}, nil +} + +func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password string) (*fleet.User, error) { + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, fleet.ErrNoContext + } + if !vc.CanPerformPasswordReset() { + return nil, fleet.NewPermissionError("cannot reset password") + } + user := vc.User + + if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil { + return nil, err + } + + if user.SSOEnabled { + return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed") + } + if !user.IsAdminForcedPasswordReset() { + return nil, ctxerr.New(ctx, "user does not require password reset") + } + + // prevent setting the same password + if err := user.ValidatePassword(password); err == nil { + return nil, fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") + } + + user.AdminForcedPasswordReset = false + err := svc.setNewPassword(ctx, user, password) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "setting new password") + } + + // Sessions should already have been cleared when the reset was + // required + + return user, nil +} + +// setNewPassword is a helper for changing a user's password. It should be +// called to set the new password after proper authorization has been +// performed. +func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, password string) error { + err := user.SetPassword(password, svc.config.Auth.SaltKeySize, svc.config.Auth.BcryptCost) + if err != nil { + return ctxerr.Wrap(ctx, err, "setting new password") + } + if user.SSOEnabled { + return ctxerr.New(ctx, "set password for single sign on user not allowed") + } + err = svc.saveUser(ctx, user) + if err != nil { + return ctxerr.Wrap(ctx, err, "saving changed password") + } + + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Reset Password +//////////////////////////////////////////////////////////////////////////////// + +type resetPasswordRequest struct { + PasswordResetToken string `json:"password_reset_token"` + NewPassword string `json:"new_password"` +} + +type resetPasswordResponse struct { + Err error `json:"error,omitempty"` +} + +func (r resetPasswordResponse) error() error { return r.Err } + +func resetPasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*resetPasswordRequest) + err := svc.ResetPassword(ctx, req.PasswordResetToken, req.NewPassword) + return resetPasswordResponse{Err: err}, nil +} + +func (svc *Service) ResetPassword(ctx context.Context, token, password string) error { + // skipauth: No viewer context available. The user is locked out of their + // account and authNZ is performed entirely by providing a valid password + // reset token. + svc.authz.SkipAuthorization(ctx) + + if token == "" { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("token", "Token cannot be empty field")) + } + if password == "" { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", "New password cannot be empty field")) + } + if err := fleet.ValidatePasswordRequirements(password); err != nil { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("new_password", err.Error())) + } + + reset, err := svc.ds.FindPasswordResetByToken(ctx, token) + if err != nil { + return ctxerr.Wrap(ctx, err, "looking up reset by token") + } + user, err := svc.ds.UserByID(ctx, reset.UserID) + if err != nil { + return ctxerr.Wrap(ctx, err, "retrieving user") + } + + if user.SSOEnabled { + return ctxerr.New(ctx, "password reset for single sign on user not allowed") + } + + // prevent setting the same password + if err := user.ValidatePassword(password); err == nil { + return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") + } + + err = svc.setNewPassword(ctx, user, password) + if err != nil { + return ctxerr.Wrap(ctx, err, "setting new password") + } + + // delete password reset tokens for user + if err := svc.ds.DeletePasswordResetRequestsForUser(ctx, user.ID); err != nil { + return ctxerr.Wrap(ctx, err, "delete password reset requests") + } + + // Clear sessions so that any other browsers will have to log in with + // the new password + if err := svc.ds.DestroyAllSessionsForUser(ctx, user.ID); err != nil { + return ctxerr.Wrap(ctx, err, "delete user sessions") + } + + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Forgot Password +//////////////////////////////////////////////////////////////////////////////// + +type forgotPasswordRequest struct { + Email string `json:"email"` +} + +type forgotPasswordResponse struct { + Err error `json:"error,omitempty"` +} + +func (r forgotPasswordResponse) error() error { return r.Err } +func (r forgotPasswordResponse) status() int { return http.StatusAccepted } + +func forgotPasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*forgotPasswordRequest) + // Any error returned by the service should not be returned to the + // client to prevent information disclosure (it will be logged in the + // server logs). + _ = svc.RequestPasswordReset(ctx, req.Email) + return forgotPasswordResponse{}, nil +} + +func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error { + // skipauth: No viewer context available. The user is locked out of their + // account and trying to reset their password. + svc.authz.SkipAuthorization(ctx) + + // Regardless of error, sleep until the request has taken at least 1 second. + // This means that any request to this method will take ~1s and frustrate a timing attack. + defer func(start time.Time) { + time.Sleep(time.Until(start.Add(1 * time.Second))) + }(time.Now()) + + user, err := svc.ds.UserByEmail(ctx, email) + if err != nil { + return err + } + if user.SSOEnabled { + return ctxerr.New(ctx, "password reset for single sign on user not allowed") + } + + random, err := server.GenerateRandomText(svc.config.App.TokenKeySize) + if err != nil { + return err + } + token := base64.URLEncoding.EncodeToString([]byte(random)) + + request := &fleet.PasswordResetRequest{ + ExpiresAt: time.Now().Add(time.Hour * 24), + UserID: user.ID, + Token: token, + } + _, err = svc.ds.NewPasswordResetRequest(ctx, request) + if err != nil { + return err + } + + config, err := svc.ds.AppConfig(ctx) + if err != nil { + return err + } + + resetEmail := fleet.Email{ + Subject: "Reset Your Fleet Password", + To: []string{user.Email}, + Config: config, + Mailer: &mail.PasswordResetMailer{ + BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix), + AssetURL: getAssetURL(), + Token: token, + }, + } + + return svc.mailService.SendEmail(resetEmail) +} + +func (svc *Service) ListAvailableTeamsForUser(ctx context.Context, user *fleet.User) ([]*fleet.TeamSummary, error) { + // skipauth: No authorization check needed due to implementation returning + // only license error. + svc.authz.SkipAuthorization(ctx) + + return nil, fleet.ErrMissingLicense +} diff --git a/server/service/users_test.go b/server/service/users_test.go index afc53ce9e3..6ed4ab0739 100644 --- a/server/service/users_test.go +++ b/server/service/users_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "errors" "testing" "time" @@ -471,6 +472,7 @@ func testUsersChangePassword(t *testing.T, ds *mysql.Datastore) { anyErr: true, }, { // missing old password + user: users["user1@example.com"], newPassword: "123cataaa!", wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"), }, @@ -540,3 +542,141 @@ func testUsersRequirePasswordReset(t *testing.T, ds *mysql.Datastore) { }) } } + +func TestPerformRequiredPasswordReset(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + + svc := newTestService(ds, nil, nil) + + createTestUsers(t, ds) + + for _, tt := range testUsers { + t.Run(tt.Email, func(t *testing.T) { + user, err := ds.UserByEmail(context.Background(), tt.Email) + require.Nil(t, err) + + ctx := context.Background() + + _, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true) + require.Nil(t, err) + + ctx = refreshCtx(t, ctx, user, ds, nil) + + session, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID}) + require.Nil(t, err) + ctx = refreshCtx(t, ctx, user, ds, session) + + // should error when reset not required + _, err = svc.RequirePasswordReset(ctx, user.ID, false) + require.Nil(t, err) + ctx = refreshCtx(t, ctx, user, ds, session) + _, err = svc.PerformRequiredPasswordReset(ctx, "new_pass") + require.NotNil(t, err) + + _, err = svc.RequirePasswordReset(ctx, user.ID, true) + require.Nil(t, err) + ctx = refreshCtx(t, ctx, user, ds, session) + + // should error when using same password + _, err = svc.PerformRequiredPasswordReset(ctx, tt.PlaintextPassword) + require.Equal(t, "validation failed: new_password cannot reuse old password", err.Error()) + + // should succeed with good new password + u, err := svc.PerformRequiredPasswordReset(ctx, "new_pass") + require.Nil(t, err) + assert.False(t, u.AdminForcedPasswordReset) + + ctx = context.Background() + + // Now user should be able to login with new password + u, _, err = svc.Login(ctx, tt.Email, "new_pass") + require.Nil(t, err) + assert.False(t, u.AdminForcedPasswordReset) + }) + } +} + +func TestResetPassword(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + + svc := newTestService(ds, nil, nil) + createTestUsers(t, ds) + passwordResetTests := []struct { + token string + newPassword string + wantErr error + }{ + { // all good + token: "abcd", + newPassword: "123cat!", + }, + { // prevent reuse + token: "abcd", + newPassword: "123cat!", + wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"), + }, + { // bad token + token: "dcbaz", + newPassword: "123cat!", + wantErr: sql.ErrNoRows, + }, + { // missing token + newPassword: "123cat!", + wantErr: fleet.NewInvalidArgumentError("token", "Token cannot be empty field"), + }, + } + + for _, tt := range passwordResetTests { + t.Run("", func(t *testing.T) { + request := &fleet.PasswordResetRequest{ + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + CreateTimestamp: fleet.CreateTimestamp{ + CreatedAt: time.Now(), + }, + UpdateTimestamp: fleet.UpdateTimestamp{ + UpdatedAt: time.Now(), + }, + }, + ExpiresAt: time.Now().Add(time.Hour * 24), + UserID: 1, + Token: "abcd", + } + _, err := ds.NewPasswordResetRequest(context.Background(), request) + assert.Nil(t, err) + + serr := svc.ResetPassword(test.UserContext(&fleet.User{ID: 1}), tt.token, tt.newPassword) + if tt.wantErr != nil { + assert.Equal(t, tt.wantErr.Error(), ctxerr.Cause(serr).Error()) + } else { + assert.Nil(t, serr) + } + }) + } +} + +func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context { + reloadedUser, err := ds.UserByEmail(ctx, user.Email) + require.NoError(t, err) + + return viewer.NewContext(ctx, viewer.Viewer{User: reloadedUser, Session: session}) +} + +func TestAuthenticatedUser(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + + createTestUsers(t, ds) + svc := newTestService(ds, nil, nil) + admin1, err := ds.UserByEmail(context.Background(), "admin1@example.com") + assert.Nil(t, err) + admin1Session, err := ds.NewSession(context.Background(), &fleet.Session{ + UserID: admin1.ID, + Key: "admin1", + }) + assert.Nil(t, err) + + ctx := context.Background() + ctx = viewer.NewContext(ctx, viewer.Viewer{User: admin1, Session: admin1Session}) + user, err := svc.AuthenticatedUser(ctx) + assert.Nil(t, err) + assert.Equal(t, user, admin1) +} diff --git a/server/service/validation_users.go b/server/service/validation_users.go deleted file mode 100644 index 28f41ca9aa..0000000000 --- a/server/service/validation_users.go +++ /dev/null @@ -1,202 +0,0 @@ -package service - -import ( - "context" - "errors" - "unicode" - - "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" - "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/fleet" -) - -func (mw validationMiddleware) CreateUserFromInvite(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { - invalid := &fleet.InvalidArgumentError{} - if p.Name == nil { - invalid.Append("name", "Full name missing required argument") - } else { - if *p.Name == "" { - invalid.Append("name", "Full name cannot be empty") - } - } - - // we don't need a password for single sign on - if p.SSOInvite == nil || !*p.SSOInvite { - if p.Password == nil { - invalid.Append("password", "Password missing required argument") - } else { - if *p.Password == "" { - invalid.Append("password", "Password cannot be empty") - } - if err := validatePasswordRequirements(*p.Password); err != nil { - invalid.Append("password", err.Error()) - } - } - } - - if p.Email == nil { - invalid.Append("email", "Email missing required argument") - } else { - if *p.Email == "" { - invalid.Append("email", "Email cannot be empty") - } - } - - if p.InviteToken == nil { - invalid.Append("invite_token", "Invite token missing required argument") - } else { - if *p.InviteToken == "" { - invalid.Append("invite_token", "Invite token cannot be empty") - } - } - - if invalid.HasErrors() { - return nil, ctxerr.Wrap(ctx, invalid) - } - return mw.Service.CreateUserFromInvite(ctx, p) -} - -func (mw validationMiddleware) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { - invalid := &fleet.InvalidArgumentError{} - if p.Name == nil { - invalid.Append("name", "Full name missing required argument") - } else { - if *p.Name == "" { - invalid.Append("name", "Full name cannot be empty") - } - } - - // we don't need a password for single sign on - if (p.SSOInvite == nil || !*p.SSOInvite) && (p.SSOEnabled == nil || !*p.SSOEnabled) { - if p.Password == nil { - invalid.Append("password", "Password missing required argument") - } else { - if *p.Password == "" { - invalid.Append("password", "Password cannot be empty") - } - // Skip password validation in the case of admin created users - } - } - - if p.SSOEnabled != nil && *p.SSOEnabled && p.Password != nil && len(*p.Password) > 0 { - invalid.Append("password", "not allowed for SSO users") - } - - if p.Email == nil { - invalid.Append("email", "Email missing required argument") - } else { - if *p.Email == "" { - invalid.Append("email", "Email cannot be empty") - } - } - - if p.InviteToken != nil { - invalid.Append("invite_token", "Invite token should not be specified with admin user creation") - } - - if invalid.HasErrors() { - return nil, ctxerr.Wrap(ctx, invalid) - } - return mw.Service.CreateUser(ctx, p) -} - -func (mw validationMiddleware) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) { - invalid := &fleet.InvalidArgumentError{} - if p.Name != nil { - if *p.Name == "" { - invalid.Append("name", "Full name cannot be empty") - } - } - - if p.Email != nil { - if *p.Email == "" { - invalid.Append("email", "Email cannot be empty") - } - // if the user is not an admin, or if an admin is changing their own email - // address a password is required, - if passwordRequiredForEmailChange(ctx, userID, invalid) { - if p.Password == nil { - invalid.Append("password", "Password cannot be empty if email is changed") - } - } - } - - if invalid.HasErrors() { - return nil, ctxerr.Wrap(ctx, invalid) - } - return mw.Service.ModifyUser(ctx, userID, p) -} - -func passwordRequiredForEmailChange(ctx context.Context, uid uint, invalid *fleet.InvalidArgumentError) bool { - vc, ok := viewer.FromContext(ctx) - if !ok { - invalid.Append("viewer", "Viewer not present") - return false - } - // if a user is changing own email need a password no matter what - return vc.UserID() == uid -} - -func (mw validationMiddleware) ChangePassword(ctx context.Context, oldPass, newPass string) error { - invalid := &fleet.InvalidArgumentError{} - if oldPass == "" { - invalid.Append("old_password", "Old password cannot be empty") - } - if newPass == "" { - invalid.Append("new_password", "New password cannot be empty") - } - - if err := validatePasswordRequirements(newPass); err != nil { - invalid.Append("new_password", err.Error()) - } - - if invalid.HasErrors() { - return ctxerr.Wrap(ctx, invalid) - } - return mw.Service.ChangePassword(ctx, oldPass, newPass) -} - -func (mw validationMiddleware) ResetPassword(ctx context.Context, token, password string) error { - invalid := &fleet.InvalidArgumentError{} - if token == "" { - invalid.Append("token", "Token cannot be empty field") - } - if password == "" { - invalid.Append("new_password", "New password cannot be empty field") - } - if err := validatePasswordRequirements(password); err != nil { - invalid.Append("new_password", err.Error()) - } - if invalid.HasErrors() { - return ctxerr.Wrap(ctx, invalid) - } - return mw.Service.ResetPassword(ctx, token, password) -} - -// Requirements for user password: -// at least 7 character length -// at least 1 symbol -// at least 1 number -func validatePasswordRequirements(password string) error { - var ( - number bool - symbol bool - ) - - for _, s := range password { - switch { - case unicode.IsNumber(s): - number = true - case unicode.IsPunct(s) || unicode.IsSymbol(s): - symbol = true - } - } - - if len(password) >= 7 && - number && - symbol { - return nil - } - - return errors.New("Password does not meet validation requirements") -}