diff --git a/auth.go b/auth.go index 9579504829..1ccabbc46e 100644 --- a/auth.go +++ b/auth.go @@ -47,7 +47,7 @@ func (vc *ViewerContext) UserID() (uint, error) { return 0, errors.New("No user set") } -func (vc *ViewerContext) CanPerformActions(db *gorm.DB) bool { +func (vc *ViewerContext) CanPerformActions() bool { if vc.user == nil { return false } @@ -70,12 +70,12 @@ func (vc *ViewerContext) IsUserID(id uint) bool { return false } -func (vc *ViewerContext) CanPerformWriteActionOnUser(db *gorm.DB, u *User) bool { - return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin()) +func (vc *ViewerContext) CanPerformWriteActionOnUser(u *User) bool { + return vc.CanPerformActions() && (vc.IsUserID(u.ID) || vc.IsAdmin()) } -func (vc *ViewerContext) CanPerformReadActionOnUser(db *gorm.DB, u *User) bool { - return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin()) +func (vc *ViewerContext) CanPerformReadActionOnUser(u *User) bool { + return vc.CanPerformActions() } // GenerateJWT generates a JWT token in serialized string form given a @@ -147,6 +147,7 @@ func JWTRenewalMiddleware(c *gin.Context) { } session.Set("jwt", jwt) + session.Save() c.Next() } @@ -246,12 +247,14 @@ func Login(c *gin.Context) { session.Set("jwt", token) session.Save() - c.JSON(200, map[string]interface{}{ - "id": user.ID, - "username": user.Username, - "email": user.Email, - "name": user.Name, - "admin": user.Admin, + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, }) } diff --git a/auth_test.go b/auth_test.go index 2a60da0358..48d7aef13c 100644 --- a/auth_test.go +++ b/auth_test.go @@ -179,19 +179,19 @@ func TestCanPerformActionsOnUser(t *testing.T) { adminVC := GenerateVC(admin) user1VC := GenerateVC(user1) - if !adminVC.CanPerformWriteActionOnUser(db, user1) || !adminVC.CanPerformWriteActionOnUser(db, user2) { + if !adminVC.CanPerformWriteActionOnUser(user1) || !adminVC.CanPerformWriteActionOnUser(user2) { t.Fatal("Admin should be able to perform writes on users") } - if !adminVC.CanPerformReadActionOnUser(db, user1) || !adminVC.CanPerformReadActionOnUser(db, user2) { + if !adminVC.CanPerformReadActionOnUser(user1) || !adminVC.CanPerformReadActionOnUser(user2) { t.Fatal("Admin should be able to perform reads on users") } - if user1VC.CanPerformWriteActionOnUser(db, user2) { + if user1VC.CanPerformWriteActionOnUser(user2) { t.Fatal("user1 shouldn't be able to perform writes on user2") } - if user1VC.CanPerformReadActionOnUser(db, user2) { + if !user1VC.CanPerformReadActionOnUser(user2) { t.Fatal("user1 should be able to perform reads on user2") } diff --git a/server.go b/server.go index 2e507124bd..dcd009b88b 100644 --- a/server.go +++ b/server.go @@ -25,6 +25,13 @@ func UnauthorizedError(c *gin.Context) { c.JSON(401, ServerError("Unauthorized")) } +// MalformedRequestError emits a response that is appropriate in the event that +// a request is received by a user which does not have required fields or is in +// some way malformed +func MalformedRequestError(c *gin.Context) { + c.JSON(400, ServerError("Malformed request")) +} + func createTestServer() *gin.Engine { server := gin.New() server.Use(TestingDatabaseMiddleware) @@ -52,7 +59,7 @@ func CreateServer() *gin.Engine { kolide.PATCH("/user", ModifyUser) kolide.DELETE("/user", DeleteUser) - kolide.PATCH("/user/password", ResetUserPassword) + kolide.PATCH("/user/password", ChangeUserPassword) kolide.PATCH("/user/admin", SetUserAdminState) kolide.PATCH("/user/enabled", SetUserEnabledState) diff --git a/story_test.go b/story_test.go new file mode 100644 index 0000000000..a8834f42a3 --- /dev/null +++ b/story_test.go @@ -0,0 +1,469 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" +) + +type integrationRequests struct { + r *gin.Engine + db *gorm.DB + t *testing.T +} + +func (req *integrationRequests) New(t *testing.T) { + req.t = t + + req.r = createTestServer() + req.r.Use(testSessionMiddleware) + req.r.Use(JWTRenewalMiddleware) + + req.db, _ = openTestDB() + injectedTestDB = req.db + + // Until we have a better solution for first-user onboarding, manually + // create an admin + _, err := NewUser(req.db, "admin", "foobar", "admin@kolide.co", true, false) + if err != nil { + panic(err.Error()) + } + + req.r.POST("/login", Login) + req.r.GET("/logout", Logout) + + req.r.POST("/user", GetUser) + req.r.PUT("/user", CreateUser) + req.r.PATCH("/user", ModifyUser) + req.r.DELETE("/user", DeleteUser) + + req.r.PATCH("/user/password", ChangeUserPassword) + req.r.PATCH("/user/admin", SetUserAdminState) + req.r.PATCH("/user/enabled", SetUserEnabledState) +} + +func (req *integrationRequests) Login(username, password string, sessionOut *string) { + response := httptest.NewRecorder() + body, err := json.Marshal(LoginRequestBody{ + Username: username, + Password: password, + }) + if err != nil { + req.t.Fatal(err.Error()) + return + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("POST", "/login", buff) + request.Header.Set("Content-Type", "application/json") + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return + } + *sessionOut = response.Header().Get("Set-Cookie") + + return +} + +func (req *integrationRequests) CreateUser(username, password, email string, admin, reset bool, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(CreateUserRequestBody{ + Username: username, + Password: password, + Email: email, + Admin: admin, + NeedsPasswordReset: reset, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PUT", "/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *integrationRequests) GetUser(username string, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(GetUserRequestBody{ + Username: username, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("POST", "/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *integrationRequests) ModifyUser(username, name, email string, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(ModifyUserRequestBody{ + Username: username, + Name: name, + Email: email, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *integrationRequests) DeleteUser(username string, session *string) { + response := httptest.NewRecorder() + body, err := json.Marshal(DeleteUserRequestBody{ + Username: username, + }) + if err != nil { + req.t.Fatal(err.Error()) + return + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("DELETE", "/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return + } + *session = response.Header().Get("Set-Cookie") + + return +} + +func (req *integrationRequests) ChangePassword(username, currentPassword, newPassword string, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(ChangePasswordRequestBody{ + Username: username, + CurrentPassword: currentPassword, + NewPassword: newPassword, + NewPasswordConfim: newPassword, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/user/password", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + } + + return &responseBody +} + +func (req *integrationRequests) SetAdminState(username string, admin bool, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(SetUserAdminStateRequestBody{ + Username: username, + Admin: admin, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/user/admin", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + } + + return &responseBody +} + +func (req *integrationRequests) SetEnabledState(username string, enabled bool, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(SetUserEnabledStateRequestBody{ + Username: username, + Enabled: enabled, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/user/enabled", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *integrationRequests) CheckUser(username, email, name string, admin, reset, enabled bool) { + var user User + err := req.db.Where("username = ?", username).First(&user).Error + if err != nil { + req.t.Fatal(err.Error()) + return + } + if user.Email != email { + req.t.Fatalf("user's email was not set in the DB: %s", user.Email) + } + if user.Admin != admin { + req.t.Fatal("user admin settings don't match") + } + if user.NeedsPasswordReset != reset { + req.t.Fatal("user reset settings don't match") + } + if user.Enabled != enabled { + req.t.Fatal("user enabled settings don't match") + } + if user.Name != name { + req.t.Fatalf("user names don't match: %s and %s", user.Name, name) + } + return +} + +func (req *integrationRequests) GetAndCheckUser(username string, session *string) { + resp := req.GetUser(username, session) + req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, resp.Enabled) +} + +func (req *integrationRequests) CreateAndCheckUser(username, password, email, name string, admin, reset bool, session *string) { + resp := req.CreateUser(username, password, email, admin, reset, session) + req.CheckUser(username, email, name, admin, reset, resp.Enabled) +} + +func (req *integrationRequests) ModifyAndCheckUser(username, email, name string, admin, reset bool, session *string) { + resp := req.ModifyUser(username, name, email, session) + req.CheckUser(username, email, name, admin, reset, resp.Enabled) +} + +func (req *integrationRequests) DeleteAndCheckUser(username string, session *string) { + req.DeleteUser(username, session) + + var user User + err := req.db.Where("username = ?", username).First(&user).Error + if err == nil { + req.t.Fatal("User should have been deleted.") + } +} + +func (req *integrationRequests) SetEnabledStateAndCheckUser(username string, enabled bool, session *string) { + resp := req.SetEnabledState(username, enabled, session) + req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, enabled) +} + +func (req *integrationRequests) SetAdminStateAndCheckUser(username string, admin bool, session *string) { + resp := req.SetAdminState(username, admin, session) + req.CheckUser(username, resp.Email, resp.Name, admin, resp.NeedsPasswordReset, resp.Enabled) +} + +func TestUserAndAccountManagement(t *testing.T) { + + // Create and configure the webserver which will be used to handle the tests + var req integrationRequests + req.New(t) + + // Instantiate the variables that will store the most recent session cookie + // for each user context that will be created + var adminSession string + var admin2Session string + var user1Session string + var user2Session string + + // Test logging in with the first admin + req.Login("admin", "foobar", &adminSession) + + // Once admin is logged in, create a user using a valid admin session + req.CreateAndCheckUser("user1", "foobar", "user1@kolide.co", "", false, false, &adminSession) + + // Once admin is logged in, create another admin account using a valid + // admin session + req.CreateAndCheckUser("admin2", "foobar", "admin2@kolide.co", "", true, false, &adminSession) + + // Once admin has created admin2, log in with admin2 to get a session + // context for admin2 + req.Login("admin2", "foobar", &admin2Session) + + // Use an admin created via the API to create a user via the API + req.CreateAndCheckUser("user2", "foobar", "user2@kolide.co", "", false, false, &admin2Session) + + // Once admin has created user1, log in with user1 to get a session context + // for user1 + req.Login("user1", "foobar", &user1Session) + + // Once admin2 has created user2, log in with user1 to get a session context + // for user2 + req.Login("user2", "foobar", &user2Session) + + // Get info on user2 as admin2 + req.GetAndCheckUser("user2", &admin2Session) + + // Get info on admin2 as user2 + req.GetAndCheckUser("admin2", &user2Session) + + // Modify user1 as admin + req.ModifyAndCheckUser("user1", "user1@kolide.co", "User One", false, false, &adminSession) + + // Modify user2 as user2 + req.ModifyAndCheckUser("user2", "user2@kolide.co", "User Two", false, false, &user2Session) + + // admin resets user1 password + req.ChangePassword("user1", "", "bazz1", &adminSession) + + // user1 logs in with new password + req.Login("user1", "bazz1", &user1Session) + + // user2 resets user2 password + req.ChangePassword("user2", "foobar", "bazz2", &user2Session) + + // user2 logs in with new password + req.Login("user2", "bazz2", &user2Session) + + // admin2 promotes user2 to admin + req.SetAdminStateAndCheckUser("user2", true, &admin2Session) + + // user2 is admin + resp := req.GetUser("user2", &user2Session) + if !resp.Admin { + t.Fatal("user2 should be an admin") + } + + // admin demotes user2 from admin + req.SetAdminStateAndCheckUser("user2", false, &adminSession) + + // user2 is no longer an admin + resp = req.GetUser("user2", &user2Session) + if resp.Admin { + t.Fatal("user2 shouldn't be an admin") + } + + // admin sets user1 as no longer enabled + req.SetEnabledStateAndCheckUser("user1", false, &adminSession) + + // user1 is no longer enabled + resp = req.GetUser("user1", &user2Session) + if resp.Enabled { + t.Fatal("user1 shouldn't be enabled") + } + + // admin2 re-enables user1 + req.SetEnabledStateAndCheckUser("user1", true, &admin2Session) + + // user1 can view user2 + req.GetUser("user2", &user2Session) + + // Delete admin2 as admin1 + req.DeleteAndCheckUser("admin2", &adminSession) + + // Delete user2 as admin + req.DeleteAndCheckUser("user2", &adminSession) +} diff --git a/users.go b/users.go index 92d38919ad..04c5b05204 100644 --- a/users.go +++ b/users.go @@ -81,7 +81,18 @@ func (u *User) MakeAdmin(db *gorm.DB) error { } type GetUserRequestBody struct { - ID uint `json:"id" binding:"required"` + ID uint `json:"id"` + Username string `json:"username"` +} + +type GetUserResponseBody struct { + ID uint `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + Name string `json:"name"` + Admin bool `json:"admin"` + Enabled bool `json:"enabled"` + NeedsPasswordReset bool `json:"needs_password_reset"` } func GetUser(c *gin.Context) { @@ -107,26 +118,27 @@ func GetUser(c *gin.Context) { } var user User - err = db.Where("id = ?", body.ID).First(&user).Error + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error if err != nil { - logrus.Errorf("Error finding user in database: %s", err.Error()) DatabaseError(c) return } - if !vc.CanPerformReadActionOnUser(db, &user) { + if !vc.CanPerformReadActionOnUser(&user) { UnauthorizedError(c) return } - c.JSON(200, map[string]interface{}{ - "id": user.ID, - "username": user.Username, - "name": user.Name, - "email": user.Email, - "admin": user.Admin, - "enabled": user.Enabled, - "needs_password_reset": user.NeedsPasswordReset, + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, }) } @@ -165,18 +177,26 @@ func CreateUser(c *gin.Context) { return } - _, err = NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset) + user, err := NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset) if err != nil { logrus.Errorf("Error creating new user: %s", err.Error()) DatabaseError(c) return } - c.JSON(200, nil) + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, + }) } type ModifyUserRequestBody struct { - ID uint `json:"id" binding:"required"` + ID uint `json:"id"` Username string `json:"username"` Name string `json:"name"` Email string `json:"email"` @@ -205,29 +225,46 @@ func ModifyUser(c *gin.Context) { } var user User - err = db.Where("id = ?", body.ID).First(&user).Error + user.ID = body.ID + user.Username = body.Username + + err = db.Where(&user).First(&user).Error if err != nil { - logrus.Errorf("Error finding user in database: %s", err.Error()) DatabaseError(c) return } - if !vc.CanPerformWriteActionOnUser(db, &user) { + if !vc.CanPerformWriteActionOnUser(&user) { UnauthorizedError(c) return } + if body.Name != "" { + user.Name = body.Name + } + if body.Email != "" { + user.Email = body.Email + } err = db.Save(&user).Error if err != nil { logrus.Errorf("Error updating user in database: %s", err.Error()) DatabaseError(c) return } - c.JSON(200, nil) + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, + }) } type DeleteUserRequestBody struct { - ID uint `json:"id" binding:"required"` + ID uint `json:"id"` + Username string `json:"username"` } func DeleteUser(c *gin.Context) { @@ -258,9 +295,10 @@ func DeleteUser(c *gin.Context) { } var user User - err = db.Where("id = ?", body.ID).First(&user).Error + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error if err != nil { - logrus.Errorf("Error finding user in database: %s", err.Error()) DatabaseError(c) return } @@ -275,20 +313,29 @@ func DeleteUser(c *gin.Context) { } type ResetPasswordRequestBody struct { - ID uint `json:"id" binding:"required"` + ID uint `json:"id"` + Username string `json:"username"` Password string `json:"password" binding:"required"` PasswordConfim string `json:"password_confirm" binding:"required"` } -func ResetUserPassword(c *gin.Context) { - var body ResetPasswordRequestBody +type ChangePasswordRequestBody struct { + ID uint `json:"id"` + Username string `json:"username"` + CurrentPassword string `json:"current_password"` + NewPassword string `json:"new_password" binding:"required"` + NewPasswordConfim string `json:"new_password_confirm" binding:"required"` +} + +func ChangeUserPassword(c *gin.Context) { + var body ChangePasswordRequestBody err := c.BindJSON(&body) if err != nil { logrus.Errorf("Error parsing ResetPassword post body: %s", err.Error()) return } - if body.Password != body.PasswordConfim { + if body.NewPassword != body.NewPasswordConfim { c.JSON(406, map[string]interface{}{"error": "Passwords do not match"}) return } @@ -300,26 +347,34 @@ func ResetUserPassword(c *gin.Context) { return } + var user User + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error + if err != nil { + DatabaseError(c) + return + } + vc, err := VC(c, db) if err != nil { logrus.Errorf("Could not create VC: %s", err.Error()) DatabaseError(c) return } - var user User - err = db.Where("id = ?", body.ID).First(&user).Error - if err != nil { - logrus.Errorf("Error finding user in database: %s", err.Error()) - DatabaseError(c) - return + + if !vc.IsAdmin() { + if !vc.IsUserID(user.ID) { + UnauthorizedError(c) + return + } + if user.ValidatePassword(body.CurrentPassword) != nil { + UnauthorizedError(c) + return + } } - if !vc.CanPerformWriteActionOnUser(db, &user) { - UnauthorizedError(c) - return - } - - err = user.SetPassword(db, body.Password) + err = user.SetPassword(db, body.NewPassword) if err != nil { logrus.Errorf("Error setting user password: %s", err.Error()) // xxx don't try to write to the db? @@ -331,12 +386,21 @@ func ResetUserPassword(c *gin.Context) { DatabaseError(c) return } - c.JSON(200, nil) + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, + }) } type SetUserAdminStateRequestBody struct { - ID uint `json:"id" binding:"required"` - Admin bool `json:"admin" binding:"required"` + ID uint `json:"id"` + Username string `json:"username"` + Admin bool `json:"admin"` } func SetUserAdminState(c *gin.Context) { @@ -367,9 +431,10 @@ func SetUserAdminState(c *gin.Context) { } var user User - err = db.Where("id = ?", body.ID).First(&user).Error + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error if err != nil { - logrus.Errorf("Error finding user in database: %s", err.Error()) DatabaseError(c) return } @@ -381,12 +446,21 @@ func SetUserAdminState(c *gin.Context) { DatabaseError(c) return } - c.JSON(200, nil) + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, + }) } type SetUserEnabledStateRequestBody struct { - ID uint `json:"id" binding:"required"` - Enabled bool `json:"enabled" binding:"required"` + ID uint `json:"id"` + Username string `json:"username"` + Enabled bool `json:"enabled"` } func SetUserEnabledState(c *gin.Context) { @@ -417,9 +491,10 @@ func SetUserEnabledState(c *gin.Context) { } var user User - err = db.Where("id = ?", body.ID).First(&user).Error + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error if err != nil { - logrus.Errorf("Error finding user in database: %s", err.Error()) DatabaseError(c) return } @@ -431,5 +506,13 @@ func SetUserEnabledState(c *gin.Context) { DatabaseError(c) return } - c.JSON(200, nil) + c.JSON(200, GetUserResponseBody{ + ID: user.ID, + Username: user.Username, + Name: user.Name, + Email: user.Email, + Admin: user.Admin, + Enabled: user.Enabled, + NeedsPasswordReset: user.NeedsPasswordReset, + }) } diff --git a/users_test.go b/users_test.go index bc51d65423..83b22294b2 100644 --- a/users_test.go +++ b/users_test.go @@ -1,10 +1,6 @@ package main import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" "testing" "github.com/jinzhu/gorm" @@ -201,109 +197,3 @@ func TestSetPassword(t *testing.T) { t.Fatal(err.Error()) } } - -func TestUserManagementIntegration(t *testing.T) { - r := createTestServer() - r.Use(testSessionMiddleware) - r.Use(JWTRenewalMiddleware) - - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } - injectedTestDB = db - - admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false) - if err != nil { - t.Fatal(err.Error()) - } - _ = admin - - r.POST("/login", Login) - r.GET("/logout", Logout) - - r.GET("/user", GetUser) - r.PUT("/user", CreateUser) - r.PATCH("/user", ModifyUser) - r.DELETE("/user", DeleteUser) - - res1 := httptest.NewRecorder() - body1, err := json.Marshal(LoginRequestBody{ - Username: "admin", - Password: "foobar", - }) - if err != nil { - t.Fatal(err.Error()) - } - buff1 := new(bytes.Buffer) - buff1.Write(body1) - req1, _ := http.NewRequest("POST", "/login", buff1) - req1.Header.Set("Content-Type", "application/json") - r.ServeHTTP(res1, req1) - if res1.Code != 200 { - t.Fatalf("Response code: %d", res1.Code) - } - - res2 := httptest.NewRecorder() - body2, err := json.Marshal(CreateUserRequestBody{ - Username: "marpaia", - Password: "foobar", - Email: "mike@kolide.co", - Admin: false, - NeedsPasswordReset: false, - }) - if err != nil { - t.Fatal(err.Error()) - } - buff2 := new(bytes.Buffer) - buff2.Write(body2) - req2, _ := http.NewRequest("PUT", "/user", buff2) - req2.Header.Set("Content-Type", "application/json") - req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) - r.ServeHTTP(res2, req2) - - res3 := httptest.NewRecorder() - body3, err := json.Marshal(CreateUserRequestBody{ - Username: "admin2", - Password: "foobar", - Email: "admin2@kolide.co", - Admin: true, - NeedsPasswordReset: false, - }) - if err != nil { - t.Fatal(err.Error()) - } - buff3 := new(bytes.Buffer) - buff3.Write(body3) - req3, _ := http.NewRequest("PUT", "/user", buff3) - req3.Header.Set("Content-Type", "application/json") - req3.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) - r.ServeHTTP(res3, req3) - - var user User - err = db.Where("username = ?", "marpaia").First(&user).Error - if err != nil { - t.Fatal(err.Error()) - } - - if user.Email != "mike@kolide.co" { - t.Fatalf("user's email was not set in the DB: %s", user.Email) - } - if user.Admin { - t.Fatal("user shouldn't be admin") - } - - var admin2 User - err = db.Where("username = ?", "admin2").First(&admin2).Error - if err != nil { - t.Fatal(err.Error()) - } - - if admin2.Email != "admin2@kolide.co" { - t.Fatalf("admin2's email was not set in the DB: %s", admin2.Email) - } - if !admin2.Admin { - t.Fatal("admin2 should be admin") - } - -}