Migrate most users endpoints to the new pattern (#3366)

This commit is contained in:
Martin Angers 2022-01-10 14:43:39 -05:00 committed by GitHub
parent a5bef8a990
commit 597144bfac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 1178 additions and 1252 deletions

View file

@ -91,40 +91,6 @@ func makeGetInfoAboutSessionEndpoint(svc fleet.Service) endpoint.Endpoint {
}
}
////////////////////////////////////////////////////////////////////////////////
// Get Info About Sessions For User
////////////////////////////////////////////////////////////////////////////////
type getInfoAboutSessionsForUserRequest struct {
ID uint
}
type getInfoAboutSessionsForUserResponse struct {
Sessions []getInfoAboutSessionResponse `json:"sessions"`
Err error `json:"error,omitempty"`
}
func (r getInfoAboutSessionsForUserResponse) error() error { return r.Err }
func makeGetInfoAboutSessionsForUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(getInfoAboutSessionsForUserRequest)
sessions, err := svc.GetInfoAboutSessionsForUser(ctx, req.ID)
if err != nil {
return getInfoAboutSessionsForUserResponse{Err: err}, nil
}
var resp getInfoAboutSessionsForUserResponse
for _, session := range sessions {
resp.Sessions = append(resp.Sessions, getInfoAboutSessionResponse{
SessionID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
})
}
return resp, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Delete Session
////////////////////////////////////////////////////////////////////////////////
@ -150,31 +116,6 @@ func makeDeleteSessionEndpoint(svc fleet.Service) endpoint.Endpoint {
}
}
////////////////////////////////////////////////////////////////////////////////
// Delete Sessions For User
////////////////////////////////////////////////////////////////////////////////
type deleteSessionsForUserRequest struct {
ID uint
}
type deleteSessionsForUserResponse struct {
Err error `json:"error,omitempty"`
}
func (r deleteSessionsForUserResponse) error() error { return r.Err }
func makeDeleteSessionsForUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(deleteSessionsForUserRequest)
err := svc.DeleteSessionsForUser(ctx, req.ID)
if err != nil {
return deleteSessionsForUserResponse{Err: err}, nil
}
return deleteSessionsForUserResponse{}, nil
}
}
type initiateSSORequest struct {
RelayURL string `json:"relay_url"`
}

View file

@ -12,21 +12,10 @@ import (
// Create User With Invite
////////////////////////////////////////////////////////////////////////////////
type createUserRequest struct {
payload fleet.UserPayload
}
type createUserResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r createUserResponse) error() error { return r.Err }
func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(createUserRequest)
user, err := svc.CreateUserFromInvite(ctx, req.payload)
user, err := svc.CreateUserFromInvite(ctx, req.UserPayload)
if err != nil {
return createUserResponse{Err: err}, nil
}
@ -34,47 +23,6 @@ func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
}
}
////////////////////////////////////////////////////////////////////////////////
// Create User
////////////////////////////////////////////////////////////////////////////////
func makeCreateUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(createUserRequest)
user, err := svc.CreateUser(ctx, req.payload)
if err != nil {
return createUserResponse{Err: err}, nil
}
return createUserResponse{User: user}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Get User
////////////////////////////////////////////////////////////////////////////////
type getUserRequest struct {
ID uint `json:"id"`
}
type getUserResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getUserResponse) error() error { return r.Err }
func makeGetUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(getUserRequest)
user, err := svc.User(ctx, req.ID)
if err != nil {
return getUserResponse{Err: err}, nil
}
return getUserResponse{User: user}, nil
}
}
func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
user, err := svc.AuthenticatedUser(ctx)
@ -85,60 +33,6 @@ func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint {
}
}
////////////////////////////////////////////////////////////////////////////////
// List Users
////////////////////////////////////////////////////////////////////////////////
type listUsersRequest struct {
ListOptions fleet.UserListOptions
}
type listUsersResponse struct {
Users []fleet.User `json:"users"`
Err error `json:"error,omitempty"`
}
func (r listUsersResponse) error() error { return r.Err }
func makeListUsersEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listUsersRequest)
users, err := svc.ListUsers(ctx, req.ListOptions)
if err != nil {
return listUsersResponse{Err: err}, nil
}
resp := listUsersResponse{Users: []fleet.User{}}
for _, user := range users {
resp.Users = append(resp.Users, *user)
}
return resp, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Change Password
////////////////////////////////////////////////////////////////////////////////
type changePasswordRequest struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
type changePasswordResponse struct {
Err error `json:"error,omitempty"`
}
func (r changePasswordResponse) error() error { return r.Err }
func makeChangePasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(changePasswordRequest)
err := svc.ChangePassword(ctx, req.OldPassword, req.NewPassword)
return changePasswordResponse{Err: err}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Reset Password
////////////////////////////////////////////////////////////////////////////////
@ -162,59 +56,6 @@ func makeResetPasswordEndpoint(svc fleet.Service) endpoint.Endpoint {
}
}
////////////////////////////////////////////////////////////////////////////////
// Modify User
////////////////////////////////////////////////////////////////////////////////
type modifyUserRequest struct {
ID uint
payload fleet.UserPayload
}
type modifyUserResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r modifyUserResponse) error() error { return r.Err }
func makeModifyUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(modifyUserRequest)
user, err := svc.ModifyUser(ctx, req.ID, req.payload)
if err != nil {
return modifyUserResponse{Err: err}, nil
}
return modifyUserResponse{User: user}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Delete User
////////////////////////////////////////////////////////////////////////////////
type deleteUserRequest struct {
ID uint `json:"id"`
}
type deleteUserResponse struct {
Err error `json:"error,omitempty"`
}
func (r deleteUserResponse) error() error { return r.Err }
func makeDeleteUserEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(deleteUserRequest)
err := svc.DeleteUser(ctx, req.ID)
if err != nil {
return deleteUserResponse{Err: err}, nil
}
return deleteUserResponse{}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Perform Required Password Reset
////////////////////////////////////////////////////////////////////////////////
@ -242,33 +83,6 @@ func makePerformRequiredPasswordResetEndpoint(svc fleet.Service) endpoint.Endpoi
}
}
////////////////////////////////////////////////////////////////////////////////
// Require Password Reset
////////////////////////////////////////////////////////////////////////////////
type requirePasswordResetRequest struct {
Require bool `json:"require"`
ID uint `json:"id"`
}
type requirePasswordResetResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r requirePasswordResetResponse) error() error { return r.Err }
func makeRequirePasswordResetEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(requirePasswordResetRequest)
user, err := svc.RequirePasswordReset(ctx, req.ID, req.Require)
if err != nil {
return requirePasswordResetResponse{Err: err}, nil
}
return requirePasswordResetResponse{User: user}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Forgot Password
////////////////////////////////////////////////////////////////////////////////

View file

@ -67,8 +67,8 @@ func allFields(ifv reflect.Value) []reflect.StructField {
// makeDecoder creates a decoder for the type for the struct passed on. If the
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
// URL (similarly for host_options, carve_options that derive from the common
// list_options).
// URL (similarly for host_options, carve_options, user_options that derive
// from the common list_options).
//
// Finally, any other `url` tag will be treated as a path variable (of the form
// /path/{name} in the route's path) from the URL path pattern, and it'll be
@ -123,6 +123,13 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
}
field.Set(reflect.ValueOf(opts))
case "user_options":
opts, err := userListOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
case "host_options":
opts, err := hostListOptionsFromRequest(r)
if err != nil {

View file

@ -28,17 +28,8 @@ type FleetEndpoints struct {
ForgotPassword endpoint.Endpoint
ResetPassword endpoint.Endpoint
Me endpoint.Endpoint
ChangePassword endpoint.Endpoint
CreateUserWithInvite endpoint.Endpoint
CreateUser endpoint.Endpoint
GetUser endpoint.Endpoint
ListUsers endpoint.Endpoint
ModifyUser endpoint.Endpoint
DeleteUser endpoint.Endpoint
RequirePasswordReset endpoint.Endpoint
PerformRequiredPasswordReset endpoint.Endpoint
GetSessionsForUserInfo endpoint.Endpoint
DeleteSessionsForUser endpoint.Endpoint
GetSessionInfo endpoint.Endpoint
DeleteSession endpoint.Endpoint
GetAppConfig endpoint.Endpoint
@ -115,15 +106,6 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
// Standard user authentication routes
Me: authenticatedUser(svc, makeGetSessionUserEndpoint(svc)),
ChangePassword: authenticatedUser(svc, makeChangePasswordEndpoint(svc)),
GetUser: authenticatedUser(svc, makeGetUserEndpoint(svc)),
ListUsers: authenticatedUser(svc, makeListUsersEndpoint(svc)),
ModifyUser: authenticatedUser(svc, makeModifyUserEndpoint(svc)),
DeleteUser: authenticatedUser(svc, makeDeleteUserEndpoint(svc)),
RequirePasswordReset: authenticatedUser(svc, makeRequirePasswordResetEndpoint(svc)),
CreateUser: authenticatedUser(svc, makeCreateUserEndpoint(svc)),
GetSessionsForUserInfo: authenticatedUser(svc, makeGetInfoAboutSessionsForUserEndpoint(svc)),
DeleteSessionsForUser: authenticatedUser(svc, makeDeleteSessionsForUserEndpoint(svc)),
GetSessionInfo: authenticatedUser(svc, makeGetInfoAboutSessionEndpoint(svc)),
DeleteSession: authenticatedUser(svc, makeDeleteSessionEndpoint(svc)),
GetAppConfig: authenticatedUser(svc, makeGetAppConfigEndpoint(svc)),
@ -184,17 +166,8 @@ type fleetHandlers struct {
ForgotPassword http.Handler
ResetPassword http.Handler
Me http.Handler
ChangePassword http.Handler
CreateUserWithInvite http.Handler
CreateUser http.Handler
GetUser http.Handler
ListUsers http.Handler
ModifyUser http.Handler
DeleteUser http.Handler
RequirePasswordReset http.Handler
PerformRequiredPasswordReset http.Handler
GetSessionsForUserInfo http.Handler
DeleteSessionsForUser http.Handler
GetSessionInfo http.Handler
DeleteSession http.Handler
GetAppConfig http.Handler
@ -255,17 +228,8 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest),
ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest),
Me: newServer(e.Me, decodeNoParamsRequest),
ChangePassword: newServer(e.ChangePassword, decodeChangePasswordRequest),
CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest),
CreateUser: newServer(e.CreateUser, decodeCreateUserRequest),
GetUser: newServer(e.GetUser, decodeGetUserRequest),
ListUsers: newServer(e.ListUsers, decodeListUsersRequest),
ModifyUser: newServer(e.ModifyUser, decodeModifyUserRequest),
DeleteUser: newServer(e.DeleteUser, decodeDeleteUserRequest),
RequirePasswordReset: newServer(e.RequirePasswordReset, decodeRequirePasswordResetRequest),
PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest),
GetSessionsForUserInfo: newServer(e.GetSessionsForUserInfo, decodeGetInfoAboutSessionsForUserRequest),
DeleteSessionsForUser: newServer(e.DeleteSessionsForUser, decodeDeleteSessionsForUserRequest),
GetSessionInfo: newServer(e.GetSessionInfo, decodeGetInfoAboutSessionRequest),
DeleteSession: newServer(e.DeleteSession, decodeDeleteSessionRequest),
GetAppConfig: newServer(e.GetAppConfig, decodeNoParamsRequest),
@ -488,20 +452,12 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password")
r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password")
r.Handle("/api/v1/fleet/me", h.Me).Methods("GET").Name("me")
r.Handle("/api/v1/fleet/change_password", h.ChangePassword).Methods("POST").Name("change_password")
r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset")
r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso")
r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config")
r.Handle("/api/v1/fleet/sso/callback", h.CallbackSSO).Methods("POST").Name("callback_sso")
r.Handle("/api/v1/fleet/users", h.ListUsers).Methods("GET").Name("list_users")
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
r.Handle("/api/v1/fleet/users/admin", h.CreateUser).Methods("POST").Name("create_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.GetUser).Methods("GET").Name("get_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.ModifyUser).Methods("PATCH").Name("modify_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}", h.DeleteUser).Methods("DELETE").Name("delete_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/require_password_reset", h.RequirePasswordReset).Methods("POST").Name("require_password_reset")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.GetSessionsForUserInfo).Methods("GET").Name("get_session_for_user")
r.Handle("/api/v1/fleet/users/{id:[0-9]+}/sessions", h.DeleteSessionsForUser).Methods("DELETE").Name("delete_session_for_user")
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session")
@ -570,6 +526,16 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").PATCH("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", modifyTeamScheduleEndpoint, modifyTeamScheduleRequest{})
e.WithAltPaths("/api/_version_/fleet/team/{team_id}/schedule/{scheduled_query_id}").DELETE("/api/_version_/fleet/teams/{team_id}/schedule/{scheduled_query_id}", deleteTeamScheduleEndpoint, deleteTeamScheduleRequest{})
e.GET("/api/_version_/fleet/users", listUsersEndpoint, listUsersRequest{})
e.POST("/api/_version_/fleet/users/admin", createUserEndpoint, createUserRequest{})
e.GET("/api/_version_/fleet/users/{id:[0-9]+}", getUserEndpoint, getUserRequest{})
e.PATCH("/api/_version_/fleet/users/{id:[0-9]+}", modifyUserEndpoint, modifyUserRequest{})
e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}", deleteUserEndpoint, deleteUserRequest{})
e.POST("/api/_version_/fleet/users/{id:[0-9]+}/require_password_reset", requirePasswordResetEndpoint, requirePasswordResetRequest{})
e.GET("/api/_version_/fleet/users/{id:[0-9]+}/sessions", getInfoAboutSessionsForUserEndpoint, getInfoAboutSessionsForUserRequest{})
e.DELETE("/api/_version_/fleet/users/{id:[0-9]+}/sessions", deleteSessionsForUserEndpoint, deleteSessionsForUserRequest{})
e.POST("/api/_version_/fleet/change_password", changePasswordEndpoint, changePasswordRequest{})
e.POST("/api/_version_/fleet/global/policies", globalPolicyEndpoint, globalPolicyRequest{})
e.GET("/api/_version_/fleet/global/policies", listGlobalPoliciesEndpoint, nil)
e.GET("/api/_version_/fleet/global/policies/{policy_id}", getPolicyByIDEndpoint, getPolicyByIDRequest{})

View file

@ -42,18 +42,6 @@ func TestAPIRoutes(t *testing.T) {
verb: "POST",
uri: "/api/v1/fleet/users",
},
{
verb: "GET",
uri: "/api/v1/fleet/users",
},
{
verb: "GET",
uri: "/api/v1/fleet/users/1",
},
{
verb: "PATCH",
uri: "/api/v1/fleet/users/1",
},
{
verb: "POST",
uri: "/api/v1/fleet/login",

View file

@ -35,21 +35,36 @@ func (s *integrationTestSuite) SetupSuite() {
}
func (s *integrationTestSuite) TearDownTest() {
t := s.T()
ctx := context.Background()
u := s.users["admin1@example.com"]
filter := fleet.TeamFilter{User: &u}
hosts, _ := s.ds.ListHosts(context.Background(), filter, fleet.HostListOptions{})
hosts, err := s.ds.ListHosts(ctx, filter, fleet.HostListOptions{})
require.NoError(t, err)
var ids []uint
for _, host := range hosts {
ids = append(ids, host.ID)
}
s.ds.DeleteHosts(context.Background(), ids)
if len(ids) > 0 {
require.NoError(t, s.ds.DeleteHosts(ctx, ids))
}
lbls, err := s.ds.ListLabels(context.Background(), fleet.TeamFilter{}, fleet.ListOptions{})
require.NoError(s.T(), err)
lbls, err := s.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{})
require.NoError(t, err)
for _, lbl := range lbls {
if lbl.LabelType != fleet.LabelTypeBuiltIn {
err := s.ds.DeleteLabel(context.Background(), lbl.Name)
require.NoError(s.T(), err)
err := s.ds.DeleteLabel(ctx, lbl.Name)
require.NoError(t, err)
}
}
users, err := s.ds.ListUsers(ctx, fleet.UserListOptions{})
require.NoError(t, err)
for _, u := range users {
if _, ok := s.users[u.Email]; !ok {
err := s.ds.DeleteUser(ctx, u.ID)
require.NoError(t, err)
}
}
}
@ -1867,6 +1882,82 @@ func (s *integrationTestSuite) TestLabelSpecs() {
s.DoJSON("GET", "/api/v1/fleet/spec/labels/zzz", nil, http.StatusNotFound, &getResp)
}
func (s *integrationTestSuite) TestUsers() {
t := s.T()
// list existing users
var listResp listUsersResponse
s.DoJSON("GET", "/api/v1/fleet/users", nil, http.StatusOK, &listResp)
assert.Len(t, listResp.Users, len(s.users))
// create a new user
var createResp createUserResponse
params := fleet.UserPayload{
Name: ptr.String("extra"),
Email: ptr.String("extra@asd.com"),
Password: ptr.String("pass"),
GlobalRole: ptr.String(fleet.RoleObserver),
}
s.DoJSON("POST", "/api/v1/fleet/users/admin", params, http.StatusOK, &createResp)
assert.NotZero(t, createResp.User.ID)
u := *createResp.User
// get that user
var getResp getUserResponse
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusOK, &getResp)
assert.Equal(t, u.ID, getResp.User.ID)
// get non-existing user
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), nil, http.StatusNotFound, &getResp)
// modify that user - simple name change
var modResp modifyUserResponse
params = fleet.UserPayload{
Name: ptr.String("extraz"),
}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusOK, &modResp)
assert.Equal(t, u.ID, modResp.User.ID)
assert.Equal(t, u.Name+"z", modResp.User.Name)
// modify user - email change, password does not match
params = fleet.UserPayload{
Email: ptr.String("extra2@asd.com"),
Password: ptr.String("wrongpass"),
}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusForbidden, &modResp)
// modify user - email change, password ok
params = fleet.UserPayload{
Email: ptr.String("extra2@asd.com"),
Password: ptr.String("pass"),
}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), params, http.StatusOK, &modResp)
assert.Equal(t, u.ID, modResp.User.ID)
assert.NotEqual(t, u.ID, modResp.User.Email)
// modify invalid user
params = fleet.UserPayload{
Name: ptr.String("nosuchuser"),
}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID+1), params, http.StatusNotFound, &modResp)
// require a password reset
var reqResetResp requirePasswordResetResponse
s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID), map[string]bool{"require": true}, http.StatusOK, &reqResetResp)
assert.Equal(t, u.ID, reqResetResp.User.ID)
assert.True(t, reqResetResp.User.AdminForcedPasswordReset)
// require a password reset to invalid user
s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/users/%d/require_password_reset", u.ID+1), map[string]bool{"require": true}, http.StatusNotFound, &reqResetResp)
// delete user
var delResp deleteUserResponse
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusOK, &delResp)
// delete invalid user
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/users/%d", u.ID), nil, http.StatusNotFound, &delResp)
}
func (s *integrationTestSuite) TestGlobalPoliciesAutomationConfig() {
t := s.T()

View file

@ -12,8 +12,7 @@ import (
////////////////////////////////////////////////////////////////////////////////
type getScheduledQueriesInPackRequest struct {
ID uint `url:"id"`
// TODO(mna): was not set in the old pattern
ID uint `url:"id"`
ListOptions fleet.ListOptions `url:"list_options"`
}

View file

@ -284,35 +284,6 @@ func (svc *Service) DestroySession(ctx context.Context) error {
return svc.ds.DestroySession(ctx, session)
}
func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) {
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
return nil, err
}
var validatedSessions []*fleet.Session
sessions, err := svc.ds.ListSessionsForUser(ctx, id)
if err != nil {
return validatedSessions, err
}
for _, session := range sessions {
if svc.validateSession(ctx, session) == nil {
validatedSessions = append(validatedSessions, session)
}
}
return validatedSessions, nil
}
func (svc *Service) DeleteSessionsForUser(ctx context.Context, id uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.DestroyAllSessionsForUser(ctx, id)
}
func (svc *Service) GetInfoAboutSession(ctx context.Context, id uint) (*fleet.Session, error) {
session, err := svc.ds.SessionByID(ctx, id)
if err != nil {

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
@ -42,27 +41,6 @@ func (svc *Service) CreateUserFromInvite(ctx context.Context, p fleet.UserPayloa
return user, nil
}
func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
var teams []fleet.UserTeam
if p.Teams != nil {
teams = *p.Teams
}
if err := svc.authz.Authorize(ctx, &fleet.User{Teams: teams}, fleet.ActionWrite); err != nil {
return nil, err
}
if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil {
return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email)
}
if p.AdminForcedPasswordReset == nil {
// By default, force password reset for users created this way.
p.AdminForcedPasswordReset = ptr.Bool(true)
}
return svc.newUser(ctx, p)
}
func (svc *Service) CreateInitialUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
// skipauth: Only the initial user creation should be allowed to skip
// authorization (because there is not yet a user context to check against).
@ -106,157 +84,6 @@ func (svc *Service) newUser(ctx context.Context, p fleet.UserPayload) (*fleet.Us
return user, nil
}
func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
return nil, err
}
user, err := svc.User(ctx, userID)
if err != nil {
return nil, err
}
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
return nil, err
}
if p.GlobalRole != nil || p.Teams != nil {
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
return nil, err
}
}
if p.Name != nil {
user.Name = *p.Name
}
if p.Email != nil && *p.Email != user.Email {
err = svc.modifyEmailAddress(ctx, user, *p.Email, p.Password)
if err != nil {
return nil, err
}
}
if p.Position != nil {
user.Position = *p.Position
}
if p.GravatarURL != nil {
user.GravatarURL = *p.GravatarURL
}
if p.SSOEnabled != nil {
user.SSOEnabled = *p.SSOEnabled
}
currentUser := authz.UserFromContext(ctx)
if p.GlobalRole != nil && *p.GlobalRole != "" {
if currentUser.GlobalRole == nil {
return nil, ctxerr.New(ctx, "Cannot edit global role as a team member")
}
if p.Teams != nil && len(*p.Teams) > 0 {
return nil, fleet.NewInvalidArgumentError("teams", "may not be specified with global_role")
}
user.GlobalRole = p.GlobalRole
user.Teams = []fleet.UserTeam{}
} else if p.Teams != nil {
if !isAdminOfTheModifiedTeams(currentUser, user.Teams, *p.Teams) {
return nil, ctxerr.New(ctx, "Cannot modify teams in that way")
}
user.Teams = *p.Teams
user.GlobalRole = nil
}
err = svc.saveUser(ctx, user)
if err != nil {
return nil, err
}
return user, nil
}
func isAdminOfTheModifiedTeams(currentUser *fleet.User, originalUserTeams, newUserTeams []fleet.UserTeam) bool {
// If the user is of the right global role, then they can modify the teams
if currentUser.GlobalRole != nil && (*currentUser.GlobalRole == fleet.RoleAdmin || *currentUser.GlobalRole == fleet.RoleMaintainer) {
return true
}
// otherwise, gather the resulting teams
resultingTeams := make(map[uint]string)
for _, team := range newUserTeams {
resultingTeams[team.ID] = team.Role
}
// and see which ones were removed or changed from the original
teamsAffected := make(map[uint]struct{})
for _, team := range originalUserTeams {
if resultingTeams[team.ID] != team.Role {
teamsAffected[team.ID] = struct{}{}
}
}
// then gather the teams the current user is admin for
currentUserTeamAdmin := make(map[uint]struct{})
for _, team := range currentUser.Teams {
if team.Role == fleet.RoleAdmin {
currentUserTeamAdmin[team.ID] = struct{}{}
}
}
// and let's check that the teams that were either removed or changed are also teams this user is an admin of
for teamID := range teamsAffected {
if _, ok := currentUserTeamAdmin[teamID]; !ok {
return false
}
}
return true
}
func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, email string, password *string) error {
// password requirement handled in validation middleware
if password != nil {
err := user.ValidatePassword(*password)
if err != nil {
return fleet.NewPermissionError("incorrect password")
}
}
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
if err != nil {
return err
}
token := base64.URLEncoding.EncodeToString([]byte(random))
err = svc.ds.PendingEmailChange(ctx, user.ID, email, token)
if err != nil {
return err
}
config, err := svc.AppConfig(ctx)
if err != nil {
return err
}
changeEmail := fleet.Email{
Subject: "Confirm Fleet Email Change",
To: []string{email},
Config: config,
Mailer: &mail.ChangeEmailMailer{
Token: token,
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
AssetURL: getAssetURL(),
},
}
return svc.mailService.SendEmail(changeEmail)
}
func (svc *Service) DeleteUser(ctx context.Context, id uint) error {
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.DeleteUser(ctx, id)
}
func (svc *Service) ChangeUserEmail(ctx context.Context, token string) (string, error) {
vc, ok := viewer.FromContext(ctx)
if !ok {
@ -270,14 +97,6 @@ func (svc *Service) ChangeUserEmail(ctx context.Context, token string) (string,
return svc.ds.ConfirmPendingEmailChange(ctx, vc.UserID(), token)
}
func (svc *Service) User(ctx context.Context, id uint) (*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.UserByID(ctx, id)
}
func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User, error) {
// Explicitly no authorization check. Should only be used by middleware.
return svc.ds.UserByID(ctx, id)
@ -299,14 +118,6 @@ func (svc *Service) AuthenticatedUser(ctx context.Context) (*fleet.User, error)
return vc.User, nil
}
func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListUsers(ctx, opt)
}
// setNewPassword is a helper for changing a user's password. It should be
// called to set the new password after proper authorization has been
// performed.
@ -326,33 +137,6 @@ func (svc *Service) setNewPassword(ctx context.Context, user *fleet.User, passwo
return nil
}
func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) error {
vc, ok := viewer.FromContext(ctx)
if !ok {
return fleet.ErrNoContext
}
if err := svc.authz.Authorize(ctx, vc.User, fleet.ActionWrite); err != nil {
return err
}
if vc.User.SSOEnabled {
return ctxerr.New(ctx, "change password for single sign on user not allowed")
}
if err := vc.User.ValidatePassword(newPass); err == nil {
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
}
if err := vc.User.ValidatePassword(oldPass); err != nil {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match"))
}
if err := svc.setNewPassword(ctx, vc.User, newPass); err != nil {
return ctxerr.Wrap(ctx, err, "setting new password")
}
return nil
}
func (svc *Service) ResetPassword(ctx context.Context, token, password string) error {
// skipauth: No viewer context available. The user is locked out of their
// account and authNZ is performed entirely by providing a valid password
@ -431,34 +215,6 @@ func (svc *Service) PerformRequiredPasswordReset(ctx context.Context, password s
return user, nil
}
func (svc *Service) RequirePasswordReset(ctx context.Context, uid uint, require bool) (*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{ID: uid}, fleet.ActionWrite); err != nil {
return nil, err
}
user, err := svc.ds.UserByID(ctx, uid)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "loading user by ID")
}
if user.SSOEnabled {
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
// Require reset on next login
user.AdminForcedPasswordReset = require
if err := svc.saveUser(ctx, user); err != nil {
return nil, ctxerr.Wrap(ctx, err, "saving user")
}
if require {
// Clear all of the existing sessions
if err := svc.DeleteSessionsForUser(ctx, user.ID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "deleting user sessions")
}
}
return user, nil
}
func (svc *Service) RequestPasswordReset(ctx context.Context, email string) error {
// skipauth: No viewer context available. The user is locked out of their
// account and trying to reset their password.
@ -512,10 +268,3 @@ func (svc *Service) RequestPasswordReset(ctx context.Context, email string) erro
return svc.mailService.SendEmail(resetEmail)
}
// saves user in datastore.
// doesn't need to be exposed to the transport
// the service should expose actions for modifying a user instead
func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error {
return svc.ds.SaveUser(ctx, user)
}

View file

@ -3,7 +3,6 @@ package service
import (
"context"
"database/sql"
"errors"
"testing"
"time"
@ -11,8 +10,6 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -38,250 +35,6 @@ func TestAuthenticatedUser(t *testing.T) {
assert.Equal(t, user, admin1)
}
func TestModifyUserEmail(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
// verify this isn't changed yet
assert.Equal(t, "foo@bar.com", u.Email)
// verify is changed per bug 1123
assert.Equal(t, "minion", u.Position)
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
Password: ptr.String("password"),
Position: ptr.String("minion"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.Nil(t, err)
assert.True(t, ms.PendingEmailChangeFuncInvoked)
assert.True(t, ms.SaveUserFuncInvoked)
}
func TestModifyUserEmailNoPassword(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
// NO PASSWORD
// Password: ptr.String("password"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.NotNil(t, err)
var iae *fleet.InvalidArgumentError
ok := errors.As(err, &iae)
require.True(t, ok)
require.Len(t, *iae, 1)
assert.False(t, ms.PendingEmailChangeFuncInvoked)
assert.False(t, ms.SaveUserFuncInvoked)
}
func TestModifyAdminUserEmailNoPassword(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
// NO PASSWORD
// Password: ptr.String("password"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.NotNil(t, err)
var iae *fleet.InvalidArgumentError
ok := errors.As(err, &iae)
require.True(t, ok)
require.Len(t, *iae, 1)
assert.False(t, ms.PendingEmailChangeFuncInvoked)
assert.False(t, ms.SaveUserFuncInvoked)
}
func TestModifyAdminUserEmailPassword(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
Password: ptr.String("password"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.Nil(t, err)
assert.True(t, ms.PendingEmailChangeFuncInvoked)
assert.True(t, ms.SaveUserFuncInvoked)
}
func TestChangePassword(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
users := createTestUsers(t, ds)
passwordChangeTests := []struct {
user fleet.User
oldPassword string
newPassword string
anyErr bool
wantErr error
}{
{ // all good
user: users["admin1@example.com"],
oldPassword: "foobarbaz1234!",
newPassword: "12345cat!",
},
{ // prevent password reuse
user: users["admin1@example.com"],
oldPassword: "12345cat!",
newPassword: "foobarbaz1234!",
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
},
{ // all good
user: users["user1@example.com"],
oldPassword: "foobarbaz1234!",
newPassword: "newpassa1234!",
},
{ // bad old password
user: users["user1@example.com"],
oldPassword: "wrong_password",
newPassword: "12345cat!",
anyErr: true,
},
{ // missing old password
newPassword: "123cataaa!",
wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"),
},
}
for _, tt := range passwordChangeTests {
t.Run("", func(t *testing.T) {
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &tt.user})
err := svc.ChangePassword(ctx, tt.oldPassword, tt.newPassword)
if tt.anyErr {
require.NotNil(t, err)
} else if tt.wantErr != nil {
require.Equal(t, tt.wantErr, ctxerr.Cause(err))
} else {
require.Nil(t, err)
}
if err != nil {
return
}
// Attempt login after successful change
_, _, err = svc.Login(context.Background(), tt.user.Email, tt.newPassword)
require.Nil(t, err, "should be able to login with new password")
})
}
}
func TestResetPassword(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
@ -340,48 +93,6 @@ func TestResetPassword(t *testing.T) {
}
}
func TestRequirePasswordReset(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
for _, tt := range testUsers {
t.Run(tt.Email, func(t *testing.T) {
user, err := ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
var sessions []*fleet.Session
// Log user in
_, _, err = svc.Login(test.UserContext(test.UserAdmin), tt.Email, tt.PlaintextPassword)
require.Nil(t, err, "login unsuccessful")
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
require.Nil(t, err)
require.Len(t, sessions, 1, "user should have one session")
// Reset and verify sessions destroyed
retUser, err := svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
require.Nil(t, err)
assert.True(t, retUser.AdminForcedPasswordReset)
checkUser, err := ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
assert.True(t, checkUser.AdminForcedPasswordReset)
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
require.Nil(t, err)
require.Len(t, sessions, 0, "sessions should be destroyed")
// try undo
retUser, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, false)
require.Nil(t, err)
assert.False(t, retUser.AdminForcedPasswordReset)
checkUser, err = ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
assert.False(t, checkUser.AdminForcedPasswordReset)
})
}
}
func refreshCtx(t *testing.T, ctx context.Context, user *fleet.User, ds fleet.Datastore, session *fleet.Session) context.Context {
reloadedUser, err := ds.UserByEmail(ctx, user.Email)
require.NoError(t, err)
@ -475,178 +186,3 @@ func TestUserPasswordRequirements(t *testing.T) {
})
}
}
func TestUserAuth(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.InviteByTokenFunc = func(ctx context.Context, token string) (*fleet.Invite, error) {
return &fleet.Invite{
Email: "some@email.com",
Token: "ABCD",
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()},
},
}, nil
}
ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) {
return &fleet.User{}, nil
}
ds.DeleteInviteFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) {
return nil, errors.New("AA")
}
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
if id == 999 {
return &fleet.User{
ID: 999,
Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}},
}, nil
}
return &fleet.User{
ID: 888,
GlobalRole: ptr.String(fleet.RoleMaintainer),
}, nil
}
ds.SaveUserFunc = func(ctx context.Context, user *fleet.User) error {
return nil
}
testCases := []struct {
name string
user *fleet.User
shouldFailGlobalWrite bool
shouldFailTeamWrite bool
shouldFailRead bool
}{
{
"global admin",
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
false,
false,
false,
},
{
"global maintainer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
true,
true,
true,
},
{
"global observer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
true,
true,
true,
},
{
"team admin, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
true,
false,
false,
},
{
"team maintainer, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
true,
true,
false,
},
{
"team observer, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
true,
true,
true,
},
{
"team maintainer, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
true,
true,
true,
},
{
"team admin, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}},
true,
true,
true,
},
{
"team observer, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
true,
true,
true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
teams := []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}
_, err := svc.CreateUser(ctx, fleet.UserPayload{
Name: ptr.String("Some Name"),
Email: ptr.String("some@email.com"),
Password: ptr.String("passw0rd."),
Teams: &teams,
})
checkAuthErr(t, tt.shouldFailTeamWrite, err)
_, err = svc.CreateUser(ctx, fleet.UserPayload{
Name: ptr.String("Some Name"),
Email: ptr.String("some@email.com"),
Password: ptr.String("passw0rd."),
GlobalRole: ptr.String(fleet.RoleAdmin),
})
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.ModifyUser(ctx, 999, fleet.UserPayload{Teams: &teams})
checkAuthErr(t, tt.shouldFailTeamWrite, err)
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{Teams: &teams})
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{GlobalRole: ptr.String(fleet.RoleMaintainer)})
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
})
}
}
// Test that CreateUser creates a user that will be forced to
// reset its password upon first login (see #2570).
func TestCreateUserForcePasswdReset(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
svc := newTestService(ds, nil, nil)
// Create admin user.
admin := &fleet.User{
Name: "Fleet Admin",
Email: "admin@foo.com",
GlobalRole: ptr.String(fleet.RoleAdmin),
}
err := admin.SetPassword("p4ssw0rd.", 10, 10)
require.NoError(t, err)
admin, err = ds.NewUser(context.Background(), admin)
require.NoError(t, err)
// As the admin, create a new user.
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: admin})
user, err := svc.CreateUser(ctx, fleet.UserPayload{
Name: ptr.String("Some Observer"),
Email: ptr.String("some-observer@email.com"),
Password: ptr.String("passw0rd."),
GlobalRole: ptr.String(fleet.RoleObserver),
})
require.NoError(t, err)
user, err = ds.UserByID(context.Background(), user.ID)
require.NoError(t, err)
require.True(t, user.AdminForcedPasswordReset)
}

View file

@ -0,0 +1,8 @@
package service
// TODO(mna): when migrating Session-related endpoints, add auth tests for those
// endpoints (the auth is session-based, not user-based).
//_, err = svc.GetInfoAboutSessionsForUser(ctx, 999)
//checkAuthErr(t, tt.shouldFailTeamWrite, err)
//_, err = svc.GetInfoAboutSessionsForUser(ctx, 888)
//checkAuthErr(t, tt.shouldFailGlobalWrite, err)

View file

@ -18,14 +18,6 @@ func decodeGetInfoAboutSessionRequest(ctx context.Context, r *http.Request) (int
return getInfoAboutSessionRequest{ID: uint(id)}, nil
}
func decodeGetInfoAboutSessionsForUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
return getInfoAboutSessionsForUserRequest{ID: uint(id)}, nil
}
func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
@ -34,14 +26,6 @@ func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface
return deleteSessionRequest{ID: uint(id)}, nil
}
func decodeDeleteSessionsForUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
return deleteSessionsForUserRequest{ID: uint(id)}, nil
}
func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {

View file

@ -27,22 +27,6 @@ func TestDecodeGetInfoAboutSessionRequest(t *testing.T) {
)
}
func TestDecodeGetInfoAboutSessionsForUserRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/user/{id}/sessions", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeGetInfoAboutSessionsForUserRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(getInfoAboutSessionsForUserRequest)
assert.Equal(t, uint(1), params.ID)
}).Methods("GET")
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("GET", "/api/v1/fleet/users/1/sessions", nil),
)
}
func TestDecodeDeleteSessionRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/sessions/{id}", func(writer http.ResponseWriter, request *http.Request) {
@ -59,22 +43,6 @@ func TestDecodeDeleteSessionRequest(t *testing.T) {
)
}
func TestDecodeDeleteSessionsForUserRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/user/{id}/sessions", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeDeleteSessionsForUserRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(deleteSessionsForUserRequest)
assert.Equal(t, uint(1), params.ID)
}).Methods("DELETE")
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("DELETE", "/api/v1/fleet/users/1/sessions", nil),
)
}
func TestDecodeLoginRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) {

View file

@ -10,69 +10,9 @@ import (
func decodeCreateUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req createUserRequest
if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil {
return nil, err
}
return req, nil
}
func decodeGetUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
return getUserRequest{ID: uint(id)}, nil
}
func decodeListUsersRequest(ctx context.Context, r *http.Request) (interface{}, error) {
opt, err := userListOptionsFromRequest(r)
if err != nil {
return nil, err
}
return listUsersRequest{ListOptions: opt}, nil
}
func decodeModifyUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
var req modifyUserRequest
if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil {
return nil, err
}
req.ID = uint(id)
return req, nil
}
func decodeDeleteUserRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
return deleteUserRequest{ID: uint(id)}, nil
}
func decodeChangePasswordRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req changePasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
return req, nil
}
func decodeRequirePasswordResetRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting ID from request")
}
var req requirePasswordResetRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, ctxerr.Wrap(ctx, err, "decoding JSON")
}
req.ID = uint(id)
return req, nil
}

View file

@ -11,68 +11,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestDecodeCreateUserRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/users", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeCreateUserRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(createUserRequest)
assert.Equal(t, "foo", *params.payload.Name)
assert.Equal(t, "foo@fleet.co", *params.payload.Email)
}).Methods("POST")
var body bytes.Buffer
body.Write([]byte(`{
"name": "foo",
"email": "foo@fleet.co"
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/api/v1/fleet/users", &body),
)
}
func TestDecodeGetUserRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/users/{id}", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeGetUserRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(getUserRequest)
assert.Equal(t, uint(1), params.ID)
}).Methods("GET")
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("GET", "/api/v1/fleet/users/1", nil),
)
}
func TestDecodeChangePasswordRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/change_password", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeChangePasswordRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(changePasswordRequest)
assert.Equal(t, "foo", params.OldPassword)
assert.Equal(t, "bar", params.NewPassword)
}).Methods("POST")
var body bytes.Buffer
body.Write([]byte(`{
"old_password": "foo",
"new_password": "bar"
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/api/v1/fleet/change_password", &body),
)
}
func TestDecodeResetPasswordRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/users/{id}/password", func(writer http.ResponseWriter, request *http.Request) {
@ -95,28 +33,3 @@ func TestDecodeResetPasswordRequest(t *testing.T) {
httptest.NewRequest("POST", "/api/v1/fleet/users/1/password", &body),
)
}
func TestDecodeModifyUserRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/users/{id}", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeModifyUserRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(modifyUserRequest)
assert.Equal(t, "foo", *params.payload.Name)
assert.Equal(t, "foo@fleet.co", *params.payload.Email)
assert.Equal(t, uint(1), params.ID)
}).Methods("PATCH")
var body bytes.Buffer
body.Write([]byte(`{
"name": "foo",
"email": "foo@fleet.co"
}`))
request := httptest.NewRequest("PATCH", "/api/v1/fleet/users/1", &body)
router.ServeHTTP(
httptest.NewRecorder(),
request,
)
}

521
server/service/users.go Normal file
View file

@ -0,0 +1,521 @@
package service
import (
"context"
"encoding/base64"
"html/template"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mail"
"github.com/fleetdm/fleet/v4/server/ptr"
)
////////////////////////////////////////////////////////////////////////////////
// Create User
////////////////////////////////////////////////////////////////////////////////
type createUserRequest struct {
fleet.UserPayload
}
type createUserResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r createUserResponse) error() error { return r.Err }
func createUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*createUserRequest)
user, err := svc.CreateUser(ctx, req.UserPayload)
if err != nil {
return createUserResponse{Err: err}, nil
}
return createUserResponse{User: user}, nil
}
func (svc *Service) CreateUser(ctx context.Context, p fleet.UserPayload) (*fleet.User, error) {
var teams []fleet.UserTeam
if p.Teams != nil {
teams = *p.Teams
}
if err := svc.authz.Authorize(ctx, &fleet.User{Teams: teams}, fleet.ActionWrite); err != nil {
return nil, err
}
if invite, err := svc.ds.InviteByEmail(ctx, *p.Email); err == nil && invite != nil {
return nil, ctxerr.Errorf(ctx, "%s already invited", *p.Email)
}
if p.AdminForcedPasswordReset == nil {
// By default, force password reset for users created this way.
p.AdminForcedPasswordReset = ptr.Bool(true)
}
return svc.newUser(ctx, p)
}
////////////////////////////////////////////////////////////////////////////////
// List Users
////////////////////////////////////////////////////////////////////////////////
type listUsersRequest struct {
ListOptions fleet.UserListOptions `url:"user_options"`
}
type listUsersResponse struct {
Users []fleet.User `json:"users"`
Err error `json:"error,omitempty"`
}
func (r listUsersResponse) error() error { return r.Err }
func listUsersEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*listUsersRequest)
users, err := svc.ListUsers(ctx, req.ListOptions)
if err != nil {
return listUsersResponse{Err: err}, nil
}
resp := listUsersResponse{Users: []fleet.User{}}
for _, user := range users {
resp.Users = append(resp.Users, *user)
}
return resp, nil
}
func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListUsers(ctx, opt)
}
////////////////////////////////////////////////////////////////////////////////
// Get User
////////////////////////////////////////////////////////////////////////////////
type getUserRequest struct {
ID uint `url:"id"`
}
type getUserResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getUserResponse) error() error { return r.Err }
func getUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getUserRequest)
user, err := svc.User(ctx, req.ID)
if err != nil {
return getUserResponse{Err: err}, nil
}
return getUserResponse{User: user}, nil
}
func (svc *Service) User(ctx context.Context, id uint) (*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.UserByID(ctx, id)
}
////////////////////////////////////////////////////////////////////////////////
// Modify User
////////////////////////////////////////////////////////////////////////////////
type modifyUserRequest struct {
ID uint `json:"-" url:"id"`
fleet.UserPayload
}
type modifyUserResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r modifyUserResponse) error() error { return r.Err }
func modifyUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*modifyUserRequest)
user, err := svc.ModifyUser(ctx, req.ID, req.UserPayload)
if err != nil {
return modifyUserResponse{Err: err}, nil
}
return modifyUserResponse{User: user}, nil
}
func (svc *Service) ModifyUser(ctx context.Context, userID uint, p fleet.UserPayload) (*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{}, fleet.ActionRead); err != nil {
return nil, err
}
user, err := svc.User(ctx, userID)
if err != nil {
return nil, err
}
if err := svc.authz.Authorize(ctx, user, fleet.ActionWrite); err != nil {
return nil, err
}
if p.GlobalRole != nil || p.Teams != nil {
if err := svc.authz.Authorize(ctx, user, fleet.ActionWriteRole); err != nil {
return nil, err
}
}
if p.Name != nil {
user.Name = *p.Name
}
if p.Email != nil && *p.Email != user.Email {
err = svc.modifyEmailAddress(ctx, user, *p.Email, p.Password)
if err != nil {
return nil, err
}
}
if p.Position != nil {
user.Position = *p.Position
}
if p.GravatarURL != nil {
user.GravatarURL = *p.GravatarURL
}
if p.SSOEnabled != nil {
user.SSOEnabled = *p.SSOEnabled
}
currentUser := authz.UserFromContext(ctx)
if p.GlobalRole != nil && *p.GlobalRole != "" {
if currentUser.GlobalRole == nil {
return nil, ctxerr.New(ctx, "Cannot edit global role as a team member")
}
if p.Teams != nil && len(*p.Teams) > 0 {
return nil, fleet.NewInvalidArgumentError("teams", "may not be specified with global_role")
}
user.GlobalRole = p.GlobalRole
user.Teams = []fleet.UserTeam{}
} else if p.Teams != nil {
if !isAdminOfTheModifiedTeams(currentUser, user.Teams, *p.Teams) {
return nil, ctxerr.New(ctx, "Cannot modify teams in that way")
}
user.Teams = *p.Teams
user.GlobalRole = nil
}
err = svc.saveUser(ctx, user)
if err != nil {
return nil, err
}
return user, nil
}
////////////////////////////////////////////////////////////////////////////////
// Delete User
////////////////////////////////////////////////////////////////////////////////
type deleteUserRequest struct {
ID uint `url:"id"`
}
type deleteUserResponse struct {
Err error `json:"error,omitempty"`
}
func (r deleteUserResponse) error() error { return r.Err }
func deleteUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*deleteUserRequest)
err := svc.DeleteUser(ctx, req.ID)
if err != nil {
return deleteUserResponse{Err: err}, nil
}
return deleteUserResponse{}, nil
}
func (svc *Service) DeleteUser(ctx context.Context, id uint) error {
if err := svc.authz.Authorize(ctx, &fleet.User{ID: id}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.DeleteUser(ctx, id)
}
////////////////////////////////////////////////////////////////////////////////
// Require Password Reset
////////////////////////////////////////////////////////////////////////////////
type requirePasswordResetRequest struct {
Require bool `json:"require"`
ID uint `json:"-" url:"id"`
}
type requirePasswordResetResponse struct {
User *fleet.User `json:"user,omitempty"`
Err error `json:"error,omitempty"`
}
func (r requirePasswordResetResponse) error() error { return r.Err }
func requirePasswordResetEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*requirePasswordResetRequest)
user, err := svc.RequirePasswordReset(ctx, req.ID, req.Require)
if err != nil {
return requirePasswordResetResponse{Err: err}, nil
}
return requirePasswordResetResponse{User: user}, nil
}
func (svc *Service) RequirePasswordReset(ctx context.Context, uid uint, require bool) (*fleet.User, error) {
if err := svc.authz.Authorize(ctx, &fleet.User{ID: uid}, fleet.ActionWrite); err != nil {
return nil, err
}
user, err := svc.ds.UserByID(ctx, uid)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "loading user by ID")
}
if user.SSOEnabled {
return nil, ctxerr.New(ctx, "password reset for single sign on user not allowed")
}
// Require reset on next login
user.AdminForcedPasswordReset = require
if err := svc.saveUser(ctx, user); err != nil {
return nil, ctxerr.Wrap(ctx, err, "saving user")
}
if require {
// Clear all of the existing sessions
if err := svc.DeleteSessionsForUser(ctx, user.ID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "deleting user sessions")
}
}
return user, nil
}
////////////////////////////////////////////////////////////////////////////////
// Change Password
////////////////////////////////////////////////////////////////////////////////
type changePasswordRequest struct {
OldPassword string `json:"old_password"`
NewPassword string `json:"new_password"`
}
type changePasswordResponse struct {
Err error `json:"error,omitempty"`
}
func (r changePasswordResponse) error() error { return r.Err }
func changePasswordEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*changePasswordRequest)
err := svc.ChangePassword(ctx, req.OldPassword, req.NewPassword)
return changePasswordResponse{Err: err}, nil
}
func (svc *Service) ChangePassword(ctx context.Context, oldPass, newPass string) error {
vc, ok := viewer.FromContext(ctx)
if !ok {
return fleet.ErrNoContext
}
if err := svc.authz.Authorize(ctx, vc.User, fleet.ActionWrite); err != nil {
return err
}
if vc.User.SSOEnabled {
return ctxerr.New(ctx, "change password for single sign on user not allowed")
}
if err := vc.User.ValidatePassword(newPass); err == nil {
return fleet.NewInvalidArgumentError("new_password", "cannot reuse old password")
}
if err := vc.User.ValidatePassword(oldPass); err != nil {
return ctxerr.Wrap(ctx, fleet.NewInvalidArgumentError("old_password", "old password does not match"))
}
if err := svc.setNewPassword(ctx, vc.User, newPass); err != nil {
return ctxerr.Wrap(ctx, err, "setting new password")
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// Get Info About Sessions For User
////////////////////////////////////////////////////////////////////////////////
type getInfoAboutSessionsForUserRequest struct {
ID uint `url:"id"`
}
type getInfoAboutSessionsForUserResponse struct {
Sessions []getInfoAboutSessionResponse `json:"sessions"`
Err error `json:"error,omitempty"`
}
func (r getInfoAboutSessionsForUserResponse) error() error { return r.Err }
func getInfoAboutSessionsForUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getInfoAboutSessionsForUserRequest)
sessions, err := svc.GetInfoAboutSessionsForUser(ctx, req.ID)
if err != nil {
return getInfoAboutSessionsForUserResponse{Err: err}, nil
}
var resp getInfoAboutSessionsForUserResponse
for _, session := range sessions {
resp.Sessions = append(resp.Sessions, getInfoAboutSessionResponse{
SessionID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
})
}
return resp, nil
}
func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) {
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
return nil, err
}
var validatedSessions []*fleet.Session
sessions, err := svc.ds.ListSessionsForUser(ctx, id)
if err != nil {
return validatedSessions, err
}
for _, session := range sessions {
if svc.validateSession(ctx, session) == nil {
validatedSessions = append(validatedSessions, session)
}
}
return validatedSessions, nil
}
////////////////////////////////////////////////////////////////////////////////
// Delete Sessions For User
////////////////////////////////////////////////////////////////////////////////
type deleteSessionsForUserRequest struct {
ID uint `url:"id"`
}
type deleteSessionsForUserResponse struct {
Err error `json:"error,omitempty"`
}
func (r deleteSessionsForUserResponse) error() error { return r.Err }
func deleteSessionsForUserEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*deleteSessionsForUserRequest)
err := svc.DeleteSessionsForUser(ctx, req.ID)
if err != nil {
return deleteSessionsForUserResponse{Err: err}, nil
}
return deleteSessionsForUserResponse{}, nil
}
func (svc *Service) DeleteSessionsForUser(ctx context.Context, id uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.DestroyAllSessionsForUser(ctx, id)
}
func isAdminOfTheModifiedTeams(currentUser *fleet.User, originalUserTeams, newUserTeams []fleet.UserTeam) bool {
// If the user is of the right global role, then they can modify the teams
if currentUser.GlobalRole != nil && (*currentUser.GlobalRole == fleet.RoleAdmin || *currentUser.GlobalRole == fleet.RoleMaintainer) {
return true
}
// otherwise, gather the resulting teams
resultingTeams := make(map[uint]string)
for _, team := range newUserTeams {
resultingTeams[team.ID] = team.Role
}
// and see which ones were removed or changed from the original
teamsAffected := make(map[uint]struct{})
for _, team := range originalUserTeams {
if resultingTeams[team.ID] != team.Role {
teamsAffected[team.ID] = struct{}{}
}
}
// then gather the teams the current user is admin for
currentUserTeamAdmin := make(map[uint]struct{})
for _, team := range currentUser.Teams {
if team.Role == fleet.RoleAdmin {
currentUserTeamAdmin[team.ID] = struct{}{}
}
}
// and let's check that the teams that were either removed or changed are also teams this user is an admin of
for teamID := range teamsAffected {
if _, ok := currentUserTeamAdmin[teamID]; !ok {
return false
}
}
return true
}
func (svc *Service) modifyEmailAddress(ctx context.Context, user *fleet.User, email string, password *string) error {
// password requirement handled in validation middleware
if password != nil {
err := user.ValidatePassword(*password)
if err != nil {
return fleet.NewPermissionError("incorrect password")
}
}
random, err := server.GenerateRandomText(svc.config.App.TokenKeySize)
if err != nil {
return err
}
token := base64.URLEncoding.EncodeToString([]byte(random))
err = svc.ds.PendingEmailChange(ctx, user.ID, email, token)
if err != nil {
return err
}
config, err := svc.AppConfig(ctx)
if err != nil {
return err
}
changeEmail := fleet.Email{
Subject: "Confirm Fleet Email Change",
To: []string{email},
Config: config,
Mailer: &mail.ChangeEmailMailer{
Token: token,
BaseURL: template.URL(config.ServerSettings.ServerURL + svc.config.Server.URLPrefix),
AssetURL: getAssetURL(),
},
}
return svc.mailService.SendEmail(changeEmail)
}
// saves user in datastore.
// doesn't need to be exposed to the transport
// the service should expose actions for modifying a user instead
func (svc *Service) saveUser(ctx context.Context, user *fleet.User) error {
return svc.ds.SaveUser(ctx, user)
}

View file

@ -0,0 +1,530 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserAuth(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.InviteByTokenFunc = func(ctx context.Context, token string) (*fleet.Invite, error) {
return &fleet.Invite{
Email: "some@email.com",
Token: "ABCD",
UpdateCreateTimestamps: fleet.UpdateCreateTimestamps{
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
UpdateTimestamp: fleet.UpdateTimestamp{UpdatedAt: time.Now()},
},
}, nil
}
ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) {
return &fleet.User{}, nil
}
ds.DeleteInviteFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) {
return nil, errors.New("AA")
}
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
if id == 999 {
return &fleet.User{
ID: 999,
Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}},
}, nil
}
return &fleet.User{
ID: 888,
GlobalRole: ptr.String(fleet.RoleMaintainer),
}, nil
}
ds.SaveUserFunc = func(ctx context.Context, user *fleet.User) error {
return nil
}
ds.ListUsersFunc = func(ctx context.Context, opts fleet.UserListOptions) ([]*fleet.User, error) {
return nil, nil
}
ds.DeleteUserFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.DestroyAllSessionsForUserFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.ListSessionsForUserFunc = func(ctx context.Context, id uint) ([]*fleet.Session, error) {
return nil, nil
}
testCases := []struct {
name string
user *fleet.User
shouldFailGlobalWrite bool
shouldFailTeamWrite bool
shouldFailRead bool
shouldFailDeleteReset bool
}{
{
"global admin",
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
false,
false,
false,
false,
},
{
"global maintainer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
true,
true,
false,
true,
},
{
"global observer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
true,
true,
false,
true,
},
{
"team admin, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
true,
false,
false,
true,
},
{
"team maintainer, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
true,
true,
false,
true,
},
{
"team observer, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
true,
true,
false,
true,
},
{
"team maintainer, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
true,
true,
false,
true,
},
{
"team admin, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}},
true,
true,
false,
true,
},
{
"team observer, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
true,
true,
false,
true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
teams := []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}
_, err := svc.CreateUser(ctx, fleet.UserPayload{
Name: ptr.String("Some Name"),
Email: ptr.String("some@email.com"),
Password: ptr.String("passw0rd."),
Teams: &teams,
})
checkAuthErr(t, tt.shouldFailTeamWrite, err)
_, err = svc.CreateUser(ctx, fleet.UserPayload{
Name: ptr.String("Some Name"),
Email: ptr.String("some@email.com"),
Password: ptr.String("passw0rd."),
GlobalRole: ptr.String(fleet.RoleAdmin),
})
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.ModifyUser(ctx, 999, fleet.UserPayload{Teams: &teams})
checkAuthErr(t, tt.shouldFailTeamWrite, err)
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{Teams: &teams})
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.ModifyUser(ctx, 888, fleet.UserPayload{GlobalRole: ptr.String(fleet.RoleMaintainer)})
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
err = svc.DeleteUser(ctx, 999)
checkAuthErr(t, tt.shouldFailDeleteReset, err)
_, err = svc.RequirePasswordReset(ctx, 999, false)
checkAuthErr(t, tt.shouldFailDeleteReset, err)
_, err = svc.ListUsers(ctx, fleet.UserListOptions{})
checkAuthErr(t, tt.shouldFailRead, err)
_, err = svc.User(ctx, 999)
checkAuthErr(t, tt.shouldFailRead, err)
_, err = svc.User(ctx, 888)
checkAuthErr(t, tt.shouldFailRead, err)
})
}
}
func TestModifyUserEmail(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
// verify this isn't changed yet
assert.Equal(t, "foo@bar.com", u.Email)
// verify is changed per bug 1123
assert.Equal(t, "minion", u.Position)
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
Password: ptr.String("password"),
Position: ptr.String("minion"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.Nil(t, err)
assert.True(t, ms.PendingEmailChangeFuncInvoked)
assert.True(t, ms.SaveUserFuncInvoked)
}
func TestModifyUserEmailNoPassword(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
// NO PASSWORD
// Password: ptr.String("password"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.NotNil(t, err)
var iae *fleet.InvalidArgumentError
ok := errors.As(err, &iae)
require.True(t, ok)
require.Len(t, *iae, 1)
assert.False(t, ms.PendingEmailChangeFuncInvoked)
assert.False(t, ms.SaveUserFuncInvoked)
}
func TestModifyAdminUserEmailNoPassword(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
// NO PASSWORD
// Password: ptr.String("password"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.NotNil(t, err)
var iae *fleet.InvalidArgumentError
ok := errors.As(err, &iae)
require.True(t, ok)
require.Len(t, *iae, 1)
assert.False(t, ms.PendingEmailChangeFuncInvoked)
assert.False(t, ms.SaveUserFuncInvoked)
}
func TestModifyAdminUserEmailPassword(t *testing.T) {
user := &fleet.User{
ID: 3,
Email: "foo@bar.com",
}
user.SetPassword("password", 10, 10)
ms := new(mock.Store)
ms.PendingEmailChangeFunc = func(ctx context.Context, id uint, em, tk string) error {
return nil
}
ms.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
return user, nil
}
ms.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
config := &fleet.AppConfig{
SMTPSettings: fleet.SMTPSettings{
SMTPConfigured: true,
SMTPAuthenticationType: fleet.AuthTypeNameNone,
SMTPPort: 1025,
SMTPServer: "127.0.0.1",
SMTPSenderAddress: "xxx@fleet.co",
},
}
return config, nil
}
ms.SaveUserFunc = func(ctx context.Context, u *fleet.User) error {
return nil
}
svc := newTestService(ms, nil, nil)
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
payload := fleet.UserPayload{
Email: ptr.String("zip@zap.com"),
Password: ptr.String("password"),
}
_, err := svc.ModifyUser(ctx, 3, payload)
require.Nil(t, err)
assert.True(t, ms.PendingEmailChangeFuncInvoked)
assert.True(t, ms.SaveUserFuncInvoked)
}
func TestUsersWithDS(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *mysql.Datastore)
}{
{"CreateUserForcePasswdReset", testUsersCreateUserForcePasswdReset},
{"ChangePassword", testUsersChangePassword},
{"RequirePasswordReset", testUsersRequirePasswordReset},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer mysql.TruncateTables(t, ds)
c.fn(t, ds)
})
}
}
// Test that CreateUser creates a user that will be forced to
// reset its password upon first login (see #2570).
func testUsersCreateUserForcePasswdReset(t *testing.T, ds *mysql.Datastore) {
svc := newTestService(ds, nil, nil)
// Create admin user.
admin := &fleet.User{
Name: "Fleet Admin",
Email: "admin@foo.com",
GlobalRole: ptr.String(fleet.RoleAdmin),
}
err := admin.SetPassword("p4ssw0rd.", 10, 10)
require.NoError(t, err)
admin, err = ds.NewUser(context.Background(), admin)
require.NoError(t, err)
// As the admin, create a new user.
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: admin})
user, err := svc.CreateUser(ctx, fleet.UserPayload{
Name: ptr.String("Some Observer"),
Email: ptr.String("some-observer@email.com"),
Password: ptr.String("passw0rd."),
GlobalRole: ptr.String(fleet.RoleObserver),
})
require.NoError(t, err)
user, err = ds.UserByID(context.Background(), user.ID)
require.NoError(t, err)
require.True(t, user.AdminForcedPasswordReset)
}
func testUsersChangePassword(t *testing.T, ds *mysql.Datastore) {
svc := newTestService(ds, nil, nil)
users := createTestUsers(t, ds)
passwordChangeTests := []struct {
user fleet.User
oldPassword string
newPassword string
anyErr bool
wantErr error
}{
{ // all good
user: users["admin1@example.com"],
oldPassword: "foobarbaz1234!",
newPassword: "12345cat!",
},
{ // prevent password reuse
user: users["admin1@example.com"],
oldPassword: "12345cat!",
newPassword: "foobarbaz1234!",
wantErr: fleet.NewInvalidArgumentError("new_password", "cannot reuse old password"),
},
{ // all good
user: users["user1@example.com"],
oldPassword: "foobarbaz1234!",
newPassword: "newpassa1234!",
},
{ // bad old password
user: users["user1@example.com"],
oldPassword: "wrong_password",
newPassword: "12345cat!",
anyErr: true,
},
{ // missing old password
newPassword: "123cataaa!",
wantErr: fleet.NewInvalidArgumentError("old_password", "Old password cannot be empty"),
},
}
for _, tt := range passwordChangeTests {
t.Run("", func(t *testing.T) {
ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &tt.user})
err := svc.ChangePassword(ctx, tt.oldPassword, tt.newPassword)
if tt.anyErr {
require.NotNil(t, err)
} else if tt.wantErr != nil {
require.Equal(t, tt.wantErr, ctxerr.Cause(err))
} else {
require.Nil(t, err)
}
if err != nil {
return
}
// Attempt login after successful change
_, _, err = svc.Login(context.Background(), tt.user.Email, tt.newPassword)
require.Nil(t, err, "should be able to login with new password")
})
}
}
func testUsersRequirePasswordReset(t *testing.T, ds *mysql.Datastore) {
svc := newTestService(ds, nil, nil)
createTestUsers(t, ds)
for _, tt := range testUsers {
t.Run(tt.Email, func(t *testing.T) {
user, err := ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
var sessions []*fleet.Session
// Log user in
_, _, err = svc.Login(test.UserContext(test.UserAdmin), tt.Email, tt.PlaintextPassword)
require.Nil(t, err, "login unsuccessful")
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
require.Nil(t, err)
require.Len(t, sessions, 1, "user should have one session")
// Reset and verify sessions destroyed
retUser, err := svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, true)
require.Nil(t, err)
assert.True(t, retUser.AdminForcedPasswordReset)
checkUser, err := ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
assert.True(t, checkUser.AdminForcedPasswordReset)
sessions, err = svc.GetInfoAboutSessionsForUser(test.UserContext(test.UserAdmin), user.ID)
require.Nil(t, err)
require.Len(t, sessions, 0, "sessions should be destroyed")
// try undo
retUser, err = svc.RequirePasswordReset(test.UserContext(test.UserAdmin), user.ID, false)
require.Nil(t, err)
assert.False(t, retUser.AdminForcedPasswordReset)
checkUser, err = ds.UserByEmail(context.Background(), tt.Email)
require.Nil(t, err)
assert.False(t, checkUser.AdminForcedPasswordReset)
})
}
}