diff --git a/server/service/endpoint_sessions.go b/server/service/endpoint_sessions.go index 1019fb0a15..cde324a992 100644 --- a/server/service/endpoint_sessions.go +++ b/server/service/endpoint_sessions.go @@ -91,40 +91,6 @@ func makeGetInfoAboutSessionEndpoint(svc fleet.Service) endpoint.Endpoint { } } -//////////////////////////////////////////////////////////////////////////////// -// Get Info About Sessions For User -//////////////////////////////////////////////////////////////////////////////// - -type getInfoAboutSessionsForUserRequest struct { - ID uint -} - -type getInfoAboutSessionsForUserResponse struct { - Sessions []getInfoAboutSessionResponse `json:"sessions"` - Err error `json:"error,omitempty"` -} - -func (r getInfoAboutSessionsForUserResponse) error() error { return r.Err } - -func makeGetInfoAboutSessionsForUserEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getInfoAboutSessionsForUserRequest) - sessions, err := svc.GetInfoAboutSessionsForUser(ctx, req.ID) - if err != nil { - return getInfoAboutSessionsForUserResponse{Err: err}, nil - } - var resp getInfoAboutSessionsForUserResponse - for _, session := range sessions { - resp.Sessions = append(resp.Sessions, getInfoAboutSessionResponse{ - SessionID: session.ID, - UserID: session.UserID, - CreatedAt: session.CreatedAt, - }) - } - return resp, nil - } -} - //////////////////////////////////////////////////////////////////////////////// // Delete Session //////////////////////////////////////////////////////////////////////////////// @@ -150,31 +116,6 @@ func makeDeleteSessionEndpoint(svc fleet.Service) endpoint.Endpoint { } } -//////////////////////////////////////////////////////////////////////////////// -// Delete Sessions For User -//////////////////////////////////////////////////////////////////////////////// - -type deleteSessionsForUserRequest struct { - ID uint -} - -type deleteSessionsForUserResponse struct { - Err error `json:"error,omitempty"` -} - -func (r deleteSessionsForUserResponse) error() error { return r.Err } - -func makeDeleteSessionsForUserEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(deleteSessionsForUserRequest) - err := svc.DeleteSessionsForUser(ctx, req.ID) - if err != nil { - return deleteSessionsForUserResponse{Err: err}, nil - } - return deleteSessionsForUserResponse{}, nil - } -} - type initiateSSORequest struct { RelayURL string `json:"relay_url"` } diff --git a/server/service/endpoint_users.go b/server/service/endpoint_users.go index b7eba94a42..7905da1214 100644 --- a/server/service/endpoint_users.go +++ b/server/service/endpoint_users.go @@ -12,21 +12,10 @@ import ( // Create User With Invite //////////////////////////////////////////////////////////////////////////////// -type createUserRequest struct { - payload fleet.UserPayload -} - -type createUserResponse struct { - User *fleet.User `json:"user,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r createUserResponse) error() error { return r.Err } - func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(createUserRequest) - user, err := svc.CreateUserFromInvite(ctx, req.payload) + user, err := svc.CreateUserFromInvite(ctx, req.UserPayload) if err != nil { return createUserResponse{Err: err}, nil } @@ -34,47 +23,6 @@ func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint { } } -//////////////////////////////////////////////////////////////////////////////// -// Create User -//////////////////////////////////////////////////////////////////////////////// - -func makeCreateUserEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(createUserRequest) - user, err := svc.CreateUser(ctx, req.payload) - if err != nil { - return createUserResponse{Err: err}, nil - } - return createUserResponse{User: user}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Get User -//////////////////////////////////////////////////////////////////////////////// - -type getUserRequest struct { - ID uint `json:"id"` -} - -type getUserResponse struct { - User *fleet.User `json:"user,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r getUserResponse) error() error { return r.Err } - -func makeGetUserEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(getUserRequest) - user, err := svc.User(ctx, req.ID) - if err != nil { - return getUserResponse{Err: err}, nil - } - return getUserResponse{User: user}, nil - } -} - func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { user, err := svc.AuthenticatedUser(ctx) @@ -85,60 +33,6 @@ func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint { } } -//////////////////////////////////////////////////////////////////////////////// -// List Users -//////////////////////////////////////////////////////////////////////////////// - -type listUsersRequest struct { - ListOptions fleet.UserListOptions -} - -type listUsersResponse struct { - Users []fleet.User `json:"users"` - Err error `json:"error,omitempty"` -} - -func (r listUsersResponse) error() error { return r.Err } - -func makeListUsersEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(listUsersRequest) - users, err := svc.ListUsers(ctx, req.ListOptions) - if err != nil { - return listUsersResponse{Err: err}, nil - } - - resp := listUsersResponse{Users: []fleet.User{}} - for _, user := range users { - resp.Users = append(resp.Users, *user) - } - return resp, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Change Password -//////////////////////////////////////////////////////////////////////////////// - -type changePasswordRequest struct { - OldPassword string `json:"old_password"` - NewPassword string `json:"new_password"` -} - -type changePasswordResponse struct { - Err error `json:"error,omitempty"` -} - -func (r changePasswordResponse) error() error { return r.Err } - -func makeChangePasswordEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(changePasswordRequest) - err := svc.ChangePassword(ctx, req.OldPassword, req.NewPassword) - return changePasswordResponse{Err: err}, nil - } -} - //////////////////////////////////////////////////////////////////////////////// // Reset Password //////////////////////////////////////////////////////////////////////////////// @@ -162,59 +56,6 @@ func makeResetPasswordEndpoint(svc fleet.Service) endpoint.Endpoint { } } -//////////////////////////////////////////////////////////////////////////////// -// Modify User -//////////////////////////////////////////////////////////////////////////////// - -type modifyUserRequest struct { - ID uint - payload fleet.UserPayload -} - -type modifyUserResponse struct { - User *fleet.User `json:"user,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r modifyUserResponse) error() error { return r.Err } - -func makeModifyUserEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(modifyUserRequest) - user, err := svc.ModifyUser(ctx, req.ID, req.payload) - if err != nil { - return modifyUserResponse{Err: err}, nil - } - - return modifyUserResponse{User: user}, nil - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Delete User -//////////////////////////////////////////////////////////////////////////////// - -type deleteUserRequest struct { - ID uint `json:"id"` -} - -type deleteUserResponse struct { - Err error `json:"error,omitempty"` -} - -func (r deleteUserResponse) error() error { return r.Err } - -func makeDeleteUserEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(deleteUserRequest) - err := svc.DeleteUser(ctx, req.ID) - if err != nil { - return deleteUserResponse{Err: err}, nil - } - return deleteUserResponse{}, nil - } -} - //////////////////////////////////////////////////////////////////////////////// // Perform Required Password Reset //////////////////////////////////////////////////////////////////////////////// @@ -242,33 +83,6 @@ func makePerformRequiredPasswordResetEndpoint(svc fleet.Service) endpoint.Endpoi } } -//////////////////////////////////////////////////////////////////////////////// -// Require Password Reset -//////////////////////////////////////////////////////////////////////////////// - -type requirePasswordResetRequest struct { - Require bool `json:"require"` - ID uint `json:"id"` -} - -type requirePasswordResetResponse struct { - User *fleet.User `json:"user,omitempty"` - Err error `json:"error,omitempty"` -} - -func (r requirePasswordResetResponse) error() error { return r.Err } - -func makeRequirePasswordResetEndpoint(svc fleet.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(requirePasswordResetRequest) - user, err := svc.RequirePasswordReset(ctx, req.ID, req.Require) - if err != nil { - return requirePasswordResetResponse{Err: err}, nil - } - return requirePasswordResetResponse{User: user}, nil - } -} - //////////////////////////////////////////////////////////////////////////////// // Forgot Password //////////////////////////////////////////////////////////////////////////////// diff --git a/server/service/endpoint_utils.go b/server/service/endpoint_utils.go index 5a481c64fc..9b7db84781 100644 --- a/server/service/endpoint_utils.go +++ b/server/service/endpoint_utils.go @@ -67,8 +67,8 @@ func allFields(ifv reflect.Value) []reflect.StructField { // makeDecoder creates a decoder for the type for the struct passed on. If the // struct has at least 1 json tag it'll unmarshall the body. If the struct has // a `url` tag with value list_options it'll gather fleet.ListOptions from the -// URL (similarly for host_options, carve_options that derive from the common -// list_options). +// URL (similarly for host_options, carve_options, user_options that derive +// from the common list_options). // // Finally, any other `url` tag will be treated as a path variable (of the form // /path/{name} in the route's path) from the URL path pattern, and it'll be @@ -123,6 +123,13 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc { } field.Set(reflect.ValueOf(opts)) + case "user_options": + opts, err := userListOptionsFromRequest(r) + if err != nil { + return nil, err + } + field.Set(reflect.ValueOf(opts)) + case "host_options": opts, err := hostListOptionsFromRequest(r) if err != nil { diff --git a/server/service/handler.go b/server/service/handler.go index ef7e47e4cf..0186541234 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -28,17 +28,8 @@ type FleetEndpoints struct { ForgotPassword endpoint.Endpoint ResetPassword endpoint.Endpoint Me endpoint.Endpoint - ChangePassword endpoint.Endpoint CreateUserWithInvite endpoint.Endpoint - CreateUser endpoint.Endpoint - GetUser endpoint.Endpoint - ListUsers endpoint.Endpoint - ModifyUser endpoint.Endpoint - DeleteUser endpoint.Endpoint - RequirePasswordReset endpoint.Endpoint PerformRequiredPasswordReset endpoint.Endpoint - GetSessionsForUserInfo endpoint.Endpoint - DeleteSessionsForUser endpoint.Endpoint GetSessionInfo endpoint.Endpoint DeleteSession endpoint.Endpoint GetAppConfig endpoint.Endpoint @@ -115,15 +106,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th // Standard user authentication routes Me: authenticatedUser(svc, makeGetSessionUserEndpoint(svc)), - ChangePassword: authenticatedUser(svc, makeChangePasswordEndpoint(svc)), - GetUser: authenticatedUser(svc, makeGetUserEndpoint(svc)), - ListUsers: authenticatedUser(svc, makeListUsersEndpoint(svc)), - ModifyUser: authenticatedUser(svc, makeModifyUserEndpoint(svc)), - DeleteUser: authenticatedUser(svc, makeDeleteUserEndpoint(svc)), - RequirePasswordReset: authenticatedUser(svc, makeRequirePasswordResetEndpoint(svc)), - CreateUser: authenticatedUser(svc, makeCreateUserEndpoint(svc)), - GetSessionsForUserInfo: authenticatedUser(svc, makeGetInfoAboutSessionsForUserEndpoint(svc)), - DeleteSessionsForUser: authenticatedUser(svc, makeDeleteSessionsForUserEndpoint(svc)), GetSessionInfo: authenticatedUser(svc, makeGetInfoAboutSessionEndpoint(svc)), DeleteSession: authenticatedUser(svc, makeDeleteSessionEndpoint(svc)), GetAppConfig: authenticatedUser(svc, makeGetAppConfigEndpoint(svc)), @@ -184,17 +166,8 @@ type fleetHandlers struct { ForgotPassword http.Handler ResetPassword http.Handler Me http.Handler - ChangePassword http.Handler CreateUserWithInvite http.Handler - CreateUser http.Handler - GetUser http.Handler - ListUsers http.Handler - ModifyUser http.Handler - DeleteUser http.Handler - RequirePasswordReset http.Handler PerformRequiredPasswordReset http.Handler - GetSessionsForUserInfo http.Handler - DeleteSessionsForUser http.Handler GetSessionInfo http.Handler DeleteSession http.Handler GetAppConfig http.Handler @@ -255,17 +228,8 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest), ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest), Me: newServer(e.Me, decodeNoParamsRequest), - ChangePassword: newServer(e.ChangePassword, decodeChangePasswordRequest), CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest), - CreateUser: newServer(e.CreateUser, decodeCreateUserRequest), - GetUser: newServer(e.GetUser, decodeGetUserRequest), - ListUsers: newServer(e.ListUsers, decodeListUsersRequest), - ModifyUser: newServer(e.ModifyUser, decodeModifyUserRequest), - DeleteUser: newServer(e.DeleteUser, decodeDeleteUserRequest), - RequirePasswordReset: newServer(e.RequirePasswordReset, decodeRequirePasswordResetRequest), PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest), - GetSessionsForUserInfo: newServer(e.GetSessionsForUserInfo, decodeGetInfoAboutSessionsForUserRequest), - DeleteSessionsForUser: newServer(e.DeleteSessionsForUser, decodeDeleteSessionsForUserRequest), GetSessionInfo: newServer(e.GetSessionInfo, decodeGetInfoAboutSessionRequest), DeleteSession: newServer(e.DeleteSession, decodeDeleteSessionRequest), GetAppConfig: newServer(e.GetAppConfig, decodeNoParamsRequest), @@ -488,20 +452,12 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) { r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password") r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password") r.Handle("/api/v1/fleet/me", h.Me).Methods("GET").Name("me") - r.Handle("/api/v1/fleet/change_password", h.ChangePassword).Methods("POST").Name("change_password") r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset") r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso") r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config") r.Handle("/api/v1/fleet/sso/callback", h.CallbackSSO).Methods("POST").Name("callback_sso") - r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users") + r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite") - r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user") - r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.GetUser).Methods("GET").Name("get_user") - r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.ModifyUser).Methods("PATCH").Name("modify_user") - r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.DeleteUser).Methods("DELETE").Name("delete_user") - r.Handle("/api/v1/fleet/users/{id:[0-9]+}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset") - r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user") - r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user") r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info") r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session") @@ -570,6 +526,16 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").PATCH("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", modifyTeamScheduleEndpoint, modifyTeamScheduleRequest{}) e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").DELETE("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", deleteTeamScheduleEndpoint, deleteTeamScheduleRequest{}) + e.GET("/api/_version_/fleet/users", listUsersEndpoint, listUsersRequest{}) + e.POST("/api/_version_/fleet/users/admin", createUserEndpoint, createUserRequest{}) + e.GET("/api/_version_/fleet/users/{id:[0-9]+}", getUserEndpoint, getUserRequest{}) + e.PATCH("/api/_version_/fleet/users/{id:[0-9]+}", modifyUserEndpoint, modifyUserRequest{}) + e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}", deleteUserEndpoint, deleteUserRequest{}) + e.POST("/api/_version_/fleet/users/{id:[0-9]+}/require_password_reset", requirePasswordResetEndpoint, requirePasswordResetRequest{}) + e.GET("/api/_version_/fleet/users/{id:[0-9]+}/sessions", getInfoAboutSessionsForUserEndpoint, getInfoAboutSessionsForUserRequest{}) + e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}/sessions", deleteSessionsForUserEndpoint, deleteSessionsForUserRequest{}) + e.POST("/api/_version_/fleet/change_password", changePasswordEndpoint, changePasswordRequest{}) + e.POST("/api/_version_/fleet/global/policies", globalPolicyEndpoint, globalPolicyRequest{}) e.GET("/api/_version_/fleet/global/policies", listGlobalPoliciesEndpoint, nil) e.GET("/api/_version_/fleet/global/policies/{policy_id}", getPolicyByIDEndpoint, getPolicyByIDRequest{}) diff --git a/server/service/handler_test.go b/server/service/handler_test.go index 944f2c4e3e..cd1c8b6389 100644 --- a/server/service/handler_test.go +++ b/server/service/handler_test.go @@ -42,18 +42,6 @@ func TestAPIRoutes(t *testing.T) { verb: "POST", uri: "/api/v1/fleet/users", }, - { - verb: "GET", - uri: "/api/v1/fleet/users", - }, - { - verb: "GET", - uri: "/api/v1/fleet/users/1", - }, - { - verb: "PATCH", - uri: "/api/v1/fleet/users/1", - }, { verb: "POST", uri: "/api/v1/fleet/login", diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index cfff03f96a..c63c8c446a 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -35,21 +35,36 @@ func (s *integrationTestSuite) SetupSuite() { } func (s *integrationTestSuite) TearDownTest() { + t := s.T() + ctx := context.Background() + u := s.users["admin1@example.com"] filter := fleet.TeamFilter{User: &u} - hosts, _ := s.ds.ListHosts(context.Background(), filter, fleet.HostListOptions{}) + hosts, err := s.ds.ListHosts(ctx, filter, fleet.HostListOptions{}) + require.NoError(t, err) var ids []uint for _, host := range hosts { ids = append(ids, host.ID) } - s.ds.DeleteHosts(context.Background(), ids) + if len(ids) > 0 { + require.NoError(t, s.ds.DeleteHosts(ctx, ids)) + } - lbls, err := s.ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{}) - require.NoError(s.T(), err) + lbls, err := s.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{}) + require.NoError(t, err) for _, lbl := range lbls { if lbl.LabelType != fleet.LabelTypeBuiltIn { - err := s.ds.DeleteLabel(context.Background(), lbl.Name) - require.NoError(s.T(), err) + err := s.ds.DeleteLabel(ctx, lbl.Name) + require.NoError(t, err) + } + } + + users, err := s.ds.ListUsers(ctx, fleet.UserListOptions{}) + require.NoError(t, err) + for _, u := range users { + if _, ok := s.users[u.Email]; !ok { + err := s.ds.DeleteUser(ctx, u.ID) + require.NoError(t, err) } } } @@ -1867,6 +1882,82 @@ func (s *integrationTestSuite) TestLabelSpecs() { s.DoJSON("GET", "/api/v1/fleet/spec/labels/zzz", nil, http.StatusNotFound, &getResp) } +func (s *integrationTestSuite) TestUsers() { + t := s.T() + + // list existing users + var listResp listUsersResponse + s.DoJSON("GET", "/api/v1/fleet/users", nil, http.StatusOK, &listResp) + assert.Len(t, listResp.Users, len(s.users)) + + // create a new user + var createResp createUserResponse + params := fleet.UserPayload{ + Name: ptr.String("extra"), + Email: ptr.String("extra@asd.com"), + Password: ptr.String("pass"), + GlobalRole: ptr.String(fleet.RoleObserver), + } + s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp) + assert.NotZero(t, createResp.User.ID) + u := *createResp.User + + // get that user + var getResp getUserResponse + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusOK, &getResp) + assert.Equal(t, u.ID, getResp.User.ID) + + // get non-existing user + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), nil, http.StatusNotFound, &getResp) + + // modify that user - simple name change + var modResp modifyUserResponse + params = fleet.UserPayload{ + Name: ptr.String("extraz"), + } + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusOK, &modResp) + assert.Equal(t, u.ID, modResp.User.ID) + assert.Equal(t, u.Name+"z", modResp.User.Name) + + // modify user - email change, password does not match + params = fleet.UserPayload{ + Email: ptr.String("extra2@asd.com"), + Password: ptr.String("wrongpass"), + } + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusForbidden, &modResp) + + // modify user - email change, password ok + params = fleet.UserPayload{ + Email: ptr.String("extra2@asd.com"), + Password: ptr.String("pass"), + } + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusOK, &modResp) + assert.Equal(t, u.ID, modResp.User.ID) + assert.NotEqual(t, u.ID, modResp.User.Email) + + // modify invalid user + params = fleet.UserPayload{ + Name: ptr.String("nosuchuser"), + } + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), params, http.StatusNotFound, &modResp) + + // require a password reset + var reqResetResp requirePasswordResetResponse + s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID), map[string]bool{"require": true}, http.StatusOK, &reqResetResp) + assert.Equal(t, u.ID, reqResetResp.User.ID) + assert.True(t, reqResetResp.User.AdminForcedPasswordReset) + + // require a password reset to invalid user + s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID+1), map[string]bool{"require": true}, http.StatusNotFound, &reqResetResp) + + // delete user + var delResp deleteUserResponse + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusOK, &delResp) + + // delete invalid user + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusNotFound, &delResp) +} + func (s *integrationTestSuite) TestGlobalPoliciesAutomationConfig() { t := s.T() diff --git a/server/service/scheduled_queries.go b/server/service/scheduled_queries.go index bd0f64e519..32f492a1a5 100644 --- a/server/service/scheduled_queries.go +++ b/server/service/scheduled_queries.go @@ -12,8 +12,7 @@ import ( //////////////////////////////////////////////////////////////////////////////// type getScheduledQueriesInPackRequest struct { - ID uint `url:"id"` - // TODO(mna): was not set in the old pattern + ID uint `url:"id"` ListOptions fleet.ListOptions `url:"list_options"` } diff --git a/server/service/service_sessions.go b/server/service/service_sessions.go index 491be30f60..6fd04d97bd 100644 --- a/server/service/service_sessions.go +++ b/server/service/service_sessions.go @@ -284,35 +284,6 @@ func (svc *Service) DestroySession(ctx context.Context) error { return svc.ds.DestroySession(ctx, session) } -func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) { - if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil { - return nil, err - } - - var validatedSessions []*fleet.Session - - sessions, err := svc.ds.ListSessionsForUser(ctx, id) - if err != nil { - return validatedSessions, err - } - - for _, session := range sessions { - if svc.validateSession(ctx, session) == nil { - validatedSessions = append(validatedSessions, session) - } - } - - return validatedSessions, nil -} - -func (svc *Service) DeleteSessionsForUser(ctx context.Context, id uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.DestroyAllSessionsForUser(ctx, id) -} - func (svc *Service) GetInfoAboutSession(ctx context.Context, id uint) (*fleet.Session, error) { session, err := svc.ds.SessionByID(ctx, id) if err != nil { diff --git a/server/service/service_users.go b/server/service/service_users.go index a74e110969..e4c8685ace 100644 --- a/server/service/service_users.go +++ b/server/service/service_users.go @@ -7,7 +7,6 @@ import ( "time" "github.com/fleetdm/fleet/v4/server" - "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/viewer" @@ -42,27 +41,6 @@ func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayloa return user, nil } -func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { - var teams []fleet.UserTeam - if p.Teams != nil { - teams = *p.Teams - } - if err := svc.authz.Authorize(ctx, &fleet.User{Teams: teams}, fleet.ActionWrite); err != nil { - return nil, err - } - - if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil { - return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email) - } - - if p.AdminForcedPasswordReset == nil { - // By default, force password reset for users created this way. - p.AdminForcedPasswordReset = ptr.Bool(true) - } - - return svc.newUser(ctx, p) -} - func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { // skipauth: Only the initial user creation should be allowed to skip // authorization (because there is not yet a user context to check against). @@ -106,157 +84,6 @@ func (svc *Service) newUser(ctx context.Context, p fleet.UserPayload) (*fleet.Us return user, nil } -func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) { - if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil { - return nil, err - } - - user, err := svc.User(ctx, userID) - if err != nil { - return nil, err - } - - if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil { - return nil, err - } - - if p.GlobalRole != nil || p.Teams != nil { - if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil { - return nil, err - } - } - if p.Name != nil { - user.Name = *p.Name - } - - if p.Email != nil && *p.Email != user.Email { - err = svc.modifyEmailAddress(ctx, user, *p.Email, p.Password) - if err != nil { - return nil, err - } - } - - if p.Position != nil { - user.Position = *p.Position - } - - if p.GravatarURL != nil { - user.GravatarURL = *p.GravatarURL - } - - if p.SSOEnabled != nil { - user.SSOEnabled = *p.SSOEnabled - } - - currentUser := authz.UserFromContext(ctx) - - if p.GlobalRole != nil && *p.GlobalRole != "" { - if currentUser.GlobalRole == nil { - return nil, ctxerr.New(ctx, "Cannot edit global role as a team member") - } - - if p.Teams != nil && len(*p.Teams) > 0 { - return nil, fleet.NewInvalidArgumentError("teams", "may not be specified with global_role") - } - user.GlobalRole = p.GlobalRole - user.Teams = []fleet.UserTeam{} - } else if p.Teams != nil { - if !isAdminOfTheModifiedTeams(currentUser, user.Teams, *p.Teams) { - return nil, ctxerr.New(ctx, "Cannot modify teams in that way") - } - user.Teams = *p.Teams - user.GlobalRole = nil - } - - err = svc.saveUser(ctx, user) - if err != nil { - return nil, err - } - - return user, nil -} - -func isAdminOfTheModifiedTeams(currentUser *fleet.User, originalUserTeams, newUserTeams []fleet.UserTeam) bool { - // If the user is of the right global role, then they can modify the teams - if currentUser.GlobalRole != nil && (*currentUser.GlobalRole == fleet.RoleAdmin || *currentUser.GlobalRole == fleet.RoleMaintainer) { - return true - } - - // otherwise, gather the resulting teams - resultingTeams := make(map[uint]string) - for _, team := range newUserTeams { - resultingTeams[team.ID] = team.Role - } - - // and see which ones were removed or changed from the original - teamsAffected := make(map[uint]struct{}) - for _, team := range originalUserTeams { - if resultingTeams[team.ID] != team.Role { - teamsAffected[team.ID] = struct{}{} - } - } - - // then gather the teams the current user is admin for - currentUserTeamAdmin := make(map[uint]struct{}) - for _, team := range currentUser.Teams { - if team.Role == fleet.RoleAdmin { - currentUserTeamAdmin[team.ID] = struct{}{} - } - } - - // and let's check that the teams that were either removed or changed are also teams this user is an admin of - for teamID := range teamsAffected { - if _, ok := currentUserTeamAdmin[teamID]; !ok { - return false - } - } - - return true -} - -func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, email string, password *string) error { - // password requirement handled in validation middleware - if password != nil { - err := user.ValidatePassword(*password) - if err != nil { - return fleet.NewPermissionError("incorrect password") - } - } - random, err := server.GenerateRandomText(svc.config.App.TokenKeySize) - if err != nil { - return err - } - token := base64.URLEncoding.EncodeToString([]byte(random)) - err = svc.ds.PendingEmailChange(ctx, user.ID, email, token) - if err != nil { - return err - } - config, err := svc.AppConfig(ctx) - if err != nil { - return err - } - - changeEmail := fleet.Email{ - Subject: "Confirm Fleet Email Change", - To: []string{email}, - Config: config, - Mailer: &mail.ChangeEmailMailer{ - Token: token, - BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix), - AssetURL: getAssetURL(), - }, - } - return svc.mailService.SendEmail(changeEmail) -} - -func (svc *Service) DeleteUser(ctx context.Context, id uint) error { - if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.DeleteUser(ctx, id) -} - func (svc *Service) ChangeUserEmail(ctx context.Context, token string) (string, error) { vc, ok := viewer.FromContext(ctx) if !ok { @@ -270,14 +97,6 @@ func (svc *Service) ChangeUserEmail(ctx context.Context, token string) (string, return svc.ds.ConfirmPendingEmailChange(ctx, vc.UserID(), token) } -func (svc *Service) User(ctx context.Context, id uint) (*fleet.User, error) { - if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.UserByID(ctx, id) -} - func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User, error) { // Explicitly no authorization check. Should only be used by middleware. return svc.ds.UserByID(ctx, id) @@ -299,14 +118,6 @@ func (svc *Service) AuthenticatedUser(ctx context.Context) (*fleet.User, error) return vc.User, nil } -func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { - if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListUsers(ctx, opt) -} - // setNewPassword is a helper for changing a user's password. It should be // called to set the new password after proper authorization has been // performed. @@ -326,33 +137,6 @@ func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, passwo return nil } -func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) error { - vc, ok := viewer.FromContext(ctx) - if !ok { - return fleet.ErrNoContext - } - - if err := svc.authz.Authorize(ctx, vc.User, fleet.ActionWrite); err != nil { - return err - } - - if vc.User.SSOEnabled { - return ctxerr.New(ctx, "change password for single sign on user not allowed") - } - if err := vc.User.ValidatePassword(newPass); err == nil { - return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") - } - - if err := vc.User.ValidatePassword(oldPass); err != nil { - return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match")) - } - - if err := svc.setNewPassword(ctx, vc.User, newPass); err != nil { - return ctxerr.Wrap(ctx, err, "setting new password") - } - return nil -} - func (svc *Service) ResetPassword(ctx context.Context, token, password string) error { // skipauth: No viewer context available. The user is locked out of their // account and authNZ is performed entirely by providing a valid password @@ -431,34 +215,6 @@ func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password s return user, nil } -func (svc *Service) RequirePasswordReset(ctx context.Context, uid uint, require bool) (*fleet.User, error) { - if err := svc.authz.Authorize(ctx, &fleet.User{ID: uid}, fleet.ActionWrite); err != nil { - return nil, err - } - - user, err := svc.ds.UserByID(ctx, uid) - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "loading user by ID") - } - if user.SSOEnabled { - return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed") - } - // Require reset on next login - user.AdminForcedPasswordReset = require - if err := svc.saveUser(ctx, user); err != nil { - return nil, ctxerr.Wrap(ctx, err, "saving user") - } - - if require { - // Clear all of the existing sessions - if err := svc.DeleteSessionsForUser(ctx, user.ID); err != nil { - return nil, ctxerr.Wrap(ctx, err, "deleting user sessions") - } - } - - return user, nil -} - func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error { // skipauth: No viewer context available. The user is locked out of their // account and trying to reset their password. @@ -512,10 +268,3 @@ func (svc *Service) RequestPasswordReset(ctx context.Context, email string) erro return svc.mailService.SendEmail(resetEmail) } - -// saves user in datastore. -// doesn't need to be exposed to the transport -// the service should expose actions for modifying a user instead -func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error { - return svc.ds.SaveUser(ctx, user) -} diff --git a/server/service/service_users_test.go b/server/service/service_users_test.go index 9c7f430eac..f87b354ec1 100644 --- a/server/service/service_users_test.go +++ b/server/service/service_users_test.go @@ -3,7 +3,6 @@ package service import ( "context" "database/sql" - "errors" "testing" "time" @@ -11,8 +10,6 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/mock" - "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -38,250 +35,6 @@ func TestAuthenticatedUser(t *testing.T) { assert.Equal(t, user, admin1) } -func TestModifyUserEmail(t *testing.T) { - user := &fleet.User{ - ID: 3, - Email: "foo@bar.com", - } - user.SetPassword("password", 10, 10) - ms := new(mock.Store) - ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { - return nil - } - ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { - return user, nil - } - ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - config := &fleet.AppConfig{ - SMTPSettings: fleet.SMTPSettings{ - SMTPConfigured: true, - SMTPAuthenticationType: fleet.AuthTypeNameNone, - SMTPPort: 1025, - SMTPServer: "127.0.0.1", - SMTPSenderAddress: "xxx@fleet.co", - }, - } - return config, nil - } - ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { - // verify this isn't changed yet - assert.Equal(t, "foo@bar.com", u.Email) - // verify is changed per bug 1123 - assert.Equal(t, "minion", u.Position) - return nil - } - svc := newTestService(ms, nil, nil) - ctx := context.Background() - ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) - payload := fleet.UserPayload{ - Email: ptr.String("zip@zap.com"), - Password: ptr.String("password"), - Position: ptr.String("minion"), - } - _, err := svc.ModifyUser(ctx, 3, payload) - require.Nil(t, err) - assert.True(t, ms.PendingEmailChangeFuncInvoked) - assert.True(t, ms.SaveUserFuncInvoked) -} - -func TestModifyUserEmailNoPassword(t *testing.T) { - user := &fleet.User{ - ID: 3, - Email: "foo@bar.com", - } - user.SetPassword("password", 10, 10) - ms := new(mock.Store) - ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { - return nil - } - ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { - return user, nil - } - ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - config := &fleet.AppConfig{ - SMTPSettings: fleet.SMTPSettings{ - SMTPConfigured: true, - SMTPAuthenticationType: fleet.AuthTypeNameNone, - SMTPPort: 1025, - SMTPServer: "127.0.0.1", - SMTPSenderAddress: "xxx@fleet.co", - }, - } - return config, nil - } - ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { - return nil - } - svc := newTestService(ms, nil, nil) - ctx := context.Background() - ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) - payload := fleet.UserPayload{ - Email: ptr.String("zip@zap.com"), - // NO PASSWORD - // Password: ptr.String("password"), - } - _, err := svc.ModifyUser(ctx, 3, payload) - require.NotNil(t, err) - var iae *fleet.InvalidArgumentError - ok := errors.As(err, &iae) - require.True(t, ok) - require.Len(t, *iae, 1) - assert.False(t, ms.PendingEmailChangeFuncInvoked) - assert.False(t, ms.SaveUserFuncInvoked) -} - -func TestModifyAdminUserEmailNoPassword(t *testing.T) { - user := &fleet.User{ - ID: 3, - Email: "foo@bar.com", - } - user.SetPassword("password", 10, 10) - ms := new(mock.Store) - ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { - return nil - } - ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { - return user, nil - } - ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - config := &fleet.AppConfig{ - SMTPSettings: fleet.SMTPSettings{ - SMTPConfigured: true, - SMTPAuthenticationType: fleet.AuthTypeNameNone, - SMTPPort: 1025, - SMTPServer: "127.0.0.1", - SMTPSenderAddress: "xxx@fleet.co", - }, - } - return config, nil - } - ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { - return nil - } - svc := newTestService(ms, nil, nil) - ctx := context.Background() - ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) - payload := fleet.UserPayload{ - Email: ptr.String("zip@zap.com"), - // NO PASSWORD - // Password: ptr.String("password"), - } - _, err := svc.ModifyUser(ctx, 3, payload) - require.NotNil(t, err) - var iae *fleet.InvalidArgumentError - ok := errors.As(err, &iae) - require.True(t, ok) - require.Len(t, *iae, 1) - assert.False(t, ms.PendingEmailChangeFuncInvoked) - assert.False(t, ms.SaveUserFuncInvoked) -} - -func TestModifyAdminUserEmailPassword(t *testing.T) { - user := &fleet.User{ - ID: 3, - Email: "foo@bar.com", - } - user.SetPassword("password", 10, 10) - ms := new(mock.Store) - ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { - return nil - } - ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { - return user, nil - } - ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { - config := &fleet.AppConfig{ - SMTPSettings: fleet.SMTPSettings{ - SMTPConfigured: true, - SMTPAuthenticationType: fleet.AuthTypeNameNone, - SMTPPort: 1025, - SMTPServer: "127.0.0.1", - SMTPSenderAddress: "xxx@fleet.co", - }, - } - return config, nil - } - ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { - return nil - } - svc := newTestService(ms, nil, nil) - ctx := context.Background() - ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) - payload := fleet.UserPayload{ - Email: ptr.String("zip@zap.com"), - Password: ptr.String("password"), - } - _, err := svc.ModifyUser(ctx, 3, payload) - require.Nil(t, err) - assert.True(t, ms.PendingEmailChangeFuncInvoked) - assert.True(t, ms.SaveUserFuncInvoked) -} - -func TestChangePassword(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - - svc := newTestService(ds, nil, nil) - users := createTestUsers(t, ds) - passwordChangeTests := []struct { - user fleet.User - oldPassword string - newPassword string - anyErr bool - wantErr error - }{ - { // all good - user: users["admin1@example.com"], - oldPassword: "foobarbaz1234!", - newPassword: "12345cat!", - }, - { // prevent password reuse - user: users["admin1@example.com"], - oldPassword: "12345cat!", - newPassword: "foobarbaz1234!", - wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"), - }, - { // all good - user: users["user1@example.com"], - oldPassword: "foobarbaz1234!", - newPassword: "newpassa1234!", - }, - { // bad old password - user: users["user1@example.com"], - oldPassword: "wrong_password", - newPassword: "12345cat!", - anyErr: true, - }, - { // missing old password - newPassword: "123cataaa!", - wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"), - }, - } - - for _, tt := range passwordChangeTests { - t.Run("", func(t *testing.T) { - ctx := context.Background() - ctx = viewer.NewContext(ctx, viewer.Viewer{User: &tt.user}) - - err := svc.ChangePassword(ctx, tt.oldPassword, tt.newPassword) - if tt.anyErr { - require.NotNil(t, err) - } else if tt.wantErr != nil { - require.Equal(t, tt.wantErr, ctxerr.Cause(err)) - } else { - require.Nil(t, err) - } - - if err != nil { - return - } - - // Attempt login after successful change - _, _, err = svc.Login(context.Background(), tt.user.Email, tt.newPassword) - require.Nil(t, err, "should be able to login with new password") - }) - } -} - func TestResetPassword(t *testing.T) { ds := mysql.CreateMySQLDS(t) @@ -340,48 +93,6 @@ func TestResetPassword(t *testing.T) { } } -func TestRequirePasswordReset(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - - svc := newTestService(ds, nil, nil) - createTestUsers(t, ds) - - for _, tt := range testUsers { - t.Run(tt.Email, func(t *testing.T) { - user, err := ds.UserByEmail(context.Background(), tt.Email) - require.Nil(t, err) - - var sessions []*fleet.Session - - // Log user in - _, _, err = svc.Login(test.UserContext(test.UserAdmin), tt.Email, tt.PlaintextPassword) - require.Nil(t, err, "login unsuccessful") - sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID) - require.Nil(t, err) - require.Len(t, sessions, 1, "user should have one session") - - // Reset and verify sessions destroyed - retUser, err := svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true) - require.Nil(t, err) - assert.True(t, retUser.AdminForcedPasswordReset) - checkUser, err := ds.UserByEmail(context.Background(), tt.Email) - require.Nil(t, err) - assert.True(t, checkUser.AdminForcedPasswordReset) - sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID) - require.Nil(t, err) - require.Len(t, sessions, 0, "sessions should be destroyed") - - // try undo - retUser, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, false) - require.Nil(t, err) - assert.False(t, retUser.AdminForcedPasswordReset) - checkUser, err = ds.UserByEmail(context.Background(), tt.Email) - require.Nil(t, err) - assert.False(t, checkUser.AdminForcedPasswordReset) - }) - } -} - func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context { reloadedUser, err := ds.UserByEmail(ctx, user.Email) require.NoError(t, err) @@ -475,178 +186,3 @@ func TestUserPasswordRequirements(t *testing.T) { }) } } - -func TestUserAuth(t *testing.T) { - ds := new(mock.Store) - svc := newTestService(ds, nil, nil) - - ds.InviteByTokenFunc = func(ctx context.Context, token string) (*fleet.Invite, error) { - return &fleet.Invite{ - Email: "some@email.com", - Token: "ABCD", - UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ - CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()}, - UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()}, - }, - }, nil - } - ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) { - return &fleet.User{}, nil - } - ds.DeleteInviteFunc = func(ctx context.Context, id uint) error { - return nil - } - ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) { - return nil, errors.New("AA") - } - ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { - if id == 999 { - return &fleet.User{ - ID: 999, - Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}, - }, nil - } - return &fleet.User{ - ID: 888, - GlobalRole: ptr.String(fleet.RoleMaintainer), - }, nil - } - ds.SaveUserFunc = func(ctx context.Context, user *fleet.User) error { - return nil - } - - testCases := []struct { - name string - user *fleet.User - shouldFailGlobalWrite bool - shouldFailTeamWrite bool - shouldFailRead bool - }{ - { - "global admin", - &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, - false, - false, - false, - }, - { - "global maintainer", - &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}, - true, - true, - true, - }, - { - "global observer", - &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}, - true, - true, - true, - }, - { - "team admin, belongs to team", - &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}, - true, - false, - false, - }, - { - "team maintainer, belongs to team", - &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}, - true, - true, - false, - }, - { - "team observer, belongs to team", - &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}, - true, - true, - true, - }, - { - "team maintainer, DOES NOT belong to team", - &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}}, - true, - true, - true, - }, - { - "team admin, DOES NOT belong to team", - &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}}, - true, - true, - true, - }, - { - "team observer, DOES NOT belong to team", - &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}}, - true, - true, - true, - }, - } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user}) - - teams := []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}} - _, err := svc.CreateUser(ctx, fleet.UserPayload{ - Name: ptr.String("Some Name"), - Email: ptr.String("some@email.com"), - Password: ptr.String("passw0rd."), - Teams: &teams, - }) - checkAuthErr(t, tt.shouldFailTeamWrite, err) - - _, err = svc.CreateUser(ctx, fleet.UserPayload{ - Name: ptr.String("Some Name"), - Email: ptr.String("some@email.com"), - Password: ptr.String("passw0rd."), - GlobalRole: ptr.String(fleet.RoleAdmin), - }) - checkAuthErr(t, tt.shouldFailGlobalWrite, err) - - _, err = svc.ModifyUser(ctx, 999, fleet.UserPayload{Teams: &teams}) - checkAuthErr(t, tt.shouldFailTeamWrite, err) - - _, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{Teams: &teams}) - checkAuthErr(t, tt.shouldFailGlobalWrite, err) - - _, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{GlobalRole: ptr.String(fleet.RoleMaintainer)}) - checkAuthErr(t, tt.shouldFailGlobalWrite, err) - }) - } -} - -// Test that CreateUser creates a user that will be forced to -// reset its password upon first login (see #2570). -func TestCreateUserForcePasswdReset(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - svc := newTestService(ds, nil, nil) - - // Create admin user. - admin := &fleet.User{ - Name: "Fleet Admin", - Email: "admin@foo.com", - GlobalRole: ptr.String(fleet.RoleAdmin), - } - err := admin.SetPassword("p4ssw0rd.", 10, 10) - require.NoError(t, err) - admin, err = ds.NewUser(context.Background(), admin) - require.NoError(t, err) - - // As the admin, create a new user. - ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: admin}) - user, err := svc.CreateUser(ctx, fleet.UserPayload{ - Name: ptr.String("Some Observer"), - Email: ptr.String("some-observer@email.com"), - Password: ptr.String("passw0rd."), - GlobalRole: ptr.String(fleet.RoleObserver), - }) - require.NoError(t, err) - - user, err = ds.UserByID(context.Background(), user.ID) - require.NoError(t, err) - require.True(t, user.AdminForcedPasswordReset) -} diff --git a/server/service/sessions_test.go b/server/service/sessions_test.go new file mode 100644 index 0000000000..d3036aaf7e --- /dev/null +++ b/server/service/sessions_test.go @@ -0,0 +1,8 @@ +package service + +// TODO(mna): when migrating Session-related endpoints, add auth tests for those +// endpoints (the auth is session-based, not user-based). +//_, err = svc.GetInfoAboutSessionsForUser(ctx, 999) +//checkAuthErr(t, tt.shouldFailTeamWrite, err) +//_, err = svc.GetInfoAboutSessionsForUser(ctx, 888) +//checkAuthErr(t, tt.shouldFailGlobalWrite, err) diff --git a/server/service/transport_sessions.go b/server/service/transport_sessions.go index ec4a4a0483..f61a8fe626 100644 --- a/server/service/transport_sessions.go +++ b/server/service/transport_sessions.go @@ -18,14 +18,6 @@ func decodeGetInfoAboutSessionRequest(ctx context.Context, r *http.Request) (int return getInfoAboutSessionRequest{ID: uint(id)}, nil } -func decodeGetInfoAboutSessionsForUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - return getInfoAboutSessionsForUserRequest{ID: uint(id)}, nil -} - func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface{}, error) { id, err := uintFromRequest(r, "id") if err != nil { @@ -34,14 +26,6 @@ func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface return deleteSessionRequest{ID: uint(id)}, nil } -func decodeDeleteSessionsForUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - return deleteSessionsForUserRequest{ID: uint(id)}, nil -} - func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) { var req loginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { diff --git a/server/service/transport_sessions_test.go b/server/service/transport_sessions_test.go index ed028f1ed0..0cb01763c2 100644 --- a/server/service/transport_sessions_test.go +++ b/server/service/transport_sessions_test.go @@ -27,22 +27,6 @@ func TestDecodeGetInfoAboutSessionRequest(t *testing.T) { ) } -func TestDecodeGetInfoAboutSessionsForUserRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/user/{id}/sessions", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeGetInfoAboutSessionsForUserRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(getInfoAboutSessionsForUserRequest) - assert.Equal(t, uint(1), params.ID) - }).Methods("GET") - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("GET", "/api/v1/fleet/users/1/sessions", nil), - ) -} - func TestDecodeDeleteSessionRequest(t *testing.T) { router := mux.NewRouter() router.HandleFunc("/api/v1/fleet/sessions/{id}", func(writer http.ResponseWriter, request *http.Request) { @@ -59,22 +43,6 @@ func TestDecodeDeleteSessionRequest(t *testing.T) { ) } -func TestDecodeDeleteSessionsForUserRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/user/{id}/sessions", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeDeleteSessionsForUserRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(deleteSessionsForUserRequest) - assert.Equal(t, uint(1), params.ID) - }).Methods("DELETE") - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("DELETE", "/api/v1/fleet/users/1/sessions", nil), - ) -} - func TestDecodeLoginRequest(t *testing.T) { router := mux.NewRouter() router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) { diff --git a/server/service/transport_users.go b/server/service/transport_users.go index 894ab510a4..88e1854b05 100644 --- a/server/service/transport_users.go +++ b/server/service/transport_users.go @@ -10,69 +10,9 @@ import ( func decodeCreateUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { var req createUserRequest - if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil { - return nil, err - } - - return req, nil -} - -func decodeGetUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - return getUserRequest{ID: uint(id)}, nil -} - -func decodeListUsersRequest(ctx context.Context, r *http.Request) (interface{}, error) { - opt, err := userListOptionsFromRequest(r) - if err != nil { - return nil, err - } - return listUsersRequest{ListOptions: opt}, nil -} - -func decodeModifyUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - var req modifyUserRequest - if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil { - return nil, err - } - req.ID = uint(id) - return req, nil -} - -func decodeDeleteUserRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, err - } - return deleteUserRequest{ID: uint(id)}, nil -} - -func decodeChangePasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) { - var req changePasswordRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, err } - return req, nil -} - -func decodeRequirePasswordResetRequest(ctx context.Context, r *http.Request) (interface{}, error) { - id, err := uintFromRequest(r, "id") - if err != nil { - return nil, ctxerr.Wrap(ctx, err, "getting ID from request") - } - - var req requirePasswordResetRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, ctxerr.Wrap(ctx, err, "decoding JSON") - } - req.ID = uint(id) return req, nil } diff --git a/server/service/transport_users_test.go b/server/service/transport_users_test.go index 2c9bf39bfe..7d6e9ff571 100644 --- a/server/service/transport_users_test.go +++ b/server/service/transport_users_test.go @@ -11,68 +11,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDecodeCreateUserRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/users", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeCreateUserRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(createUserRequest) - assert.Equal(t, "foo", *params.payload.Name) - assert.Equal(t, "foo@fleet.co", *params.payload.Email) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "name": "foo", - "email": "foo@fleet.co" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/api/v1/fleet/users", &body), - ) -} - -func TestDecodeGetUserRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/users/{id}", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeGetUserRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(getUserRequest) - assert.Equal(t, uint(1), params.ID) - }).Methods("GET") - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("GET", "/api/v1/fleet/users/1", nil), - ) -} - -func TestDecodeChangePasswordRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/change_password", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeChangePasswordRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(changePasswordRequest) - assert.Equal(t, "foo", params.OldPassword) - assert.Equal(t, "bar", params.NewPassword) - }).Methods("POST") - - var body bytes.Buffer - body.Write([]byte(`{ - "old_password": "foo", - "new_password": "bar" - }`)) - - router.ServeHTTP( - httptest.NewRecorder(), - httptest.NewRequest("POST", "/api/v1/fleet/change_password", &body), - ) -} - func TestDecodeResetPasswordRequest(t *testing.T) { router := mux.NewRouter() router.HandleFunc("/api/v1/fleet/users/{id}/password", func(writer http.ResponseWriter, request *http.Request) { @@ -95,28 +33,3 @@ func TestDecodeResetPasswordRequest(t *testing.T) { httptest.NewRequest("POST", "/api/v1/fleet/users/1/password", &body), ) } - -func TestDecodeModifyUserRequest(t *testing.T) { - router := mux.NewRouter() - router.HandleFunc("/api/v1/fleet/users/{id}", func(writer http.ResponseWriter, request *http.Request) { - r, err := decodeModifyUserRequest(context.Background(), request) - assert.Nil(t, err) - - params := r.(modifyUserRequest) - assert.Equal(t, "foo", *params.payload.Name) - assert.Equal(t, "foo@fleet.co", *params.payload.Email) - assert.Equal(t, uint(1), params.ID) - }).Methods("PATCH") - - var body bytes.Buffer - body.Write([]byte(`{ - "name": "foo", - "email": "foo@fleet.co" - }`)) - - request := httptest.NewRequest("PATCH", "/api/v1/fleet/users/1", &body) - router.ServeHTTP( - httptest.NewRecorder(), - request, - ) -} diff --git a/server/service/users.go b/server/service/users.go new file mode 100644 index 0000000000..28a7f8e9f8 --- /dev/null +++ b/server/service/users.go @@ -0,0 +1,521 @@ +package service + +import ( + "context" + "encoding/base64" + "html/template" + + "github.com/fleetdm/fleet/v4/server" + "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mail" + "github.com/fleetdm/fleet/v4/server/ptr" +) + +//////////////////////////////////////////////////////////////////////////////// +// Create User +//////////////////////////////////////////////////////////////////////////////// + +type createUserRequest struct { + fleet.UserPayload +} + +type createUserResponse struct { + User *fleet.User `json:"user,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r createUserResponse) error() error { return r.Err } + +func createUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*createUserRequest) + user, err := svc.CreateUser(ctx, req.UserPayload) + if err != nil { + return createUserResponse{Err: err}, nil + } + return createUserResponse{User: user}, nil +} + +func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) { + var teams []fleet.UserTeam + if p.Teams != nil { + teams = *p.Teams + } + if err := svc.authz.Authorize(ctx, &fleet.User{Teams: teams}, fleet.ActionWrite); err != nil { + return nil, err + } + + if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil { + return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email) + } + + if p.AdminForcedPasswordReset == nil { + // By default, force password reset for users created this way. + p.AdminForcedPasswordReset = ptr.Bool(true) + } + + return svc.newUser(ctx, p) +} + +//////////////////////////////////////////////////////////////////////////////// +// List Users +//////////////////////////////////////////////////////////////////////////////// + +type listUsersRequest struct { + ListOptions fleet.UserListOptions `url:"user_options"` +} + +type listUsersResponse struct { + Users []fleet.User `json:"users"` + Err error `json:"error,omitempty"` +} + +func (r listUsersResponse) error() error { return r.Err } + +func listUsersEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*listUsersRequest) + users, err := svc.ListUsers(ctx, req.ListOptions) + if err != nil { + return listUsersResponse{Err: err}, nil + } + + resp := listUsersResponse{Users: []fleet.User{}} + for _, user := range users { + resp.Users = append(resp.Users, *user) + } + return resp, nil +} + +func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { + if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.ListUsers(ctx, opt) +} + +//////////////////////////////////////////////////////////////////////////////// +// Get User +//////////////////////////////////////////////////////////////////////////////// + +type getUserRequest struct { + ID uint `url:"id"` +} + +type getUserResponse struct { + User *fleet.User `json:"user,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r getUserResponse) error() error { return r.Err } + +func getUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getUserRequest) + user, err := svc.User(ctx, req.ID) + if err != nil { + return getUserResponse{Err: err}, nil + } + return getUserResponse{User: user}, nil +} + +func (svc *Service) User(ctx context.Context, id uint) (*fleet.User, error) { + if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionRead); err != nil { + return nil, err + } + + return svc.ds.UserByID(ctx, id) +} + +//////////////////////////////////////////////////////////////////////////////// +// Modify User +//////////////////////////////////////////////////////////////////////////////// + +type modifyUserRequest struct { + ID uint `json:"-" url:"id"` + fleet.UserPayload +} + +type modifyUserResponse struct { + User *fleet.User `json:"user,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r modifyUserResponse) error() error { return r.Err } + +func modifyUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*modifyUserRequest) + user, err := svc.ModifyUser(ctx, req.ID, req.UserPayload) + if err != nil { + return modifyUserResponse{Err: err}, nil + } + + return modifyUserResponse{User: user}, nil +} + +func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) { + if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil { + return nil, err + } + + user, err := svc.User(ctx, userID) + if err != nil { + return nil, err + } + + if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil { + return nil, err + } + + if p.GlobalRole != nil || p.Teams != nil { + if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil { + return nil, err + } + } + if p.Name != nil { + user.Name = *p.Name + } + + if p.Email != nil && *p.Email != user.Email { + err = svc.modifyEmailAddress(ctx, user, *p.Email, p.Password) + if err != nil { + return nil, err + } + } + + if p.Position != nil { + user.Position = *p.Position + } + + if p.GravatarURL != nil { + user.GravatarURL = *p.GravatarURL + } + + if p.SSOEnabled != nil { + user.SSOEnabled = *p.SSOEnabled + } + + currentUser := authz.UserFromContext(ctx) + + if p.GlobalRole != nil && *p.GlobalRole != "" { + if currentUser.GlobalRole == nil { + return nil, ctxerr.New(ctx, "Cannot edit global role as a team member") + } + + if p.Teams != nil && len(*p.Teams) > 0 { + return nil, fleet.NewInvalidArgumentError("teams", "may not be specified with global_role") + } + user.GlobalRole = p.GlobalRole + user.Teams = []fleet.UserTeam{} + } else if p.Teams != nil { + if !isAdminOfTheModifiedTeams(currentUser, user.Teams, *p.Teams) { + return nil, ctxerr.New(ctx, "Cannot modify teams in that way") + } + user.Teams = *p.Teams + user.GlobalRole = nil + } + + err = svc.saveUser(ctx, user) + if err != nil { + return nil, err + } + + return user, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Delete User +//////////////////////////////////////////////////////////////////////////////// + +type deleteUserRequest struct { + ID uint `url:"id"` +} + +type deleteUserResponse struct { + Err error `json:"error,omitempty"` +} + +func (r deleteUserResponse) error() error { return r.Err } + +func deleteUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*deleteUserRequest) + err := svc.DeleteUser(ctx, req.ID) + if err != nil { + return deleteUserResponse{Err: err}, nil + } + return deleteUserResponse{}, nil +} + +func (svc *Service) DeleteUser(ctx context.Context, id uint) error { + if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionWrite); err != nil { + return err + } + + return svc.ds.DeleteUser(ctx, id) +} + +//////////////////////////////////////////////////////////////////////////////// +// Require Password Reset +//////////////////////////////////////////////////////////////////////////////// + +type requirePasswordResetRequest struct { + Require bool `json:"require"` + ID uint `json:"-" url:"id"` +} + +type requirePasswordResetResponse struct { + User *fleet.User `json:"user,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r requirePasswordResetResponse) error() error { return r.Err } + +func requirePasswordResetEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*requirePasswordResetRequest) + user, err := svc.RequirePasswordReset(ctx, req.ID, req.Require) + if err != nil { + return requirePasswordResetResponse{Err: err}, nil + } + return requirePasswordResetResponse{User: user}, nil +} + +func (svc *Service) RequirePasswordReset(ctx context.Context, uid uint, require bool) (*fleet.User, error) { + if err := svc.authz.Authorize(ctx, &fleet.User{ID: uid}, fleet.ActionWrite); err != nil { + return nil, err + } + + user, err := svc.ds.UserByID(ctx, uid) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "loading user by ID") + } + if user.SSOEnabled { + return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed") + } + // Require reset on next login + user.AdminForcedPasswordReset = require + if err := svc.saveUser(ctx, user); err != nil { + return nil, ctxerr.Wrap(ctx, err, "saving user") + } + + if require { + // Clear all of the existing sessions + if err := svc.DeleteSessionsForUser(ctx, user.ID); err != nil { + return nil, ctxerr.Wrap(ctx, err, "deleting user sessions") + } + } + + return user, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Change Password +//////////////////////////////////////////////////////////////////////////////// + +type changePasswordRequest struct { + OldPassword string `json:"old_password"` + NewPassword string `json:"new_password"` +} + +type changePasswordResponse struct { + Err error `json:"error,omitempty"` +} + +func (r changePasswordResponse) error() error { return r.Err } + +func changePasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*changePasswordRequest) + err := svc.ChangePassword(ctx, req.OldPassword, req.NewPassword) + return changePasswordResponse{Err: err}, nil +} + +func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) error { + vc, ok := viewer.FromContext(ctx) + if !ok { + return fleet.ErrNoContext + } + + if err := svc.authz.Authorize(ctx, vc.User, fleet.ActionWrite); err != nil { + return err + } + + if vc.User.SSOEnabled { + return ctxerr.New(ctx, "change password for single sign on user not allowed") + } + if err := vc.User.ValidatePassword(newPass); err == nil { + return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password") + } + + if err := vc.User.ValidatePassword(oldPass); err != nil { + return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match")) + } + + if err := svc.setNewPassword(ctx, vc.User, newPass); err != nil { + return ctxerr.Wrap(ctx, err, "setting new password") + } + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Get Info About Sessions For User +//////////////////////////////////////////////////////////////////////////////// + +type getInfoAboutSessionsForUserRequest struct { + ID uint `url:"id"` +} + +type getInfoAboutSessionsForUserResponse struct { + Sessions []getInfoAboutSessionResponse `json:"sessions"` + Err error `json:"error,omitempty"` +} + +func (r getInfoAboutSessionsForUserResponse) error() error { return r.Err } + +func getInfoAboutSessionsForUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*getInfoAboutSessionsForUserRequest) + sessions, err := svc.GetInfoAboutSessionsForUser(ctx, req.ID) + if err != nil { + return getInfoAboutSessionsForUserResponse{Err: err}, nil + } + var resp getInfoAboutSessionsForUserResponse + for _, session := range sessions { + resp.Sessions = append(resp.Sessions, getInfoAboutSessionResponse{ + SessionID: session.ID, + UserID: session.UserID, + CreatedAt: session.CreatedAt, + }) + } + return resp, nil +} + +func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) { + if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil { + return nil, err + } + + var validatedSessions []*fleet.Session + + sessions, err := svc.ds.ListSessionsForUser(ctx, id) + if err != nil { + return validatedSessions, err + } + + for _, session := range sessions { + if svc.validateSession(ctx, session) == nil { + validatedSessions = append(validatedSessions, session) + } + } + + return validatedSessions, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Delete Sessions For User +//////////////////////////////////////////////////////////////////////////////// + +type deleteSessionsForUserRequest struct { + ID uint `url:"id"` +} + +type deleteSessionsForUserResponse struct { + Err error `json:"error,omitempty"` +} + +func (r deleteSessionsForUserResponse) error() error { return r.Err } + +func deleteSessionsForUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { + req := request.(*deleteSessionsForUserRequest) + err := svc.DeleteSessionsForUser(ctx, req.ID) + if err != nil { + return deleteSessionsForUserResponse{Err: err}, nil + } + return deleteSessionsForUserResponse{}, nil +} + +func (svc *Service) DeleteSessionsForUser(ctx context.Context, id uint) error { + if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil { + return err + } + + return svc.ds.DestroyAllSessionsForUser(ctx, id) +} + +func isAdminOfTheModifiedTeams(currentUser *fleet.User, originalUserTeams, newUserTeams []fleet.UserTeam) bool { + // If the user is of the right global role, then they can modify the teams + if currentUser.GlobalRole != nil && (*currentUser.GlobalRole == fleet.RoleAdmin || *currentUser.GlobalRole == fleet.RoleMaintainer) { + return true + } + + // otherwise, gather the resulting teams + resultingTeams := make(map[uint]string) + for _, team := range newUserTeams { + resultingTeams[team.ID] = team.Role + } + + // and see which ones were removed or changed from the original + teamsAffected := make(map[uint]struct{}) + for _, team := range originalUserTeams { + if resultingTeams[team.ID] != team.Role { + teamsAffected[team.ID] = struct{}{} + } + } + + // then gather the teams the current user is admin for + currentUserTeamAdmin := make(map[uint]struct{}) + for _, team := range currentUser.Teams { + if team.Role == fleet.RoleAdmin { + currentUserTeamAdmin[team.ID] = struct{}{} + } + } + + // and let's check that the teams that were either removed or changed are also teams this user is an admin of + for teamID := range teamsAffected { + if _, ok := currentUserTeamAdmin[teamID]; !ok { + return false + } + } + + return true +} + +func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, email string, password *string) error { + // password requirement handled in validation middleware + if password != nil { + err := user.ValidatePassword(*password) + if err != nil { + return fleet.NewPermissionError("incorrect password") + } + } + random, err := server.GenerateRandomText(svc.config.App.TokenKeySize) + if err != nil { + return err + } + token := base64.URLEncoding.EncodeToString([]byte(random)) + err = svc.ds.PendingEmailChange(ctx, user.ID, email, token) + if err != nil { + return err + } + config, err := svc.AppConfig(ctx) + if err != nil { + return err + } + + changeEmail := fleet.Email{ + Subject: "Confirm Fleet Email Change", + To: []string{email}, + Config: config, + Mailer: &mail.ChangeEmailMailer{ + Token: token, + BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix), + AssetURL: getAssetURL(), + }, + } + return svc.mailService.SendEmail(changeEmail) +} + +// saves user in datastore. +// doesn't need to be exposed to the transport +// the service should expose actions for modifying a user instead +func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error { + return svc.ds.SaveUser(ctx, user) +} diff --git a/server/service/users_test.go b/server/service/users_test.go new file mode 100644 index 0000000000..fe6ed4aed3 --- /dev/null +++ b/server/service/users_test.go @@ -0,0 +1,530 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/datastore/mysql" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/fleetdm/fleet/v4/server/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserAuth(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + + ds.InviteByTokenFunc = func(ctx context.Context, token string) (*fleet.Invite, error) { + return &fleet.Invite{ + Email: "some@email.com", + Token: "ABCD", + UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{ + CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()}, + UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()}, + }, + }, nil + } + ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) { + return &fleet.User{}, nil + } + ds.DeleteInviteFunc = func(ctx context.Context, id uint) error { + return nil + } + ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) { + return nil, errors.New("AA") + } + ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { + if id == 999 { + return &fleet.User{ + ID: 999, + Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}, + }, nil + } + return &fleet.User{ + ID: 888, + GlobalRole: ptr.String(fleet.RoleMaintainer), + }, nil + } + ds.SaveUserFunc = func(ctx context.Context, user *fleet.User) error { + return nil + } + ds.ListUsersFunc = func(ctx context.Context, opts fleet.UserListOptions) ([]*fleet.User, error) { + return nil, nil + } + ds.DeleteUserFunc = func(ctx context.Context, id uint) error { + return nil + } + ds.DestroyAllSessionsForUserFunc = func(ctx context.Context, id uint) error { + return nil + } + ds.ListSessionsForUserFunc = func(ctx context.Context, id uint) ([]*fleet.Session, error) { + return nil, nil + } + + testCases := []struct { + name string + user *fleet.User + shouldFailGlobalWrite bool + shouldFailTeamWrite bool + shouldFailRead bool + shouldFailDeleteReset bool + }{ + { + "global admin", + &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}, + false, + false, + false, + false, + }, + { + "global maintainer", + &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}, + true, + true, + false, + true, + }, + { + "global observer", + &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}, + true, + true, + false, + true, + }, + { + "team admin, belongs to team", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}, + true, + false, + false, + true, + }, + { + "team maintainer, belongs to team", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}, + true, + true, + false, + true, + }, + { + "team observer, belongs to team", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}, + true, + true, + false, + true, + }, + { + "team maintainer, DOES NOT belong to team", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}}, + true, + true, + false, + true, + }, + { + "team admin, DOES NOT belong to team", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}}, + true, + true, + false, + true, + }, + { + "team observer, DOES NOT belong to team", + &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}}, + true, + true, + false, + true, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user}) + + teams := []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}} + _, err := svc.CreateUser(ctx, fleet.UserPayload{ + Name: ptr.String("Some Name"), + Email: ptr.String("some@email.com"), + Password: ptr.String("passw0rd."), + Teams: &teams, + }) + checkAuthErr(t, tt.shouldFailTeamWrite, err) + + _, err = svc.CreateUser(ctx, fleet.UserPayload{ + Name: ptr.String("Some Name"), + Email: ptr.String("some@email.com"), + Password: ptr.String("passw0rd."), + GlobalRole: ptr.String(fleet.RoleAdmin), + }) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) + + _, err = svc.ModifyUser(ctx, 999, fleet.UserPayload{Teams: &teams}) + checkAuthErr(t, tt.shouldFailTeamWrite, err) + + _, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{Teams: &teams}) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) + + _, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{GlobalRole: ptr.String(fleet.RoleMaintainer)}) + checkAuthErr(t, tt.shouldFailGlobalWrite, err) + + err = svc.DeleteUser(ctx, 999) + checkAuthErr(t, tt.shouldFailDeleteReset, err) + + _, err = svc.RequirePasswordReset(ctx, 999, false) + checkAuthErr(t, tt.shouldFailDeleteReset, err) + + _, err = svc.ListUsers(ctx, fleet.UserListOptions{}) + checkAuthErr(t, tt.shouldFailRead, err) + + _, err = svc.User(ctx, 999) + checkAuthErr(t, tt.shouldFailRead, err) + + _, err = svc.User(ctx, 888) + checkAuthErr(t, tt.shouldFailRead, err) + }) + } +} + +func TestModifyUserEmail(t *testing.T) { + user := &fleet.User{ + ID: 3, + Email: "foo@bar.com", + } + user.SetPassword("password", 10, 10) + ms := new(mock.Store) + ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { + return nil + } + ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { + return user, nil + } + ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + config := &fleet.AppConfig{ + SMTPSettings: fleet.SMTPSettings{ + SMTPConfigured: true, + SMTPAuthenticationType: fleet.AuthTypeNameNone, + SMTPPort: 1025, + SMTPServer: "127.0.0.1", + SMTPSenderAddress: "xxx@fleet.co", + }, + } + return config, nil + } + ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { + // verify this isn't changed yet + assert.Equal(t, "foo@bar.com", u.Email) + // verify is changed per bug 1123 + assert.Equal(t, "minion", u.Position) + return nil + } + svc := newTestService(ms, nil, nil) + ctx := context.Background() + ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) + payload := fleet.UserPayload{ + Email: ptr.String("zip@zap.com"), + Password: ptr.String("password"), + Position: ptr.String("minion"), + } + _, err := svc.ModifyUser(ctx, 3, payload) + require.Nil(t, err) + assert.True(t, ms.PendingEmailChangeFuncInvoked) + assert.True(t, ms.SaveUserFuncInvoked) +} + +func TestModifyUserEmailNoPassword(t *testing.T) { + user := &fleet.User{ + ID: 3, + Email: "foo@bar.com", + } + user.SetPassword("password", 10, 10) + ms := new(mock.Store) + ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { + return nil + } + ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { + return user, nil + } + ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + config := &fleet.AppConfig{ + SMTPSettings: fleet.SMTPSettings{ + SMTPConfigured: true, + SMTPAuthenticationType: fleet.AuthTypeNameNone, + SMTPPort: 1025, + SMTPServer: "127.0.0.1", + SMTPSenderAddress: "xxx@fleet.co", + }, + } + return config, nil + } + ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { + return nil + } + svc := newTestService(ms, nil, nil) + ctx := context.Background() + ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) + payload := fleet.UserPayload{ + Email: ptr.String("zip@zap.com"), + // NO PASSWORD + // Password: ptr.String("password"), + } + _, err := svc.ModifyUser(ctx, 3, payload) + require.NotNil(t, err) + var iae *fleet.InvalidArgumentError + ok := errors.As(err, &iae) + require.True(t, ok) + require.Len(t, *iae, 1) + assert.False(t, ms.PendingEmailChangeFuncInvoked) + assert.False(t, ms.SaveUserFuncInvoked) +} + +func TestModifyAdminUserEmailNoPassword(t *testing.T) { + user := &fleet.User{ + ID: 3, + Email: "foo@bar.com", + } + user.SetPassword("password", 10, 10) + ms := new(mock.Store) + ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { + return nil + } + ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { + return user, nil + } + ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + config := &fleet.AppConfig{ + SMTPSettings: fleet.SMTPSettings{ + SMTPConfigured: true, + SMTPAuthenticationType: fleet.AuthTypeNameNone, + SMTPPort: 1025, + SMTPServer: "127.0.0.1", + SMTPSenderAddress: "xxx@fleet.co", + }, + } + return config, nil + } + ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { + return nil + } + svc := newTestService(ms, nil, nil) + ctx := context.Background() + ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) + payload := fleet.UserPayload{ + Email: ptr.String("zip@zap.com"), + // NO PASSWORD + // Password: ptr.String("password"), + } + _, err := svc.ModifyUser(ctx, 3, payload) + require.NotNil(t, err) + var iae *fleet.InvalidArgumentError + ok := errors.As(err, &iae) + require.True(t, ok) + require.Len(t, *iae, 1) + assert.False(t, ms.PendingEmailChangeFuncInvoked) + assert.False(t, ms.SaveUserFuncInvoked) +} + +func TestModifyAdminUserEmailPassword(t *testing.T) { + user := &fleet.User{ + ID: 3, + Email: "foo@bar.com", + } + user.SetPassword("password", 10, 10) + ms := new(mock.Store) + ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error { + return nil + } + ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) { + return user, nil + } + ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + config := &fleet.AppConfig{ + SMTPSettings: fleet.SMTPSettings{ + SMTPConfigured: true, + SMTPAuthenticationType: fleet.AuthTypeNameNone, + SMTPPort: 1025, + SMTPServer: "127.0.0.1", + SMTPSenderAddress: "xxx@fleet.co", + }, + } + return config, nil + } + ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error { + return nil + } + svc := newTestService(ms, nil, nil) + ctx := context.Background() + ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) + payload := fleet.UserPayload{ + Email: ptr.String("zip@zap.com"), + Password: ptr.String("password"), + } + _, err := svc.ModifyUser(ctx, 3, payload) + require.Nil(t, err) + assert.True(t, ms.PendingEmailChangeFuncInvoked) + assert.True(t, ms.SaveUserFuncInvoked) +} + +func TestUsersWithDS(t *testing.T) { + ds := mysql.CreateMySQLDS(t) + + cases := []struct { + name string + fn func(t *testing.T, ds *mysql.Datastore) + }{ + {"CreateUserForcePasswdReset", testUsersCreateUserForcePasswdReset}, + {"ChangePassword", testUsersChangePassword}, + {"RequirePasswordReset", testUsersRequirePasswordReset}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + defer mysql.TruncateTables(t, ds) + c.fn(t, ds) + }) + } +} + +// Test that CreateUser creates a user that will be forced to +// reset its password upon first login (see #2570). +func testUsersCreateUserForcePasswdReset(t *testing.T, ds *mysql.Datastore) { + svc := newTestService(ds, nil, nil) + + // Create admin user. + admin := &fleet.User{ + Name: "Fleet Admin", + Email: "admin@foo.com", + GlobalRole: ptr.String(fleet.RoleAdmin), + } + err := admin.SetPassword("p4ssw0rd.", 10, 10) + require.NoError(t, err) + admin, err = ds.NewUser(context.Background(), admin) + require.NoError(t, err) + + // As the admin, create a new user. + ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: admin}) + user, err := svc.CreateUser(ctx, fleet.UserPayload{ + Name: ptr.String("Some Observer"), + Email: ptr.String("some-observer@email.com"), + Password: ptr.String("passw0rd."), + GlobalRole: ptr.String(fleet.RoleObserver), + }) + require.NoError(t, err) + + user, err = ds.UserByID(context.Background(), user.ID) + require.NoError(t, err) + require.True(t, user.AdminForcedPasswordReset) +} + +func testUsersChangePassword(t *testing.T, ds *mysql.Datastore) { + svc := newTestService(ds, nil, nil) + users := createTestUsers(t, ds) + passwordChangeTests := []struct { + user fleet.User + oldPassword string + newPassword string + anyErr bool + wantErr error + }{ + { // all good + user: users["admin1@example.com"], + oldPassword: "foobarbaz1234!", + newPassword: "12345cat!", + }, + { // prevent password reuse + user: users["admin1@example.com"], + oldPassword: "12345cat!", + newPassword: "foobarbaz1234!", + wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"), + }, + { // all good + user: users["user1@example.com"], + oldPassword: "foobarbaz1234!", + newPassword: "newpassa1234!", + }, + { // bad old password + user: users["user1@example.com"], + oldPassword: "wrong_password", + newPassword: "12345cat!", + anyErr: true, + }, + { // missing old password + newPassword: "123cataaa!", + wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"), + }, + } + + for _, tt := range passwordChangeTests { + t.Run("", func(t *testing.T) { + ctx := context.Background() + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &tt.user}) + + err := svc.ChangePassword(ctx, tt.oldPassword, tt.newPassword) + if tt.anyErr { + require.NotNil(t, err) + } else if tt.wantErr != nil { + require.Equal(t, tt.wantErr, ctxerr.Cause(err)) + } else { + require.Nil(t, err) + } + + if err != nil { + return + } + + // Attempt login after successful change + _, _, err = svc.Login(context.Background(), tt.user.Email, tt.newPassword) + require.Nil(t, err, "should be able to login with new password") + }) + } +} + +func testUsersRequirePasswordReset(t *testing.T, ds *mysql.Datastore) { + svc := newTestService(ds, nil, nil) + createTestUsers(t, ds) + + for _, tt := range testUsers { + t.Run(tt.Email, func(t *testing.T) { + user, err := ds.UserByEmail(context.Background(), tt.Email) + require.Nil(t, err) + + var sessions []*fleet.Session + + // Log user in + _, _, err = svc.Login(test.UserContext(test.UserAdmin), tt.Email, tt.PlaintextPassword) + require.Nil(t, err, "login unsuccessful") + sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID) + require.Nil(t, err) + require.Len(t, sessions, 1, "user should have one session") + + // Reset and verify sessions destroyed + retUser, err := svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true) + require.Nil(t, err) + assert.True(t, retUser.AdminForcedPasswordReset) + checkUser, err := ds.UserByEmail(context.Background(), tt.Email) + require.Nil(t, err) + assert.True(t, checkUser.AdminForcedPasswordReset) + sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID) + require.Nil(t, err) + require.Len(t, sessions, 0, "sessions should be destroyed") + + // try undo + retUser, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, false) + require.Nil(t, err) + assert.False(t, retUser.AdminForcedPasswordReset) + checkUser, err = ds.UserByEmail(context.Background(), tt.Email) + require.Nil(t, err) + assert.False(t, checkUser.AdminForcedPasswordReset) + }) + } +}