Eliminate DB globals + refactor tests (#39)

* Eliminate global DB connections. Instead, one connection is established (with the underlying object supporting pooling) and passed through the gin.Context. This allows test/prod to inject the appropriate DB object into the context.
* Refactor tests appropriately for this new style of DB connection
* Fix a bug in the routing caught by refactoring of tests
This commit is contained in:
Zachary Wasserman 2016-08-04 11:41:47 -07:00 committed by GitHub
parent 5ad7c07e0c
commit 5c349a458d
10 changed files with 427 additions and 532 deletions

View file

@ -215,12 +215,7 @@ func Login(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
var user User
err = db.Where("username = ?", body.Username).First(&user).Error

View file

@ -18,10 +18,7 @@ func TestGenerateRandomText(t *testing.T) {
}
func TestGenerateVC(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
if err != nil {
@ -50,15 +47,12 @@ func TestGenerateVC(t *testing.T) {
}
func TestVC(t *testing.T) {
r := createTestServer()
db := openTestDB()
r := createEmptyTestServer(db)
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
@ -134,10 +128,7 @@ func TestVC(t *testing.T) {
}
func TestIsUserID(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
db := openTestDB()
user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
@ -156,10 +147,7 @@ func TestIsUserID(t *testing.T) {
}
func TestCanPerformActionsOnUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
db := openTestDB()
user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {

View file

@ -51,6 +51,8 @@ func init() {
// populate the global config data structure with sane defaults
setDefaultConfigValues()
rand.Seed(time.Now().UnixNano())
}
// logContextHook is a logrus hook which is used to contextualize application
@ -74,10 +76,6 @@ func (hook logContextHook) Fire(entry *logrus.Entry) error {
return nil
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func main() {
// configure flag parsing and parse flags
app.Version(version)
@ -124,10 +122,16 @@ func main() {
dropTables(db)
createTables(db)
case serve.FullCommand():
db, err := openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database)
if err != nil {
logrus.Fatalf("Error opening database: %s", err.Error())
}
fmt.Printf("=> %s %s application starting on https://%s\n", app.Name, version, config.Server.Address)
fmt.Println("=> Run `kolide help serve` for more startup options")
fmt.Println("Use Ctrl-C to stop\n\n")
CreateServer().RunTLS(
CreateServer(db).RunTLS(
config.Server.Address,
config.Server.Cert,
config.Server.Key)

View file

@ -1,7 +1,6 @@
package main
import (
"errors"
"fmt"
"time"
@ -13,36 +12,9 @@ import (
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
type DBEnvironment int
const (
ProductionDB DBEnvironment = iota
TestingDB DBEnvironment = iota
)
var injectedTestDB *gorm.DB
func GetDB(c *gin.Context) (*gorm.DB, error) {
f, ok := c.Get("DB")
if !ok {
return nil, errors.New("DB was not set on the supplied *gin.Context. Use a middleware to set it.")
}
switch f.(DBEnvironment) {
case ProductionDB:
return openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database)
case TestingDB:
if injectedTestDB != nil {
return injectedTestDB, nil
}
db, err := openTestDB()
if err != nil {
return nil, errors.New("Could not open a test database")
}
injectedTestDB = db
return injectedTestDB, nil
default:
return nil, errors.New("GetDB not implemented for DBEnvironment object")
}
// Get the database connection from the context, or panic
func GetDB(c *gin.Context) *gorm.DB {
return c.MustGet("DB").(*gorm.DB)
}
type BaseModel struct {
@ -197,43 +169,27 @@ func setDBSettings(db *gorm.DB) {
func openDB(user, password, address, dbName string) (*gorm.DB, error) {
connectionString := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", user, password, address, dbName)
return gorm.Open("mysql", connectionString)
db, err := gorm.Open("mysql", connectionString)
if err != nil {
return nil, err
}
setDBSettings(db)
return db, nil
}
/// Open a database connection, or panic
func mustOpenDB(user, password, address, dbName string) *gorm.DB {
db, err := openDB(user, password, address, dbName)
if err != nil {
panic(fmt.Sprintf("Could not connect to DB: %s", err.Error()))
}
return db
}
func openTestDB() (*gorm.DB, error) {
func openTestDB() *gorm.DB {
db, err := gorm.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
panic(fmt.Sprintf("Error opening test DB: %s", err.Error()))
}
setDBSettings(db)
createTables(db)
return db, db.Error
}
func ProductionDatabaseMiddleware(c *gin.Context) {
c.Set("DB", ProductionDB)
c.Next()
}
func TestingDatabaseMiddleware(c *gin.Context) {
c.Set("DB", TestingDB)
c.Next()
if db.Error != nil {
panic(fmt.Sprintf("Error creating test DB tables: %s", db.Error.Error()))
}
return db
}
func dropTables(db *gorm.DB) {

View file

@ -6,6 +6,7 @@ import (
"github.com/Sirupsen/logrus"
"github.com/gin-gonic/contrib/ginrus"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
)
// ServerError is a helper which accepts a string error and returns a map in
@ -36,17 +37,25 @@ func MalformedRequestError(c *gin.Context) {
c.JSON(400, ServerError("Malformed request"))
}
func createTestServer() *gin.Engine {
func createEmptyTestServer(db *gorm.DB) *gin.Engine {
server := gin.New()
server.Use(TestingDatabaseMiddleware)
server.Use(DatabaseMiddleware(db))
return server
}
// Adapted from https://goo.gl/03Qxiy
func DatabaseMiddleware(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("DB", db)
c.Next()
}
}
// CreateServer creates a gin.Engine HTTP server and configures it to be in a
// state such that it is ready to serve HTTP requests for the kolide application
func CreateServer() *gin.Engine {
func CreateServer(db *gorm.DB) *gin.Engine {
server := gin.New()
server.Use(ProductionDatabaseMiddleware)
server.Use(DatabaseMiddleware(db))
// TODO: The following loggers are not synchronized with each other or
// logrus.StandardLogger() used through the rest of the codebase. As
@ -74,7 +83,7 @@ func CreateServer() *gin.Engine {
kolide.POST("/login", Login)
kolide.GET("/logout", Logout)
kolide.GET("/user", GetUser)
kolide.POST("/user", GetUser)
kolide.PUT("/user", CreateUser)
kolide.PATCH("/user", ModifyUser)
kolide.DELETE("/user", DeleteUser)

View file

@ -6,21 +6,11 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/sessions"
)
const testSessionName = "TestSession"
func getTestStore() sessions.Store {
return sessions.NewCookieStore([]byte("test"))
}
func testSessionMiddleware(c *gin.Context) {
CreateSession(testSessionName, getTestStore())(c)
}
func TestSessionGetSet(t *testing.T) {
r := createTestServer()
db := openTestDB()
r := createEmptyTestServer(db)
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
@ -50,7 +40,8 @@ func TestSessionGetSet(t *testing.T) {
}
func TestSessionDeleteKey(t *testing.T) {
r := createTestServer()
db := openTestDB()
r := createEmptyTestServer(db)
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
@ -92,7 +83,8 @@ func TestSessionDeleteKey(t *testing.T) {
}
func TestSessionFlashes(t *testing.T) {
r := createTestServer()
db := openTestDB()
r := createEmptyTestServer(db)
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
@ -139,11 +131,13 @@ func TestSessionFlashes(t *testing.T) {
}
func TestSessionClear(t *testing.T) {
db := openTestDB()
r := createEmptyTestServer(db)
data := map[string]string{
"key": "val",
"foo": "bar",
}
r := createTestServer()
store := getTestStore()
r.Use(CreateSession(testSessionName, store))
r.Use(JWTRenewalMiddleware)

View file

@ -1,375 +1,13 @@
package main
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
)
type integrationRequests struct {
r *gin.Engine
db *gorm.DB
t *testing.T
}
func (req *integrationRequests) New(t *testing.T) {
req.t = t
req.r = createTestServer()
req.r.Use(testSessionMiddleware)
req.r.Use(JWTRenewalMiddleware)
req.db, _ = openTestDB()
injectedTestDB = req.db
// Until we have a better solution for first-user onboarding, manually
// create an admin
_, err := NewUser(req.db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
panic(err.Error())
}
req.r.POST("/login", Login)
req.r.GET("/logout", Logout)
req.r.POST("/user", GetUser)
req.r.PUT("/user", CreateUser)
req.r.PATCH("/user", ModifyUser)
req.r.DELETE("/user", DeleteUser)
req.r.PATCH("/user/password", ChangeUserPassword)
req.r.PATCH("/user/admin", SetUserAdminState)
req.r.PATCH("/user/enabled", SetUserEnabledState)
}
func (req *integrationRequests) Login(username, password string, sessionOut *string) {
response := httptest.NewRecorder()
body, err := json.Marshal(LoginRequestBody{
Username: username,
Password: password,
})
if err != nil {
req.t.Fatal(err.Error())
return
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("POST", "/login", buff)
request.Header.Set("Content-Type", "application/json")
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return
}
*sessionOut = response.Header().Get("Set-Cookie")
return
}
func (req *integrationRequests) CreateUser(username, password, email string, admin, reset bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(CreateUserRequestBody{
Username: username,
Password: password,
Email: email,
Admin: admin,
NeedsPasswordReset: reset,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PUT", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) GetUser(username string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(GetUserRequestBody{
Username: username,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("POST", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) ModifyUser(username, name, email string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(ModifyUserRequestBody{
Username: username,
Name: name,
Email: email,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) DeleteUser(username string, session *string) {
response := httptest.NewRecorder()
body, err := json.Marshal(DeleteUserRequestBody{
Username: username,
})
if err != nil {
req.t.Fatal(err.Error())
return
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("DELETE", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return
}
*session = response.Header().Get("Set-Cookie")
return
}
func (req *integrationRequests) ChangePassword(username, currentPassword, newPassword string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(ChangePasswordRequestBody{
Username: username,
CurrentPassword: currentPassword,
NewPassword: newPassword,
NewPasswordConfim: newPassword,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user/password", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
}
return &responseBody
}
func (req *integrationRequests) SetAdminState(username string, admin bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(SetUserAdminStateRequestBody{
Username: username,
Admin: admin,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user/admin", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
}
return &responseBody
}
func (req *integrationRequests) SetEnabledState(username string, enabled bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(SetUserEnabledStateRequestBody{
Username: username,
Enabled: enabled,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user/enabled", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) CheckUser(username, email, name string, admin, reset, enabled bool) {
var user User
err := req.db.Where("username = ?", username).First(&user).Error
if err != nil {
req.t.Fatal(err.Error())
return
}
if user.Email != email {
req.t.Fatalf("user's email was not set in the DB: %s", user.Email)
}
if user.Admin != admin {
req.t.Fatal("user admin settings don't match")
}
if user.NeedsPasswordReset != reset {
req.t.Fatal("user reset settings don't match")
}
if user.Enabled != enabled {
req.t.Fatal("user enabled settings don't match")
}
if user.Name != name {
req.t.Fatalf("user names don't match: %s and %s", user.Name, name)
}
return
}
func (req *integrationRequests) GetAndCheckUser(username string, session *string) {
resp := req.GetUser(username, session)
req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, resp.Enabled)
}
func (req *integrationRequests) CreateAndCheckUser(username, password, email, name string, admin, reset bool, session *string) {
resp := req.CreateUser(username, password, email, admin, reset, session)
req.CheckUser(username, email, name, admin, reset, resp.Enabled)
}
func (req *integrationRequests) ModifyAndCheckUser(username, email, name string, admin, reset bool, session *string) {
resp := req.ModifyUser(username, name, email, session)
req.CheckUser(username, email, name, admin, reset, resp.Enabled)
}
func (req *integrationRequests) DeleteAndCheckUser(username string, session *string) {
req.DeleteUser(username, session)
var user User
err := req.db.Where("username = ?", username).First(&user).Error
if err == nil {
req.t.Fatal("User should have been deleted.")
}
}
func (req *integrationRequests) SetEnabledStateAndCheckUser(username string, enabled bool, session *string) {
resp := req.SetEnabledState(username, enabled, session)
req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, enabled)
}
func (req *integrationRequests) SetAdminStateAndCheckUser(username string, admin bool, session *string) {
resp := req.SetAdminState(username, admin, session)
req.CheckUser(username, resp.Email, resp.Name, admin, resp.NeedsPasswordReset, resp.Enabled)
}
func TestUserAndAccountManagement(t *testing.T) {
// Create and configure the webserver which will be used to handle the tests
var req integrationRequests
var req IntegrationRequests
req.New(t)
// Instantiate the variables that will store the most recent session cookie

364
test_util.go Normal file
View file

@ -0,0 +1,364 @@
package main
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/sessions"
"github.com/jinzhu/gorm"
)
const testSessionName = "TestSession"
func getTestStore() sessions.Store {
return sessions.NewCookieStore([]byte("test"))
}
func testSessionMiddleware(c *gin.Context) {
CreateSession(testSessionName, getTestStore())(c)
}
type IntegrationRequests struct {
r *gin.Engine
db *gorm.DB
t *testing.T
}
func (req *IntegrationRequests) New(t *testing.T) {
req.t = t
*debug = false
req.db = openTestDB()
// Until we have a better solution for first-user onboarding, manually
// create an admin
_, err := NewUser(req.db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
t.Fatalf("Error opening DB: %s", err.Error())
}
req.r = CreateServer(req.db)
}
func (req *IntegrationRequests) Login(username, password string, sessionOut *string) {
response := httptest.NewRecorder()
body, err := json.Marshal(LoginRequestBody{
Username: username,
Password: password,
})
if err != nil {
req.t.Fatal(err.Error())
return
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("POST", "/api/v1/kolide/login", buff)
request.Header.Set("Content-Type", "application/json")
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return
}
*sessionOut = response.Header().Get("Set-Cookie")
return
}
func (req *IntegrationRequests) CreateUser(username, password, email string, admin, reset bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(CreateUserRequestBody{
Username: username,
Password: password,
Email: email,
Admin: admin,
NeedsPasswordReset: reset,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PUT", "/api/v1/kolide/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *IntegrationRequests) GetUser(username string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(GetUserRequestBody{
Username: username,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("POST", "/api/v1/kolide/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *IntegrationRequests) ModifyUser(username, name, email string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(ModifyUserRequestBody{
Username: username,
Name: name,
Email: email,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *IntegrationRequests) DeleteUser(username string, session *string) {
response := httptest.NewRecorder()
body, err := json.Marshal(DeleteUserRequestBody{
Username: username,
})
if err != nil {
req.t.Fatal(err.Error())
return
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("DELETE", "/api/v1/kolide/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return
}
*session = response.Header().Get("Set-Cookie")
return
}
func (req *IntegrationRequests) ChangePassword(username, currentPassword, newPassword string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(ChangePasswordRequestBody{
Username: username,
CurrentPassword: currentPassword,
NewPassword: newPassword,
NewPasswordConfim: newPassword,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user/password", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
}
return &responseBody
}
func (req *IntegrationRequests) SetAdminState(username string, admin bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(SetUserAdminStateRequestBody{
Username: username,
Admin: admin,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user/admin", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
}
return &responseBody
}
func (req *IntegrationRequests) SetEnabledState(username string, enabled bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(SetUserEnabledStateRequestBody{
Username: username,
Enabled: enabled,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user/enabled", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *IntegrationRequests) CheckUser(username, email, name string, admin, reset, enabled bool) {
var user User
err := req.db.Where("username = ?", username).First(&user).Error
if err != nil {
req.t.Fatal(err.Error())
return
}
if user.Email != email {
req.t.Fatalf("user's email was not set in the DB: %s", user.Email)
}
if user.Admin != admin {
req.t.Fatal("user admin settings don't match")
}
if user.NeedsPasswordReset != reset {
req.t.Fatal("user reset settings don't match")
}
if user.Enabled != enabled {
req.t.Fatal("user enabled settings don't match")
}
if user.Name != name {
req.t.Fatalf("user names don't match: %s and %s", user.Name, name)
}
return
}
func (req *IntegrationRequests) GetAndCheckUser(username string, session *string) {
resp := req.GetUser(username, session)
req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, resp.Enabled)
}
func (req *IntegrationRequests) CreateAndCheckUser(username, password, email, name string, admin, reset bool, session *string) {
resp := req.CreateUser(username, password, email, admin, reset, session)
req.CheckUser(username, email, name, admin, reset, resp.Enabled)
}
func (req *IntegrationRequests) ModifyAndCheckUser(username, email, name string, admin, reset bool, session *string) {
resp := req.ModifyUser(username, name, email, session)
req.CheckUser(username, email, name, admin, reset, resp.Enabled)
}
func (req *IntegrationRequests) DeleteAndCheckUser(username string, session *string) {
req.DeleteUser(username, session)
var user User
err := req.db.Where("username = ?", username).First(&user).Error
if err == nil {
req.t.Fatal("User should have been deleted.")
}
}
func (req *IntegrationRequests) SetEnabledStateAndCheckUser(username string, enabled bool, session *string) {
resp := req.SetEnabledState(username, enabled, session)
req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, enabled)
}
func (req *IntegrationRequests) SetAdminStateAndCheckUser(username string, admin bool, session *string) {
resp := req.SetAdminState(username, admin, session)
req.CheckUser(username, resp.Email, resp.Name, admin, resp.NeedsPasswordReset, resp.Enabled)
}

View file

@ -103,12 +103,7 @@ func GetUser(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
vc, err := VC(c, db)
if err != nil {
@ -158,12 +153,7 @@ func CreateUser(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
vc, err := VC(c, db)
if err != nil {
@ -210,12 +200,7 @@ func ModifyUser(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
vc, err := VC(c, db)
if err != nil {
@ -275,12 +260,7 @@ func DeleteUser(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
vc, err := VC(c, db)
if err != nil {
@ -340,12 +320,7 @@ func ChangeUserPassword(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
var user User
user.ID = body.ID
@ -411,12 +386,7 @@ func SetUserAdminState(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
vc, err := VC(c, db)
if err != nil {
@ -471,12 +441,7 @@ func SetUserEnabledState(c *gin.Context) {
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
db := GetDB(c)
vc, err := VC(c, db)
if err != nil {

View file

@ -7,10 +7,7 @@ import (
)
func TestNewUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
if err != nil {
@ -37,10 +34,7 @@ func TestNewUser(t *testing.T) {
}
func TestValidatePassword(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
if err != nil {
@ -59,10 +53,7 @@ func TestValidatePassword(t *testing.T) {
}
func TestMakeAdmin(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
@ -96,10 +87,7 @@ func TestMakeAdmin(t *testing.T) {
}
func TestUpdatingUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
@ -129,10 +117,7 @@ func TestUpdatingUser(t *testing.T) {
}
func TestDeletingUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
@ -161,10 +146,7 @@ func TestDeletingUser(t *testing.T) {
}
func TestSetPassword(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
db := openTestDB()
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {