mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 01:18:42 +00:00
Migrate most users endpoints to the new pattern (#3366)
This commit is contained in:
parent
a5bef8a990
commit
597144bfac
17 changed files with 1178 additions and 1252 deletions
|
|
@ -91,40 +91,6 @@ func makeGetInfoAboutSessionEndpoint(svc fleet.Service) endpoint.Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get Info About Sessions For User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getInfoAboutSessionsForUserRequest struct {
|
||||
ID uint
|
||||
}
|
||||
|
||||
type getInfoAboutSessionsForUserResponse struct {
|
||||
Sessions []getInfoAboutSessionResponse `json:"sessions"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getInfoAboutSessionsForUserResponse) error() error { return r.Err }
|
||||
|
||||
func makeGetInfoAboutSessionsForUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(getInfoAboutSessionsForUserRequest)
|
||||
sessions, err := svc.GetInfoAboutSessionsForUser(ctx, req.ID)
|
||||
if err != nil {
|
||||
return getInfoAboutSessionsForUserResponse{Err: err}, nil
|
||||
}
|
||||
var resp getInfoAboutSessionsForUserResponse
|
||||
for _, session := range sessions {
|
||||
resp.Sessions = append(resp.Sessions, getInfoAboutSessionResponse{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete Session
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -150,31 +116,6 @@ func makeDeleteSessionEndpoint(svc fleet.Service) endpoint.Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete Sessions For User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type deleteSessionsForUserRequest struct {
|
||||
ID uint
|
||||
}
|
||||
|
||||
type deleteSessionsForUserResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r deleteSessionsForUserResponse) error() error { return r.Err }
|
||||
|
||||
func makeDeleteSessionsForUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(deleteSessionsForUserRequest)
|
||||
err := svc.DeleteSessionsForUser(ctx, req.ID)
|
||||
if err != nil {
|
||||
return deleteSessionsForUserResponse{Err: err}, nil
|
||||
}
|
||||
return deleteSessionsForUserResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type initiateSSORequest struct {
|
||||
RelayURL string `json:"relay_url"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,21 +12,10 @@ import (
|
|||
// Create User With Invite
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type createUserRequest struct {
|
||||
payload fleet.UserPayload
|
||||
}
|
||||
|
||||
type createUserResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r createUserResponse) error() error { return r.Err }
|
||||
|
||||
func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(createUserRequest)
|
||||
user, err := svc.CreateUserFromInvite(ctx, req.payload)
|
||||
user, err := svc.CreateUserFromInvite(ctx, req.UserPayload)
|
||||
if err != nil {
|
||||
return createUserResponse{Err: err}, nil
|
||||
}
|
||||
|
|
@ -34,47 +23,6 @@ func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func makeCreateUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(createUserRequest)
|
||||
user, err := svc.CreateUser(ctx, req.payload)
|
||||
if err != nil {
|
||||
return createUserResponse{Err: err}, nil
|
||||
}
|
||||
return createUserResponse{User: user}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getUserRequest struct {
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
|
||||
type getUserResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getUserResponse) error() error { return r.Err }
|
||||
|
||||
func makeGetUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(getUserRequest)
|
||||
user, err := svc.User(ctx, req.ID)
|
||||
if err != nil {
|
||||
return getUserResponse{Err: err}, nil
|
||||
}
|
||||
return getUserResponse{User: user}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
user, err := svc.AuthenticatedUser(ctx)
|
||||
|
|
@ -85,60 +33,6 @@ func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// List Users
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type listUsersRequest struct {
|
||||
ListOptions fleet.UserListOptions
|
||||
}
|
||||
|
||||
type listUsersResponse struct {
|
||||
Users []fleet.User `json:"users"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r listUsersResponse) error() error { return r.Err }
|
||||
|
||||
func makeListUsersEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(listUsersRequest)
|
||||
users, err := svc.ListUsers(ctx, req.ListOptions)
|
||||
if err != nil {
|
||||
return listUsersResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
resp := listUsersResponse{Users: []fleet.User{}}
|
||||
for _, user := range users {
|
||||
resp.Users = append(resp.Users, *user)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Change Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type changePasswordRequest struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type changePasswordResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r changePasswordResponse) error() error { return r.Err }
|
||||
|
||||
func makeChangePasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(changePasswordRequest)
|
||||
err := svc.ChangePassword(ctx, req.OldPassword, req.NewPassword)
|
||||
return changePasswordResponse{Err: err}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Reset Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -162,59 +56,6 @@ func makeResetPasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Modify User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type modifyUserRequest struct {
|
||||
ID uint
|
||||
payload fleet.UserPayload
|
||||
}
|
||||
|
||||
type modifyUserResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r modifyUserResponse) error() error { return r.Err }
|
||||
|
||||
func makeModifyUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(modifyUserRequest)
|
||||
user, err := svc.ModifyUser(ctx, req.ID, req.payload)
|
||||
if err != nil {
|
||||
return modifyUserResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
return modifyUserResponse{User: user}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type deleteUserRequest struct {
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
|
||||
type deleteUserResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r deleteUserResponse) error() error { return r.Err }
|
||||
|
||||
func makeDeleteUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(deleteUserRequest)
|
||||
err := svc.DeleteUser(ctx, req.ID)
|
||||
if err != nil {
|
||||
return deleteUserResponse{Err: err}, nil
|
||||
}
|
||||
return deleteUserResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Perform Required Password Reset
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -242,33 +83,6 @@ func makePerformRequiredPasswordResetEndpoint(svc fleet.Service) endpoint.Endpoi
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Require Password Reset
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type requirePasswordResetRequest struct {
|
||||
Require bool `json:"require"`
|
||||
ID uint `json:"id"`
|
||||
}
|
||||
|
||||
type requirePasswordResetResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r requirePasswordResetResponse) error() error { return r.Err }
|
||||
|
||||
func makeRequirePasswordResetEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(requirePasswordResetRequest)
|
||||
user, err := svc.RequirePasswordReset(ctx, req.ID, req.Require)
|
||||
if err != nil {
|
||||
return requirePasswordResetResponse{Err: err}, nil
|
||||
}
|
||||
return requirePasswordResetResponse{User: user}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Forgot Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
|||
|
|
@ -67,8 +67,8 @@ func allFields(ifv reflect.Value) []reflect.StructField {
|
|||
// makeDecoder creates a decoder for the type for the struct passed on. If the
|
||||
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
|
||||
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
|
||||
// URL (similarly for host_options, carve_options that derive from the common
|
||||
// list_options).
|
||||
// URL (similarly for host_options, carve_options, user_options that derive
|
||||
// from the common list_options).
|
||||
//
|
||||
// Finally, any other `url` tag will be treated as a path variable (of the form
|
||||
// /path/{name} in the route's path) from the URL path pattern, and it'll be
|
||||
|
|
@ -123,6 +123,13 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
|
|||
}
|
||||
field.Set(reflect.ValueOf(opts))
|
||||
|
||||
case "user_options":
|
||||
opts, err := userListOptionsFromRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
field.Set(reflect.ValueOf(opts))
|
||||
|
||||
case "host_options":
|
||||
opts, err := hostListOptionsFromRequest(r)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -28,17 +28,8 @@ type FleetEndpoints struct {
|
|||
ForgotPassword endpoint.Endpoint
|
||||
ResetPassword endpoint.Endpoint
|
||||
Me endpoint.Endpoint
|
||||
ChangePassword endpoint.Endpoint
|
||||
CreateUserWithInvite endpoint.Endpoint
|
||||
CreateUser endpoint.Endpoint
|
||||
GetUser endpoint.Endpoint
|
||||
ListUsers endpoint.Endpoint
|
||||
ModifyUser endpoint.Endpoint
|
||||
DeleteUser endpoint.Endpoint
|
||||
RequirePasswordReset endpoint.Endpoint
|
||||
PerformRequiredPasswordReset endpoint.Endpoint
|
||||
GetSessionsForUserInfo endpoint.Endpoint
|
||||
DeleteSessionsForUser endpoint.Endpoint
|
||||
GetSessionInfo endpoint.Endpoint
|
||||
DeleteSession endpoint.Endpoint
|
||||
GetAppConfig endpoint.Endpoint
|
||||
|
|
@ -115,15 +106,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
|
|||
|
||||
// Standard user authentication routes
|
||||
Me: authenticatedUser(svc, makeGetSessionUserEndpoint(svc)),
|
||||
ChangePassword: authenticatedUser(svc, makeChangePasswordEndpoint(svc)),
|
||||
GetUser: authenticatedUser(svc, makeGetUserEndpoint(svc)),
|
||||
ListUsers: authenticatedUser(svc, makeListUsersEndpoint(svc)),
|
||||
ModifyUser: authenticatedUser(svc, makeModifyUserEndpoint(svc)),
|
||||
DeleteUser: authenticatedUser(svc, makeDeleteUserEndpoint(svc)),
|
||||
RequirePasswordReset: authenticatedUser(svc, makeRequirePasswordResetEndpoint(svc)),
|
||||
CreateUser: authenticatedUser(svc, makeCreateUserEndpoint(svc)),
|
||||
GetSessionsForUserInfo: authenticatedUser(svc, makeGetInfoAboutSessionsForUserEndpoint(svc)),
|
||||
DeleteSessionsForUser: authenticatedUser(svc, makeDeleteSessionsForUserEndpoint(svc)),
|
||||
GetSessionInfo: authenticatedUser(svc, makeGetInfoAboutSessionEndpoint(svc)),
|
||||
DeleteSession: authenticatedUser(svc, makeDeleteSessionEndpoint(svc)),
|
||||
GetAppConfig: authenticatedUser(svc, makeGetAppConfigEndpoint(svc)),
|
||||
|
|
@ -184,17 +166,8 @@ type fleetHandlers struct {
|
|||
ForgotPassword http.Handler
|
||||
ResetPassword http.Handler
|
||||
Me http.Handler
|
||||
ChangePassword http.Handler
|
||||
CreateUserWithInvite http.Handler
|
||||
CreateUser http.Handler
|
||||
GetUser http.Handler
|
||||
ListUsers http.Handler
|
||||
ModifyUser http.Handler
|
||||
DeleteUser http.Handler
|
||||
RequirePasswordReset http.Handler
|
||||
PerformRequiredPasswordReset http.Handler
|
||||
GetSessionsForUserInfo http.Handler
|
||||
DeleteSessionsForUser http.Handler
|
||||
GetSessionInfo http.Handler
|
||||
DeleteSession http.Handler
|
||||
GetAppConfig http.Handler
|
||||
|
|
@ -255,17 +228,8 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
|
|||
ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest),
|
||||
ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest),
|
||||
Me: newServer(e.Me, decodeNoParamsRequest),
|
||||
ChangePassword: newServer(e.ChangePassword, decodeChangePasswordRequest),
|
||||
CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest),
|
||||
CreateUser: newServer(e.CreateUser, decodeCreateUserRequest),
|
||||
GetUser: newServer(e.GetUser, decodeGetUserRequest),
|
||||
ListUsers: newServer(e.ListUsers, decodeListUsersRequest),
|
||||
ModifyUser: newServer(e.ModifyUser, decodeModifyUserRequest),
|
||||
DeleteUser: newServer(e.DeleteUser, decodeDeleteUserRequest),
|
||||
RequirePasswordReset: newServer(e.RequirePasswordReset, decodeRequirePasswordResetRequest),
|
||||
PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest),
|
||||
GetSessionsForUserInfo: newServer(e.GetSessionsForUserInfo, decodeGetInfoAboutSessionsForUserRequest),
|
||||
DeleteSessionsForUser: newServer(e.DeleteSessionsForUser, decodeDeleteSessionsForUserRequest),
|
||||
GetSessionInfo: newServer(e.GetSessionInfo, decodeGetInfoAboutSessionRequest),
|
||||
DeleteSession: newServer(e.DeleteSession, decodeDeleteSessionRequest),
|
||||
GetAppConfig: newServer(e.GetAppConfig, decodeNoParamsRequest),
|
||||
|
|
@ -488,20 +452,12 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
|||
r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password")
|
||||
r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password")
|
||||
r.Handle("/api/v1/fleet/me", h.Me).Methods("GET").Name("me")
|
||||
r.Handle("/api/v1/fleet/change_password", h.ChangePassword).Methods("POST").Name("change_password")
|
||||
r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset")
|
||||
r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso")
|
||||
r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config")
|
||||
r.Handle("/api/v1/fleet/sso/callback", h.CallbackSSO).Methods("POST").Name("callback_sso")
|
||||
r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users")
|
||||
|
||||
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
|
||||
r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user")
|
||||
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.GetUser).Methods("GET").Name("get_user")
|
||||
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.ModifyUser).Methods("PATCH").Name("modify_user")
|
||||
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.DeleteUser).Methods("DELETE").Name("delete_user")
|
||||
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset")
|
||||
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user")
|
||||
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user")
|
||||
|
||||
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
|
||||
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session")
|
||||
|
|
@ -570,6 +526,16 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
|
|||
e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").PATCH("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", modifyTeamScheduleEndpoint, modifyTeamScheduleRequest{})
|
||||
e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").DELETE("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", deleteTeamScheduleEndpoint, deleteTeamScheduleRequest{})
|
||||
|
||||
e.GET("/api/_version_/fleet/users", listUsersEndpoint, listUsersRequest{})
|
||||
e.POST("/api/_version_/fleet/users/admin", createUserEndpoint, createUserRequest{})
|
||||
e.GET("/api/_version_/fleet/users/{id:[0-9]+}", getUserEndpoint, getUserRequest{})
|
||||
e.PATCH("/api/_version_/fleet/users/{id:[0-9]+}", modifyUserEndpoint, modifyUserRequest{})
|
||||
e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}", deleteUserEndpoint, deleteUserRequest{})
|
||||
e.POST("/api/_version_/fleet/users/{id:[0-9]+}/require_password_reset", requirePasswordResetEndpoint, requirePasswordResetRequest{})
|
||||
e.GET("/api/_version_/fleet/users/{id:[0-9]+}/sessions", getInfoAboutSessionsForUserEndpoint, getInfoAboutSessionsForUserRequest{})
|
||||
e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}/sessions", deleteSessionsForUserEndpoint, deleteSessionsForUserRequest{})
|
||||
e.POST("/api/_version_/fleet/change_password", changePasswordEndpoint, changePasswordRequest{})
|
||||
|
||||
e.POST("/api/_version_/fleet/global/policies", globalPolicyEndpoint, globalPolicyRequest{})
|
||||
e.GET("/api/_version_/fleet/global/policies", listGlobalPoliciesEndpoint, nil)
|
||||
e.GET("/api/_version_/fleet/global/policies/{policy_id}", getPolicyByIDEndpoint, getPolicyByIDRequest{})
|
||||
|
|
|
|||
|
|
@ -42,18 +42,6 @@ func TestAPIRoutes(t *testing.T) {
|
|||
verb: "POST",
|
||||
uri: "/api/v1/fleet/users",
|
||||
},
|
||||
{
|
||||
verb: "GET",
|
||||
uri: "/api/v1/fleet/users",
|
||||
},
|
||||
{
|
||||
verb: "GET",
|
||||
uri: "/api/v1/fleet/users/1",
|
||||
},
|
||||
{
|
||||
verb: "PATCH",
|
||||
uri: "/api/v1/fleet/users/1",
|
||||
},
|
||||
{
|
||||
verb: "POST",
|
||||
uri: "/api/v1/fleet/login",
|
||||
|
|
|
|||
|
|
@ -35,21 +35,36 @@ func (s *integrationTestSuite) SetupSuite() {
|
|||
}
|
||||
|
||||
func (s *integrationTestSuite) TearDownTest() {
|
||||
t := s.T()
|
||||
ctx := context.Background()
|
||||
|
||||
u := s.users["admin1@example.com"]
|
||||
filter := fleet.TeamFilter{User: &u}
|
||||
hosts, _ := s.ds.ListHosts(context.Background(), filter, fleet.HostListOptions{})
|
||||
hosts, err := s.ds.ListHosts(ctx, filter, fleet.HostListOptions{})
|
||||
require.NoError(t, err)
|
||||
var ids []uint
|
||||
for _, host := range hosts {
|
||||
ids = append(ids, host.ID)
|
||||
}
|
||||
s.ds.DeleteHosts(context.Background(), ids)
|
||||
if len(ids) > 0 {
|
||||
require.NoError(t, s.ds.DeleteHosts(ctx, ids))
|
||||
}
|
||||
|
||||
lbls, err := s.ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{})
|
||||
require.NoError(s.T(), err)
|
||||
lbls, err := s.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
for _, lbl := range lbls {
|
||||
if lbl.LabelType != fleet.LabelTypeBuiltIn {
|
||||
err := s.ds.DeleteLabel(context.Background(), lbl.Name)
|
||||
require.NoError(s.T(), err)
|
||||
err := s.ds.DeleteLabel(ctx, lbl.Name)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
users, err := s.ds.ListUsers(ctx, fleet.UserListOptions{})
|
||||
require.NoError(t, err)
|
||||
for _, u := range users {
|
||||
if _, ok := s.users[u.Email]; !ok {
|
||||
err := s.ds.DeleteUser(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1867,6 +1882,82 @@ func (s *integrationTestSuite) TestLabelSpecs() {
|
|||
s.DoJSON("GET", "/api/v1/fleet/spec/labels/zzz", nil, http.StatusNotFound, &getResp)
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestUsers() {
|
||||
t := s.T()
|
||||
|
||||
// list existing users
|
||||
var listResp listUsersResponse
|
||||
s.DoJSON("GET", "/api/v1/fleet/users", nil, http.StatusOK, &listResp)
|
||||
assert.Len(t, listResp.Users, len(s.users))
|
||||
|
||||
// create a new user
|
||||
var createResp createUserResponse
|
||||
params := fleet.UserPayload{
|
||||
Name: ptr.String("extra"),
|
||||
Email: ptr.String("extra@asd.com"),
|
||||
Password: ptr.String("pass"),
|
||||
GlobalRole: ptr.String(fleet.RoleObserver),
|
||||
}
|
||||
s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp)
|
||||
assert.NotZero(t, createResp.User.ID)
|
||||
u := *createResp.User
|
||||
|
||||
// get that user
|
||||
var getResp getUserResponse
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusOK, &getResp)
|
||||
assert.Equal(t, u.ID, getResp.User.ID)
|
||||
|
||||
// get non-existing user
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), nil, http.StatusNotFound, &getResp)
|
||||
|
||||
// modify that user - simple name change
|
||||
var modResp modifyUserResponse
|
||||
params = fleet.UserPayload{
|
||||
Name: ptr.String("extraz"),
|
||||
}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusOK, &modResp)
|
||||
assert.Equal(t, u.ID, modResp.User.ID)
|
||||
assert.Equal(t, u.Name+"z", modResp.User.Name)
|
||||
|
||||
// modify user - email change, password does not match
|
||||
params = fleet.UserPayload{
|
||||
Email: ptr.String("extra2@asd.com"),
|
||||
Password: ptr.String("wrongpass"),
|
||||
}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusForbidden, &modResp)
|
||||
|
||||
// modify user - email change, password ok
|
||||
params = fleet.UserPayload{
|
||||
Email: ptr.String("extra2@asd.com"),
|
||||
Password: ptr.String("pass"),
|
||||
}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusOK, &modResp)
|
||||
assert.Equal(t, u.ID, modResp.User.ID)
|
||||
assert.NotEqual(t, u.ID, modResp.User.Email)
|
||||
|
||||
// modify invalid user
|
||||
params = fleet.UserPayload{
|
||||
Name: ptr.String("nosuchuser"),
|
||||
}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), params, http.StatusNotFound, &modResp)
|
||||
|
||||
// require a password reset
|
||||
var reqResetResp requirePasswordResetResponse
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID), map[string]bool{"require": true}, http.StatusOK, &reqResetResp)
|
||||
assert.Equal(t, u.ID, reqResetResp.User.ID)
|
||||
assert.True(t, reqResetResp.User.AdminForcedPasswordReset)
|
||||
|
||||
// require a password reset to invalid user
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID+1), map[string]bool{"require": true}, http.StatusNotFound, &reqResetResp)
|
||||
|
||||
// delete user
|
||||
var delResp deleteUserResponse
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusOK, &delResp)
|
||||
|
||||
// delete invalid user
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusNotFound, &delResp)
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestGlobalPoliciesAutomationConfig() {
|
||||
t := s.T()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ import (
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getScheduledQueriesInPackRequest struct {
|
||||
ID uint `url:"id"`
|
||||
// TODO(mna): was not set in the old pattern
|
||||
ID uint `url:"id"`
|
||||
ListOptions fleet.ListOptions `url:"list_options"`
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -284,35 +284,6 @@ func (svc *Service) DestroySession(ctx context.Context) error {
|
|||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
||||
|
||||
func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validatedSessions []*fleet.Session
|
||||
|
||||
sessions, err := svc.ds.ListSessionsForUser(ctx, id)
|
||||
if err != nil {
|
||||
return validatedSessions, err
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
if svc.validateSession(ctx, session) == nil {
|
||||
validatedSessions = append(validatedSessions, session)
|
||||
}
|
||||
}
|
||||
|
||||
return validatedSessions, nil
|
||||
}
|
||||
|
||||
func (svc *Service) DeleteSessionsForUser(ctx context.Context, id uint) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DestroyAllSessionsForUser(ctx, id)
|
||||
}
|
||||
|
||||
func (svc *Service) GetInfoAboutSession(ctx context.Context, id uint) (*fleet.Session, error) {
|
||||
session, err := svc.ds.SessionByID(ctx, id)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
|
|
@ -42,27 +41,6 @@ func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayloa
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
var teams []fleet.UserTeam
|
||||
if p.Teams != nil {
|
||||
teams = *p.Teams
|
||||
}
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{Teams: teams}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil {
|
||||
return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email)
|
||||
}
|
||||
|
||||
if p.AdminForcedPasswordReset == nil {
|
||||
// By default, force password reset for users created this way.
|
||||
p.AdminForcedPasswordReset = ptr.Bool(true)
|
||||
}
|
||||
|
||||
return svc.newUser(ctx, p)
|
||||
}
|
||||
|
||||
func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
// skipauth: Only the initial user creation should be allowed to skip
|
||||
// authorization (because there is not yet a user context to check against).
|
||||
|
|
@ -106,157 +84,6 @@ func (svc *Service) newUser(ctx context.Context, p fleet.UserPayload) (*fleet.Us
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := svc.User(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.GlobalRole != nil || p.Teams != nil {
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if p.Name != nil {
|
||||
user.Name = *p.Name
|
||||
}
|
||||
|
||||
if p.Email != nil && *p.Email != user.Email {
|
||||
err = svc.modifyEmailAddress(ctx, user, *p.Email, p.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if p.Position != nil {
|
||||
user.Position = *p.Position
|
||||
}
|
||||
|
||||
if p.GravatarURL != nil {
|
||||
user.GravatarURL = *p.GravatarURL
|
||||
}
|
||||
|
||||
if p.SSOEnabled != nil {
|
||||
user.SSOEnabled = *p.SSOEnabled
|
||||
}
|
||||
|
||||
currentUser := authz.UserFromContext(ctx)
|
||||
|
||||
if p.GlobalRole != nil && *p.GlobalRole != "" {
|
||||
if currentUser.GlobalRole == nil {
|
||||
return nil, ctxerr.New(ctx, "Cannot edit global role as a team member")
|
||||
}
|
||||
|
||||
if p.Teams != nil && len(*p.Teams) > 0 {
|
||||
return nil, fleet.NewInvalidArgumentError("teams", "may not be specified with global_role")
|
||||
}
|
||||
user.GlobalRole = p.GlobalRole
|
||||
user.Teams = []fleet.UserTeam{}
|
||||
} else if p.Teams != nil {
|
||||
if !isAdminOfTheModifiedTeams(currentUser, user.Teams, *p.Teams) {
|
||||
return nil, ctxerr.New(ctx, "Cannot modify teams in that way")
|
||||
}
|
||||
user.Teams = *p.Teams
|
||||
user.GlobalRole = nil
|
||||
}
|
||||
|
||||
err = svc.saveUser(ctx, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func isAdminOfTheModifiedTeams(currentUser *fleet.User, originalUserTeams, newUserTeams []fleet.UserTeam) bool {
|
||||
// If the user is of the right global role, then they can modify the teams
|
||||
if currentUser.GlobalRole != nil && (*currentUser.GlobalRole == fleet.RoleAdmin || *currentUser.GlobalRole == fleet.RoleMaintainer) {
|
||||
return true
|
||||
}
|
||||
|
||||
// otherwise, gather the resulting teams
|
||||
resultingTeams := make(map[uint]string)
|
||||
for _, team := range newUserTeams {
|
||||
resultingTeams[team.ID] = team.Role
|
||||
}
|
||||
|
||||
// and see which ones were removed or changed from the original
|
||||
teamsAffected := make(map[uint]struct{})
|
||||
for _, team := range originalUserTeams {
|
||||
if resultingTeams[team.ID] != team.Role {
|
||||
teamsAffected[team.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// then gather the teams the current user is admin for
|
||||
currentUserTeamAdmin := make(map[uint]struct{})
|
||||
for _, team := range currentUser.Teams {
|
||||
if team.Role == fleet.RoleAdmin {
|
||||
currentUserTeamAdmin[team.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// and let's check that the teams that were either removed or changed are also teams this user is an admin of
|
||||
for teamID := range teamsAffected {
|
||||
if _, ok := currentUserTeamAdmin[teamID]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, email string, password *string) error {
|
||||
// password requirement handled in validation middleware
|
||||
if password != nil {
|
||||
err := user.ValidatePassword(*password)
|
||||
if err != nil {
|
||||
return fleet.NewPermissionError("incorrect password")
|
||||
}
|
||||
}
|
||||
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token := base64.URLEncoding.EncodeToString([]byte(random))
|
||||
err = svc.ds.PendingEmailChange(ctx, user.ID, email, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
changeEmail := fleet.Email{
|
||||
Subject: "Confirm Fleet Email Change",
|
||||
To: []string{email},
|
||||
Config: config,
|
||||
Mailer: &mail.ChangeEmailMailer{
|
||||
Token: token,
|
||||
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
|
||||
AssetURL: getAssetURL(),
|
||||
},
|
||||
}
|
||||
return svc.mailService.SendEmail(changeEmail)
|
||||
}
|
||||
|
||||
func (svc *Service) DeleteUser(ctx context.Context, id uint) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DeleteUser(ctx, id)
|
||||
}
|
||||
|
||||
func (svc *Service) ChangeUserEmail(ctx context.Context, token string) (string, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
|
|
@ -270,14 +97,6 @@ func (svc *Service) ChangeUserEmail(ctx context.Context, token string) (string,
|
|||
return svc.ds.ConfirmPendingEmailChange(ctx, vc.UserID(), token)
|
||||
}
|
||||
|
||||
func (svc *Service) User(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return svc.ds.UserByID(ctx, id)
|
||||
}
|
||||
|
||||
func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
// Explicitly no authorization check. Should only be used by middleware.
|
||||
return svc.ds.UserByID(ctx, id)
|
||||
|
|
@ -299,14 +118,6 @@ func (svc *Service) AuthenticatedUser(ctx context.Context) (*fleet.User, error)
|
|||
return vc.User, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return svc.ds.ListUsers(ctx, opt)
|
||||
}
|
||||
|
||||
// setNewPassword is a helper for changing a user's password. It should be
|
||||
// called to set the new password after proper authorization has been
|
||||
// performed.
|
||||
|
|
@ -326,33 +137,6 @@ func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, passwo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) error {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return fleet.ErrNoContext
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, vc.User, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if vc.User.SSOEnabled {
|
||||
return ctxerr.New(ctx, "change password for single sign on user not allowed")
|
||||
}
|
||||
if err := vc.User.ValidatePassword(newPass); err == nil {
|
||||
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
}
|
||||
|
||||
if err := vc.User.ValidatePassword(oldPass); err != nil {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match"))
|
||||
}
|
||||
|
||||
if err := svc.setNewPassword(ctx, vc.User, newPass); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) ResetPassword(ctx context.Context, token, password string) error {
|
||||
// skipauth: No viewer context available. The user is locked out of their
|
||||
// account and authNZ is performed entirely by providing a valid password
|
||||
|
|
@ -431,34 +215,6 @@ func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password s
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RequirePasswordReset(ctx context.Context, uid uint, require bool) (*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: uid}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := svc.ds.UserByID(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "loading user by ID")
|
||||
}
|
||||
if user.SSOEnabled {
|
||||
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
// Require reset on next login
|
||||
user.AdminForcedPasswordReset = require
|
||||
if err := svc.saveUser(ctx, user); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "saving user")
|
||||
}
|
||||
|
||||
if require {
|
||||
// Clear all of the existing sessions
|
||||
if err := svc.DeleteSessionsForUser(ctx, user.ID); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "deleting user sessions")
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error {
|
||||
// skipauth: No viewer context available. The user is locked out of their
|
||||
// account and trying to reset their password.
|
||||
|
|
@ -512,10 +268,3 @@ func (svc *Service) RequestPasswordReset(ctx context.Context, email string) erro
|
|||
|
||||
return svc.mailService.SendEmail(resetEmail)
|
||||
}
|
||||
|
||||
// saves user in datastore.
|
||||
// doesn't need to be exposed to the transport
|
||||
// the service should expose actions for modifying a user instead
|
||||
func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error {
|
||||
return svc.ds.SaveUser(ctx, user)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package service
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -11,8 +10,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -38,250 +35,6 @@ func TestAuthenticatedUser(t *testing.T) {
|
|||
assert.Equal(t, user, admin1)
|
||||
}
|
||||
|
||||
func TestModifyUserEmail(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
// verify this isn't changed yet
|
||||
assert.Equal(t, "foo@bar.com", u.Email)
|
||||
// verify is changed per bug 1123
|
||||
assert.Equal(t, "minion", u.Position)
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
Password: ptr.String("password"),
|
||||
Position: ptr.String("minion"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.True(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestModifyUserEmailNoPassword(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
// NO PASSWORD
|
||||
// Password: ptr.String("password"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.NotNil(t, err)
|
||||
var iae *fleet.InvalidArgumentError
|
||||
ok := errors.As(err, &iae)
|
||||
require.True(t, ok)
|
||||
require.Len(t, *iae, 1)
|
||||
assert.False(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.False(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestModifyAdminUserEmailNoPassword(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
// NO PASSWORD
|
||||
// Password: ptr.String("password"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.NotNil(t, err)
|
||||
var iae *fleet.InvalidArgumentError
|
||||
ok := errors.As(err, &iae)
|
||||
require.True(t, ok)
|
||||
require.Len(t, *iae, 1)
|
||||
assert.False(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.False(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestModifyAdminUserEmailPassword(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
Password: ptr.String("password"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.True(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestChangePassword(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
users := createTestUsers(t, ds)
|
||||
passwordChangeTests := []struct {
|
||||
user fleet.User
|
||||
oldPassword string
|
||||
newPassword string
|
||||
anyErr bool
|
||||
wantErr error
|
||||
}{
|
||||
{ // all good
|
||||
user: users["admin1@example.com"],
|
||||
oldPassword: "foobarbaz1234!",
|
||||
newPassword: "12345cat!",
|
||||
},
|
||||
{ // prevent password reuse
|
||||
user: users["admin1@example.com"],
|
||||
oldPassword: "12345cat!",
|
||||
newPassword: "foobarbaz1234!",
|
||||
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
|
||||
},
|
||||
{ // all good
|
||||
user: users["user1@example.com"],
|
||||
oldPassword: "foobarbaz1234!",
|
||||
newPassword: "newpassa1234!",
|
||||
},
|
||||
{ // bad old password
|
||||
user: users["user1@example.com"],
|
||||
oldPassword: "wrong_password",
|
||||
newPassword: "12345cat!",
|
||||
anyErr: true,
|
||||
},
|
||||
{ // missing old password
|
||||
newPassword: "123cataaa!",
|
||||
wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range passwordChangeTests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &tt.user})
|
||||
|
||||
err := svc.ChangePassword(ctx, tt.oldPassword, tt.newPassword)
|
||||
if tt.anyErr {
|
||||
require.NotNil(t, err)
|
||||
} else if tt.wantErr != nil {
|
||||
require.Equal(t, tt.wantErr, ctxerr.Cause(err))
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt login after successful change
|
||||
_, _, err = svc.Login(context.Background(), tt.user.Email, tt.newPassword)
|
||||
require.Nil(t, err, "should be able to login with new password")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetPassword(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
|
|
@ -340,48 +93,6 @@ func TestResetPassword(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRequirePasswordReset(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
svc := newTestService(ds, nil, nil)
|
||||
createTestUsers(t, ds)
|
||||
|
||||
for _, tt := range testUsers {
|
||||
t.Run(tt.Email, func(t *testing.T) {
|
||||
user, err := ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
|
||||
var sessions []*fleet.Session
|
||||
|
||||
// Log user in
|
||||
_, _, err = svc.Login(test.UserContext(test.UserAdmin), tt.Email, tt.PlaintextPassword)
|
||||
require.Nil(t, err, "login unsuccessful")
|
||||
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, sessions, 1, "user should have one session")
|
||||
|
||||
// Reset and verify sessions destroyed
|
||||
retUser, err := svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, retUser.AdminForcedPasswordReset)
|
||||
checkUser, err := ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, checkUser.AdminForcedPasswordReset)
|
||||
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, sessions, 0, "sessions should be destroyed")
|
||||
|
||||
// try undo
|
||||
retUser, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, false)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, retUser.AdminForcedPasswordReset)
|
||||
checkUser, err = ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, checkUser.AdminForcedPasswordReset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context {
|
||||
reloadedUser, err := ds.UserByEmail(ctx, user.Email)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -475,178 +186,3 @@ func TestUserPasswordRequirements(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAuth(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
ds.InviteByTokenFunc = func(ctx context.Context, token string) (*fleet.Invite, error) {
|
||||
return &fleet.Invite{
|
||||
Email: "some@email.com",
|
||||
Token: "ABCD",
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) {
|
||||
return &fleet.User{}, nil
|
||||
}
|
||||
ds.DeleteInviteFunc = func(ctx context.Context, id uint) error {
|
||||
return nil
|
||||
}
|
||||
ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) {
|
||||
return nil, errors.New("AA")
|
||||
}
|
||||
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
if id == 999 {
|
||||
return &fleet.User{
|
||||
ID: 999,
|
||||
Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}},
|
||||
}, nil
|
||||
}
|
||||
return &fleet.User{
|
||||
ID: 888,
|
||||
GlobalRole: ptr.String(fleet.RoleMaintainer),
|
||||
}, nil
|
||||
}
|
||||
ds.SaveUserFunc = func(ctx context.Context, user *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *fleet.User
|
||||
shouldFailGlobalWrite bool
|
||||
shouldFailTeamWrite bool
|
||||
shouldFailRead bool
|
||||
}{
|
||||
{
|
||||
"global admin",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global maintainer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"global observer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team admin, belongs to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"team maintainer, belongs to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"team observer, belongs to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team maintainer, DOES NOT belong to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team admin, DOES NOT belong to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}},
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team observer, DOES NOT belong to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
|
||||
|
||||
teams := []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}
|
||||
_, err := svc.CreateUser(ctx, fleet.UserPayload{
|
||||
Name: ptr.String("Some Name"),
|
||||
Email: ptr.String("some@email.com"),
|
||||
Password: ptr.String("passw0rd."),
|
||||
Teams: &teams,
|
||||
})
|
||||
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
|
||||
_, err = svc.CreateUser(ctx, fleet.UserPayload{
|
||||
Name: ptr.String("Some Name"),
|
||||
Email: ptr.String("some@email.com"),
|
||||
Password: ptr.String("passw0rd."),
|
||||
GlobalRole: ptr.String(fleet.RoleAdmin),
|
||||
})
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
||||
_, err = svc.ModifyUser(ctx, 999, fleet.UserPayload{Teams: &teams})
|
||||
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
|
||||
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{Teams: &teams})
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
||||
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{GlobalRole: ptr.String(fleet.RoleMaintainer)})
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that CreateUser creates a user that will be forced to
|
||||
// reset its password upon first login (see #2570).
|
||||
func TestCreateUserForcePasswdReset(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
// Create admin user.
|
||||
admin := &fleet.User{
|
||||
Name: "Fleet Admin",
|
||||
Email: "admin@foo.com",
|
||||
GlobalRole: ptr.String(fleet.RoleAdmin),
|
||||
}
|
||||
err := admin.SetPassword("p4ssw0rd.", 10, 10)
|
||||
require.NoError(t, err)
|
||||
admin, err = ds.NewUser(context.Background(), admin)
|
||||
require.NoError(t, err)
|
||||
|
||||
// As the admin, create a new user.
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: admin})
|
||||
user, err := svc.CreateUser(ctx, fleet.UserPayload{
|
||||
Name: ptr.String("Some Observer"),
|
||||
Email: ptr.String("some-observer@email.com"),
|
||||
Password: ptr.String("passw0rd."),
|
||||
GlobalRole: ptr.String(fleet.RoleObserver),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err = ds.UserByID(context.Background(), user.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, user.AdminForcedPasswordReset)
|
||||
}
|
||||
|
|
|
|||
8
server/service/sessions_test.go
Normal file
8
server/service/sessions_test.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
package service
|
||||
|
||||
// TODO(mna): when migrating Session-related endpoints, add auth tests for those
|
||||
// endpoints (the auth is session-based, not user-based).
|
||||
//_, err = svc.GetInfoAboutSessionsForUser(ctx, 999)
|
||||
//checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
//_, err = svc.GetInfoAboutSessionsForUser(ctx, 888)
|
||||
//checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
|
@ -18,14 +18,6 @@ func decodeGetInfoAboutSessionRequest(ctx context.Context, r *http.Request) (int
|
|||
return getInfoAboutSessionRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeGetInfoAboutSessionsForUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getInfoAboutSessionsForUserRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
|
|
@ -34,14 +26,6 @@ func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface
|
|||
return deleteSessionRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeDeleteSessionsForUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return deleteSessionsForUserRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
|
|
|
|||
|
|
@ -27,22 +27,6 @@ func TestDecodeGetInfoAboutSessionRequest(t *testing.T) {
|
|||
)
|
||||
}
|
||||
|
||||
func TestDecodeGetInfoAboutSessionsForUserRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/user/{id}/sessions", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeGetInfoAboutSessionsForUserRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(getInfoAboutSessionsForUserRequest)
|
||||
assert.Equal(t, uint(1), params.ID)
|
||||
}).Methods("GET")
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("GET", "/api/v1/fleet/users/1/sessions", nil),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeDeleteSessionRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/sessions/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
|
@ -59,22 +43,6 @@ func TestDecodeDeleteSessionRequest(t *testing.T) {
|
|||
)
|
||||
}
|
||||
|
||||
func TestDecodeDeleteSessionsForUserRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/user/{id}/sessions", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeDeleteSessionsForUserRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(deleteSessionsForUserRequest)
|
||||
assert.Equal(t, uint(1), params.ID)
|
||||
}).Methods("DELETE")
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("DELETE", "/api/v1/fleet/users/1/sessions", nil),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeLoginRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
|
|
|||
|
|
@ -10,69 +10,9 @@ import (
|
|||
|
||||
func decodeCreateUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req createUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeGetUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getUserRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeListUsersRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
opt, err := userListOptionsFromRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listUsersRequest{ListOptions: opt}, nil
|
||||
}
|
||||
|
||||
func decodeModifyUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var req modifyUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.ID = uint(id)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeDeleteUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return deleteUserRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeChangePasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req changePasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeRequirePasswordResetRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "getting ID from request")
|
||||
}
|
||||
|
||||
var req requirePasswordResetRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "decoding JSON")
|
||||
}
|
||||
req.ID = uint(id)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,68 +11,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeCreateUserRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/users", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeCreateUserRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(createUserRequest)
|
||||
assert.Equal(t, "foo", *params.payload.Name)
|
||||
assert.Equal(t, "foo@fleet.co", *params.payload.Email)
|
||||
}).Methods("POST")
|
||||
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"name": "foo",
|
||||
"email": "foo@fleet.co"
|
||||
}`))
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("POST", "/api/v1/fleet/users", &body),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeGetUserRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/users/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeGetUserRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(getUserRequest)
|
||||
assert.Equal(t, uint(1), params.ID)
|
||||
}).Methods("GET")
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("GET", "/api/v1/fleet/users/1", nil),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeChangePasswordRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/change_password", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeChangePasswordRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(changePasswordRequest)
|
||||
assert.Equal(t, "foo", params.OldPassword)
|
||||
assert.Equal(t, "bar", params.NewPassword)
|
||||
}).Methods("POST")
|
||||
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"old_password": "foo",
|
||||
"new_password": "bar"
|
||||
}`))
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("POST", "/api/v1/fleet/change_password", &body),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeResetPasswordRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/users/{id}/password", func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
|
@ -95,28 +33,3 @@ func TestDecodeResetPasswordRequest(t *testing.T) {
|
|||
httptest.NewRequest("POST", "/api/v1/fleet/users/1/password", &body),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeModifyUserRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/users/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeModifyUserRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(modifyUserRequest)
|
||||
assert.Equal(t, "foo", *params.payload.Name)
|
||||
assert.Equal(t, "foo@fleet.co", *params.payload.Email)
|
||||
assert.Equal(t, uint(1), params.ID)
|
||||
}).Methods("PATCH")
|
||||
|
||||
var body bytes.Buffer
|
||||
body.Write([]byte(`{
|
||||
"name": "foo",
|
||||
"email": "foo@fleet.co"
|
||||
}`))
|
||||
|
||||
request := httptest.NewRequest("PATCH", "/api/v1/fleet/users/1", &body)
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
request,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
521
server/service/users.go
Normal file
521
server/service/users.go
Normal file
|
|
@ -0,0 +1,521 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"html/template"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mail"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Create User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type createUserRequest struct {
|
||||
fleet.UserPayload
|
||||
}
|
||||
|
||||
type createUserResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r createUserResponse) error() error { return r.Err }
|
||||
|
||||
func createUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*createUserRequest)
|
||||
user, err := svc.CreateUser(ctx, req.UserPayload)
|
||||
if err != nil {
|
||||
return createUserResponse{Err: err}, nil
|
||||
}
|
||||
return createUserResponse{User: user}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
|
||||
var teams []fleet.UserTeam
|
||||
if p.Teams != nil {
|
||||
teams = *p.Teams
|
||||
}
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{Teams: teams}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil {
|
||||
return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email)
|
||||
}
|
||||
|
||||
if p.AdminForcedPasswordReset == nil {
|
||||
// By default, force password reset for users created this way.
|
||||
p.AdminForcedPasswordReset = ptr.Bool(true)
|
||||
}
|
||||
|
||||
return svc.newUser(ctx, p)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// List Users
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type listUsersRequest struct {
|
||||
ListOptions fleet.UserListOptions `url:"user_options"`
|
||||
}
|
||||
|
||||
type listUsersResponse struct {
|
||||
Users []fleet.User `json:"users"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r listUsersResponse) error() error { return r.Err }
|
||||
|
||||
func listUsersEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*listUsersRequest)
|
||||
users, err := svc.ListUsers(ctx, req.ListOptions)
|
||||
if err != nil {
|
||||
return listUsersResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
resp := listUsersResponse{Users: []fleet.User{}}
|
||||
for _, user := range users {
|
||||
resp.Users = append(resp.Users, *user)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return svc.ds.ListUsers(ctx, opt)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getUserRequest struct {
|
||||
ID uint `url:"id"`
|
||||
}
|
||||
|
||||
type getUserResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getUserResponse) error() error { return r.Err }
|
||||
|
||||
func getUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*getUserRequest)
|
||||
user, err := svc.User(ctx, req.ID)
|
||||
if err != nil {
|
||||
return getUserResponse{Err: err}, nil
|
||||
}
|
||||
return getUserResponse{User: user}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) User(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return svc.ds.UserByID(ctx, id)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Modify User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type modifyUserRequest struct {
|
||||
ID uint `json:"-" url:"id"`
|
||||
fleet.UserPayload
|
||||
}
|
||||
|
||||
type modifyUserResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r modifyUserResponse) error() error { return r.Err }
|
||||
|
||||
func modifyUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*modifyUserRequest)
|
||||
user, err := svc.ModifyUser(ctx, req.ID, req.UserPayload)
|
||||
if err != nil {
|
||||
return modifyUserResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
return modifyUserResponse{User: user}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := svc.User(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.GlobalRole != nil || p.Teams != nil {
|
||||
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if p.Name != nil {
|
||||
user.Name = *p.Name
|
||||
}
|
||||
|
||||
if p.Email != nil && *p.Email != user.Email {
|
||||
err = svc.modifyEmailAddress(ctx, user, *p.Email, p.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if p.Position != nil {
|
||||
user.Position = *p.Position
|
||||
}
|
||||
|
||||
if p.GravatarURL != nil {
|
||||
user.GravatarURL = *p.GravatarURL
|
||||
}
|
||||
|
||||
if p.SSOEnabled != nil {
|
||||
user.SSOEnabled = *p.SSOEnabled
|
||||
}
|
||||
|
||||
currentUser := authz.UserFromContext(ctx)
|
||||
|
||||
if p.GlobalRole != nil && *p.GlobalRole != "" {
|
||||
if currentUser.GlobalRole == nil {
|
||||
return nil, ctxerr.New(ctx, "Cannot edit global role as a team member")
|
||||
}
|
||||
|
||||
if p.Teams != nil && len(*p.Teams) > 0 {
|
||||
return nil, fleet.NewInvalidArgumentError("teams", "may not be specified with global_role")
|
||||
}
|
||||
user.GlobalRole = p.GlobalRole
|
||||
user.Teams = []fleet.UserTeam{}
|
||||
} else if p.Teams != nil {
|
||||
if !isAdminOfTheModifiedTeams(currentUser, user.Teams, *p.Teams) {
|
||||
return nil, ctxerr.New(ctx, "Cannot modify teams in that way")
|
||||
}
|
||||
user.Teams = *p.Teams
|
||||
user.GlobalRole = nil
|
||||
}
|
||||
|
||||
err = svc.saveUser(ctx, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type deleteUserRequest struct {
|
||||
ID uint `url:"id"`
|
||||
}
|
||||
|
||||
type deleteUserResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r deleteUserResponse) error() error { return r.Err }
|
||||
|
||||
func deleteUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*deleteUserRequest)
|
||||
err := svc.DeleteUser(ctx, req.ID)
|
||||
if err != nil {
|
||||
return deleteUserResponse{Err: err}, nil
|
||||
}
|
||||
return deleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) DeleteUser(ctx context.Context, id uint) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DeleteUser(ctx, id)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Require Password Reset
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type requirePasswordResetRequest struct {
|
||||
Require bool `json:"require"`
|
||||
ID uint `json:"-" url:"id"`
|
||||
}
|
||||
|
||||
type requirePasswordResetResponse struct {
|
||||
User *fleet.User `json:"user,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r requirePasswordResetResponse) error() error { return r.Err }
|
||||
|
||||
func requirePasswordResetEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*requirePasswordResetRequest)
|
||||
user, err := svc.RequirePasswordReset(ctx, req.ID, req.Require)
|
||||
if err != nil {
|
||||
return requirePasswordResetResponse{Err: err}, nil
|
||||
}
|
||||
return requirePasswordResetResponse{User: user}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) RequirePasswordReset(ctx context.Context, uid uint, require bool) (*fleet.User, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: uid}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := svc.ds.UserByID(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "loading user by ID")
|
||||
}
|
||||
if user.SSOEnabled {
|
||||
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
|
||||
}
|
||||
// Require reset on next login
|
||||
user.AdminForcedPasswordReset = require
|
||||
if err := svc.saveUser(ctx, user); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "saving user")
|
||||
}
|
||||
|
||||
if require {
|
||||
// Clear all of the existing sessions
|
||||
if err := svc.DeleteSessionsForUser(ctx, user.ID); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "deleting user sessions")
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Change Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type changePasswordRequest struct {
|
||||
OldPassword string `json:"old_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type changePasswordResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r changePasswordResponse) error() error { return r.Err }
|
||||
|
||||
func changePasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*changePasswordRequest)
|
||||
err := svc.ChangePassword(ctx, req.OldPassword, req.NewPassword)
|
||||
return changePasswordResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) error {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return fleet.ErrNoContext
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, vc.User, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if vc.User.SSOEnabled {
|
||||
return ctxerr.New(ctx, "change password for single sign on user not allowed")
|
||||
}
|
||||
if err := vc.User.ValidatePassword(newPass); err == nil {
|
||||
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
|
||||
}
|
||||
|
||||
if err := vc.User.ValidatePassword(oldPass); err != nil {
|
||||
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match"))
|
||||
}
|
||||
|
||||
if err := svc.setNewPassword(ctx, vc.User, newPass); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "setting new password")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get Info About Sessions For User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getInfoAboutSessionsForUserRequest struct {
|
||||
ID uint `url:"id"`
|
||||
}
|
||||
|
||||
type getInfoAboutSessionsForUserResponse struct {
|
||||
Sessions []getInfoAboutSessionResponse `json:"sessions"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getInfoAboutSessionsForUserResponse) error() error { return r.Err }
|
||||
|
||||
func getInfoAboutSessionsForUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*getInfoAboutSessionsForUserRequest)
|
||||
sessions, err := svc.GetInfoAboutSessionsForUser(ctx, req.ID)
|
||||
if err != nil {
|
||||
return getInfoAboutSessionsForUserResponse{Err: err}, nil
|
||||
}
|
||||
var resp getInfoAboutSessionsForUserResponse
|
||||
for _, session := range sessions {
|
||||
resp.Sessions = append(resp.Sessions, getInfoAboutSessionResponse{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validatedSessions []*fleet.Session
|
||||
|
||||
sessions, err := svc.ds.ListSessionsForUser(ctx, id)
|
||||
if err != nil {
|
||||
return validatedSessions, err
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
if svc.validateSession(ctx, session) == nil {
|
||||
validatedSessions = append(validatedSessions, session)
|
||||
}
|
||||
}
|
||||
|
||||
return validatedSessions, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete Sessions For User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type deleteSessionsForUserRequest struct {
|
||||
ID uint `url:"id"`
|
||||
}
|
||||
|
||||
type deleteSessionsForUserResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r deleteSessionsForUserResponse) error() error { return r.Err }
|
||||
|
||||
func deleteSessionsForUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*deleteSessionsForUserRequest)
|
||||
err := svc.DeleteSessionsForUser(ctx, req.ID)
|
||||
if err != nil {
|
||||
return deleteSessionsForUserResponse{Err: err}, nil
|
||||
}
|
||||
return deleteSessionsForUserResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) DeleteSessionsForUser(ctx context.Context, id uint) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DestroyAllSessionsForUser(ctx, id)
|
||||
}
|
||||
|
||||
func isAdminOfTheModifiedTeams(currentUser *fleet.User, originalUserTeams, newUserTeams []fleet.UserTeam) bool {
|
||||
// If the user is of the right global role, then they can modify the teams
|
||||
if currentUser.GlobalRole != nil && (*currentUser.GlobalRole == fleet.RoleAdmin || *currentUser.GlobalRole == fleet.RoleMaintainer) {
|
||||
return true
|
||||
}
|
||||
|
||||
// otherwise, gather the resulting teams
|
||||
resultingTeams := make(map[uint]string)
|
||||
for _, team := range newUserTeams {
|
||||
resultingTeams[team.ID] = team.Role
|
||||
}
|
||||
|
||||
// and see which ones were removed or changed from the original
|
||||
teamsAffected := make(map[uint]struct{})
|
||||
for _, team := range originalUserTeams {
|
||||
if resultingTeams[team.ID] != team.Role {
|
||||
teamsAffected[team.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// then gather the teams the current user is admin for
|
||||
currentUserTeamAdmin := make(map[uint]struct{})
|
||||
for _, team := range currentUser.Teams {
|
||||
if team.Role == fleet.RoleAdmin {
|
||||
currentUserTeamAdmin[team.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// and let's check that the teams that were either removed or changed are also teams this user is an admin of
|
||||
for teamID := range teamsAffected {
|
||||
if _, ok := currentUserTeamAdmin[teamID]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, email string, password *string) error {
|
||||
// password requirement handled in validation middleware
|
||||
if password != nil {
|
||||
err := user.ValidatePassword(*password)
|
||||
if err != nil {
|
||||
return fleet.NewPermissionError("incorrect password")
|
||||
}
|
||||
}
|
||||
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token := base64.URLEncoding.EncodeToString([]byte(random))
|
||||
err = svc.ds.PendingEmailChange(ctx, user.ID, email, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
changeEmail := fleet.Email{
|
||||
Subject: "Confirm Fleet Email Change",
|
||||
To: []string{email},
|
||||
Config: config,
|
||||
Mailer: &mail.ChangeEmailMailer{
|
||||
Token: token,
|
||||
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
|
||||
AssetURL: getAssetURL(),
|
||||
},
|
||||
}
|
||||
return svc.mailService.SendEmail(changeEmail)
|
||||
}
|
||||
|
||||
// saves user in datastore.
|
||||
// doesn't need to be exposed to the transport
|
||||
// the service should expose actions for modifying a user instead
|
||||
func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error {
|
||||
return svc.ds.SaveUser(ctx, user)
|
||||
}
|
||||
530
server/service/users_test.go
Normal file
530
server/service/users_test.go
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserAuth(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
ds.InviteByTokenFunc = func(ctx context.Context, token string) (*fleet.Invite, error) {
|
||||
return &fleet.Invite{
|
||||
Email: "some@email.com",
|
||||
Token: "ABCD",
|
||||
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
|
||||
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
|
||||
UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) {
|
||||
return &fleet.User{}, nil
|
||||
}
|
||||
ds.DeleteInviteFunc = func(ctx context.Context, id uint) error {
|
||||
return nil
|
||||
}
|
||||
ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) {
|
||||
return nil, errors.New("AA")
|
||||
}
|
||||
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
if id == 999 {
|
||||
return &fleet.User{
|
||||
ID: 999,
|
||||
Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}},
|
||||
}, nil
|
||||
}
|
||||
return &fleet.User{
|
||||
ID: 888,
|
||||
GlobalRole: ptr.String(fleet.RoleMaintainer),
|
||||
}, nil
|
||||
}
|
||||
ds.SaveUserFunc = func(ctx context.Context, user *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
ds.ListUsersFunc = func(ctx context.Context, opts fleet.UserListOptions) ([]*fleet.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
ds.DeleteUserFunc = func(ctx context.Context, id uint) error {
|
||||
return nil
|
||||
}
|
||||
ds.DestroyAllSessionsForUserFunc = func(ctx context.Context, id uint) error {
|
||||
return nil
|
||||
}
|
||||
ds.ListSessionsForUserFunc = func(ctx context.Context, id uint) ([]*fleet.Session, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *fleet.User
|
||||
shouldFailGlobalWrite bool
|
||||
shouldFailTeamWrite bool
|
||||
shouldFailRead bool
|
||||
shouldFailDeleteReset bool
|
||||
}{
|
||||
{
|
||||
"global admin",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global maintainer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"global observer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team admin, belongs to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team maintainer, belongs to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team observer, belongs to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team maintainer, DOES NOT belong to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team admin, DOES NOT belong to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team observer, DOES NOT belong to team",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
|
||||
|
||||
teams := []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}
|
||||
_, err := svc.CreateUser(ctx, fleet.UserPayload{
|
||||
Name: ptr.String("Some Name"),
|
||||
Email: ptr.String("some@email.com"),
|
||||
Password: ptr.String("passw0rd."),
|
||||
Teams: &teams,
|
||||
})
|
||||
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
|
||||
_, err = svc.CreateUser(ctx, fleet.UserPayload{
|
||||
Name: ptr.String("Some Name"),
|
||||
Email: ptr.String("some@email.com"),
|
||||
Password: ptr.String("passw0rd."),
|
||||
GlobalRole: ptr.String(fleet.RoleAdmin),
|
||||
})
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
||||
_, err = svc.ModifyUser(ctx, 999, fleet.UserPayload{Teams: &teams})
|
||||
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
|
||||
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{Teams: &teams})
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
||||
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{GlobalRole: ptr.String(fleet.RoleMaintainer)})
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
||||
err = svc.DeleteUser(ctx, 999)
|
||||
checkAuthErr(t, tt.shouldFailDeleteReset, err)
|
||||
|
||||
_, err = svc.RequirePasswordReset(ctx, 999, false)
|
||||
checkAuthErr(t, tt.shouldFailDeleteReset, err)
|
||||
|
||||
_, err = svc.ListUsers(ctx, fleet.UserListOptions{})
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
|
||||
_, err = svc.User(ctx, 999)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
|
||||
_, err = svc.User(ctx, 888)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyUserEmail(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
// verify this isn't changed yet
|
||||
assert.Equal(t, "foo@bar.com", u.Email)
|
||||
// verify is changed per bug 1123
|
||||
assert.Equal(t, "minion", u.Position)
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
Password: ptr.String("password"),
|
||||
Position: ptr.String("minion"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.True(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestModifyUserEmailNoPassword(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
// NO PASSWORD
|
||||
// Password: ptr.String("password"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.NotNil(t, err)
|
||||
var iae *fleet.InvalidArgumentError
|
||||
ok := errors.As(err, &iae)
|
||||
require.True(t, ok)
|
||||
require.Len(t, *iae, 1)
|
||||
assert.False(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.False(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestModifyAdminUserEmailNoPassword(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
// NO PASSWORD
|
||||
// Password: ptr.String("password"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.NotNil(t, err)
|
||||
var iae *fleet.InvalidArgumentError
|
||||
ok := errors.As(err, &iae)
|
||||
require.True(t, ok)
|
||||
require.Len(t, *iae, 1)
|
||||
assert.False(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.False(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestModifyAdminUserEmailPassword(t *testing.T) {
|
||||
user := &fleet.User{
|
||||
ID: 3,
|
||||
Email: "foo@bar.com",
|
||||
}
|
||||
user.SetPassword("password", 10, 10)
|
||||
ms := new(mock.Store)
|
||||
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
|
||||
return nil
|
||||
}
|
||||
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
config := &fleet.AppConfig{
|
||||
SMTPSettings: fleet.SMTPSettings{
|
||||
SMTPConfigured: true,
|
||||
SMTPAuthenticationType: fleet.AuthTypeNameNone,
|
||||
SMTPPort: 1025,
|
||||
SMTPServer: "127.0.0.1",
|
||||
SMTPSenderAddress: "xxx@fleet.co",
|
||||
},
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
|
||||
return nil
|
||||
}
|
||||
svc := newTestService(ms, nil, nil)
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
|
||||
payload := fleet.UserPayload{
|
||||
Email: ptr.String("zip@zap.com"),
|
||||
Password: ptr.String("password"),
|
||||
}
|
||||
_, err := svc.ModifyUser(ctx, 3, payload)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, ms.PendingEmailChangeFuncInvoked)
|
||||
assert.True(t, ms.SaveUserFuncInvoked)
|
||||
}
|
||||
|
||||
func TestUsersWithDS(t *testing.T) {
|
||||
ds := mysql.CreateMySQLDS(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
fn func(t *testing.T, ds *mysql.Datastore)
|
||||
}{
|
||||
{"CreateUserForcePasswdReset", testUsersCreateUserForcePasswdReset},
|
||||
{"ChangePassword", testUsersChangePassword},
|
||||
{"RequirePasswordReset", testUsersRequirePasswordReset},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
defer mysql.TruncateTables(t, ds)
|
||||
c.fn(t, ds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that CreateUser creates a user that will be forced to
|
||||
// reset its password upon first login (see #2570).
|
||||
func testUsersCreateUserForcePasswdReset(t *testing.T, ds *mysql.Datastore) {
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
// Create admin user.
|
||||
admin := &fleet.User{
|
||||
Name: "Fleet Admin",
|
||||
Email: "admin@foo.com",
|
||||
GlobalRole: ptr.String(fleet.RoleAdmin),
|
||||
}
|
||||
err := admin.SetPassword("p4ssw0rd.", 10, 10)
|
||||
require.NoError(t, err)
|
||||
admin, err = ds.NewUser(context.Background(), admin)
|
||||
require.NoError(t, err)
|
||||
|
||||
// As the admin, create a new user.
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: admin})
|
||||
user, err := svc.CreateUser(ctx, fleet.UserPayload{
|
||||
Name: ptr.String("Some Observer"),
|
||||
Email: ptr.String("some-observer@email.com"),
|
||||
Password: ptr.String("passw0rd."),
|
||||
GlobalRole: ptr.String(fleet.RoleObserver),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err = ds.UserByID(context.Background(), user.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, user.AdminForcedPasswordReset)
|
||||
}
|
||||
|
||||
func testUsersChangePassword(t *testing.T, ds *mysql.Datastore) {
|
||||
svc := newTestService(ds, nil, nil)
|
||||
users := createTestUsers(t, ds)
|
||||
passwordChangeTests := []struct {
|
||||
user fleet.User
|
||||
oldPassword string
|
||||
newPassword string
|
||||
anyErr bool
|
||||
wantErr error
|
||||
}{
|
||||
{ // all good
|
||||
user: users["admin1@example.com"],
|
||||
oldPassword: "foobarbaz1234!",
|
||||
newPassword: "12345cat!",
|
||||
},
|
||||
{ // prevent password reuse
|
||||
user: users["admin1@example.com"],
|
||||
oldPassword: "12345cat!",
|
||||
newPassword: "foobarbaz1234!",
|
||||
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
|
||||
},
|
||||
{ // all good
|
||||
user: users["user1@example.com"],
|
||||
oldPassword: "foobarbaz1234!",
|
||||
newPassword: "newpassa1234!",
|
||||
},
|
||||
{ // bad old password
|
||||
user: users["user1@example.com"],
|
||||
oldPassword: "wrong_password",
|
||||
newPassword: "12345cat!",
|
||||
anyErr: true,
|
||||
},
|
||||
{ // missing old password
|
||||
newPassword: "123cataaa!",
|
||||
wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range passwordChangeTests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &tt.user})
|
||||
|
||||
err := svc.ChangePassword(ctx, tt.oldPassword, tt.newPassword)
|
||||
if tt.anyErr {
|
||||
require.NotNil(t, err)
|
||||
} else if tt.wantErr != nil {
|
||||
require.Equal(t, tt.wantErr, ctxerr.Cause(err))
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt login after successful change
|
||||
_, _, err = svc.Login(context.Background(), tt.user.Email, tt.newPassword)
|
||||
require.Nil(t, err, "should be able to login with new password")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testUsersRequirePasswordReset(t *testing.T, ds *mysql.Datastore) {
|
||||
svc := newTestService(ds, nil, nil)
|
||||
createTestUsers(t, ds)
|
||||
|
||||
for _, tt := range testUsers {
|
||||
t.Run(tt.Email, func(t *testing.T) {
|
||||
user, err := ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
|
||||
var sessions []*fleet.Session
|
||||
|
||||
// Log user in
|
||||
_, _, err = svc.Login(test.UserContext(test.UserAdmin), tt.Email, tt.PlaintextPassword)
|
||||
require.Nil(t, err, "login unsuccessful")
|
||||
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, sessions, 1, "user should have one session")
|
||||
|
||||
// Reset and verify sessions destroyed
|
||||
retUser, err := svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, retUser.AdminForcedPasswordReset)
|
||||
checkUser, err := ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
assert.True(t, checkUser.AdminForcedPasswordReset)
|
||||
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, sessions, 0, "sessions should be destroyed")
|
||||
|
||||
// try undo
|
||||
retUser, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, false)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, retUser.AdminForcedPasswordReset)
|
||||
checkUser, err = ds.UserByEmail(context.Background(), tt.Email)
|
||||
require.Nil(t, err)
|
||||
assert.False(t, checkUser.AdminForcedPasswordReset)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue