diff --git a/auth.go b/auth.go index 1ccabbc46e..77f993eed2 100644 --- a/auth.go +++ b/auth.go @@ -215,12 +215,7 @@ func Login(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) var user User err = db.Where("username = ?", body.Username).First(&user).Error diff --git a/auth_test.go b/auth_test.go index 48d7aef13c..c5e85cd13d 100644 --- a/auth_test.go +++ b/auth_test.go @@ -18,10 +18,7 @@ func TestGenerateRandomText(t *testing.T) { } func TestGenerateVC(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false) if err != nil { @@ -50,15 +47,12 @@ func TestGenerateVC(t *testing.T) { } func TestVC(t *testing.T) { - r := createTestServer() + db := openTestDB() + r := createEmptyTestServer(db) + r.Use(testSessionMiddleware) r.Use(JWTRenewalMiddleware) - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } - user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil { t.Fatal(err.Error()) @@ -134,10 +128,7 @@ func TestVC(t *testing.T) { } func TestIsUserID(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err) - } + db := openTestDB() user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil { @@ -156,10 +147,7 @@ func TestIsUserID(t *testing.T) { } func TestCanPerformActionsOnUser(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err) - } + db := openTestDB() user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil { diff --git a/kolide.go b/kolide.go index b93258de15..d136935841 100644 --- a/kolide.go +++ b/kolide.go @@ -51,6 +51,8 @@ func init() { // populate the global config data structure with sane defaults setDefaultConfigValues() + + rand.Seed(time.Now().UnixNano()) } // logContextHook is a logrus hook which is used to contextualize application @@ -74,10 +76,6 @@ func (hook logContextHook) Fire(entry *logrus.Entry) error { return nil } -func init() { - rand.Seed(time.Now().UnixNano()) -} - func main() { // configure flag parsing and parse flags app.Version(version) @@ -124,10 +122,16 @@ func main() { dropTables(db) createTables(db) case serve.FullCommand(): + db, err := openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database) + if err != nil { + logrus.Fatalf("Error opening database: %s", err.Error()) + } + fmt.Printf("=> %s %s application starting on https://%s\n", app.Name, version, config.Server.Address) fmt.Println("=> Run `kolide help serve` for more startup options") fmt.Println("Use Ctrl-C to stop\n\n") - CreateServer().RunTLS( + + CreateServer(db).RunTLS( config.Server.Address, config.Server.Cert, config.Server.Key) diff --git a/models.go b/models.go index 8d22eec321..e16a677a1a 100644 --- a/models.go +++ b/models.go @@ -1,7 +1,6 @@ package main import ( - "errors" "fmt" "time" @@ -13,36 +12,9 @@ import ( _ "github.com/jinzhu/gorm/dialects/sqlite" ) -type DBEnvironment int - -const ( - ProductionDB DBEnvironment = iota - TestingDB DBEnvironment = iota -) - -var injectedTestDB *gorm.DB - -func GetDB(c *gin.Context) (*gorm.DB, error) { - f, ok := c.Get("DB") - if !ok { - return nil, errors.New("DB was not set on the supplied *gin.Context. Use a middleware to set it.") - } - switch f.(DBEnvironment) { - case ProductionDB: - return openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database) - case TestingDB: - if injectedTestDB != nil { - return injectedTestDB, nil - } - db, err := openTestDB() - if err != nil { - return nil, errors.New("Could not open a test database") - } - injectedTestDB = db - return injectedTestDB, nil - default: - return nil, errors.New("GetDB not implemented for DBEnvironment object") - } +// Get the database connection from the context, or panic +func GetDB(c *gin.Context) *gorm.DB { + return c.MustGet("DB").(*gorm.DB) } type BaseModel struct { @@ -197,43 +169,27 @@ func setDBSettings(db *gorm.DB) { func openDB(user, password, address, dbName string) (*gorm.DB, error) { connectionString := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", user, password, address, dbName) - return gorm.Open("mysql", connectionString) db, err := gorm.Open("mysql", connectionString) if err != nil { return nil, err } + setDBSettings(db) return db, nil } -/// Open a database connection, or panic -func mustOpenDB(user, password, address, dbName string) *gorm.DB { - db, err := openDB(user, password, address, dbName) - if err != nil { - panic(fmt.Sprintf("Could not connect to DB: %s", err.Error())) - } - return db -} - -func openTestDB() (*gorm.DB, error) { +func openTestDB() *gorm.DB { db, err := gorm.Open("sqlite3", ":memory:") if err != nil { - return nil, err + panic(fmt.Sprintf("Error opening test DB: %s", err.Error())) } setDBSettings(db) createTables(db) - return db, db.Error -} - -func ProductionDatabaseMiddleware(c *gin.Context) { - c.Set("DB", ProductionDB) - c.Next() -} - -func TestingDatabaseMiddleware(c *gin.Context) { - c.Set("DB", TestingDB) - c.Next() + if db.Error != nil { + panic(fmt.Sprintf("Error creating test DB tables: %s", db.Error.Error())) + } + return db } func dropTables(db *gorm.DB) { diff --git a/server.go b/server.go index 5a0daad0d2..1e5981e12d 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/gin-gonic/contrib/ginrus" "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" ) // ServerError is a helper which accepts a string error and returns a map in @@ -36,17 +37,25 @@ func MalformedRequestError(c *gin.Context) { c.JSON(400, ServerError("Malformed request")) } -func createTestServer() *gin.Engine { +func createEmptyTestServer(db *gorm.DB) *gin.Engine { server := gin.New() - server.Use(TestingDatabaseMiddleware) + server.Use(DatabaseMiddleware(db)) return server } +// Adapted from https://goo.gl/03Qxiy +func DatabaseMiddleware(db *gorm.DB) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set("DB", db) + c.Next() + } +} + // CreateServer creates a gin.Engine HTTP server and configures it to be in a // state such that it is ready to serve HTTP requests for the kolide application -func CreateServer() *gin.Engine { +func CreateServer(db *gorm.DB) *gin.Engine { server := gin.New() - server.Use(ProductionDatabaseMiddleware) + server.Use(DatabaseMiddleware(db)) // TODO: The following loggers are not synchronized with each other or // logrus.StandardLogger() used through the rest of the codebase. As @@ -74,7 +83,7 @@ func CreateServer() *gin.Engine { kolide.POST("/login", Login) kolide.GET("/logout", Logout) - kolide.GET("/user", GetUser) + kolide.POST("/user", GetUser) kolide.PUT("/user", CreateUser) kolide.PATCH("/user", ModifyUser) kolide.DELETE("/user", DeleteUser) diff --git a/sessions_test.go b/sessions_test.go index dca8e28e42..5cab6df40a 100644 --- a/sessions_test.go +++ b/sessions_test.go @@ -6,21 +6,11 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/gorilla/sessions" ) -const testSessionName = "TestSession" - -func getTestStore() sessions.Store { - return sessions.NewCookieStore([]byte("test")) -} - -func testSessionMiddleware(c *gin.Context) { - CreateSession(testSessionName, getTestStore())(c) -} - func TestSessionGetSet(t *testing.T) { - r := createTestServer() + db := openTestDB() + r := createEmptyTestServer(db) r.Use(testSessionMiddleware) r.Use(JWTRenewalMiddleware) @@ -50,7 +40,8 @@ func TestSessionGetSet(t *testing.T) { } func TestSessionDeleteKey(t *testing.T) { - r := createTestServer() + db := openTestDB() + r := createEmptyTestServer(db) r.Use(testSessionMiddleware) r.Use(JWTRenewalMiddleware) @@ -92,7 +83,8 @@ func TestSessionDeleteKey(t *testing.T) { } func TestSessionFlashes(t *testing.T) { - r := createTestServer() + db := openTestDB() + r := createEmptyTestServer(db) r.Use(testSessionMiddleware) r.Use(JWTRenewalMiddleware) @@ -139,11 +131,13 @@ func TestSessionFlashes(t *testing.T) { } func TestSessionClear(t *testing.T) { + db := openTestDB() + r := createEmptyTestServer(db) + data := map[string]string{ "key": "val", "foo": "bar", } - r := createTestServer() store := getTestStore() r.Use(CreateSession(testSessionName, store)) r.Use(JWTRenewalMiddleware) diff --git a/story_test.go b/story_test.go index a8834f42a3..b094cd28e5 100644 --- a/story_test.go +++ b/story_test.go @@ -1,375 +1,13 @@ package main import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" "testing" - - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" ) -type integrationRequests struct { - r *gin.Engine - db *gorm.DB - t *testing.T -} - -func (req *integrationRequests) New(t *testing.T) { - req.t = t - - req.r = createTestServer() - req.r.Use(testSessionMiddleware) - req.r.Use(JWTRenewalMiddleware) - - req.db, _ = openTestDB() - injectedTestDB = req.db - - // Until we have a better solution for first-user onboarding, manually - // create an admin - _, err := NewUser(req.db, "admin", "foobar", "admin@kolide.co", true, false) - if err != nil { - panic(err.Error()) - } - - req.r.POST("/login", Login) - req.r.GET("/logout", Logout) - - req.r.POST("/user", GetUser) - req.r.PUT("/user", CreateUser) - req.r.PATCH("/user", ModifyUser) - req.r.DELETE("/user", DeleteUser) - - req.r.PATCH("/user/password", ChangeUserPassword) - req.r.PATCH("/user/admin", SetUserAdminState) - req.r.PATCH("/user/enabled", SetUserEnabledState) -} - -func (req *integrationRequests) Login(username, password string, sessionOut *string) { - response := httptest.NewRecorder() - body, err := json.Marshal(LoginRequestBody{ - Username: username, - Password: password, - }) - if err != nil { - req.t.Fatal(err.Error()) - return - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("POST", "/login", buff) - request.Header.Set("Content-Type", "application/json") - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return - } - *sessionOut = response.Header().Get("Set-Cookie") - - return -} - -func (req *integrationRequests) CreateUser(username, password, email string, admin, reset bool, session *string) *GetUserResponseBody { - response := httptest.NewRecorder() - body, err := json.Marshal(CreateUserRequestBody{ - Username: username, - Password: password, - Email: email, - Admin: admin, - NeedsPasswordReset: reset, - }) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("PUT", "/user", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return nil - } - *session = response.Header().Get("Set-Cookie") - - var responseBody GetUserResponseBody - err = json.Unmarshal(response.Body.Bytes(), &responseBody) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - return &responseBody -} - -func (req *integrationRequests) GetUser(username string, session *string) *GetUserResponseBody { - response := httptest.NewRecorder() - body, err := json.Marshal(GetUserRequestBody{ - Username: username, - }) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("POST", "/user", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return nil - } - *session = response.Header().Get("Set-Cookie") - - var responseBody GetUserResponseBody - err = json.Unmarshal(response.Body.Bytes(), &responseBody) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - return &responseBody -} - -func (req *integrationRequests) ModifyUser(username, name, email string, session *string) *GetUserResponseBody { - response := httptest.NewRecorder() - body, err := json.Marshal(ModifyUserRequestBody{ - Username: username, - Name: name, - Email: email, - }) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("PATCH", "/user", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return nil - } - *session = response.Header().Get("Set-Cookie") - - var responseBody GetUserResponseBody - err = json.Unmarshal(response.Body.Bytes(), &responseBody) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - return &responseBody -} - -func (req *integrationRequests) DeleteUser(username string, session *string) { - response := httptest.NewRecorder() - body, err := json.Marshal(DeleteUserRequestBody{ - Username: username, - }) - if err != nil { - req.t.Fatal(err.Error()) - return - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("DELETE", "/user", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return - } - *session = response.Header().Get("Set-Cookie") - - return -} - -func (req *integrationRequests) ChangePassword(username, currentPassword, newPassword string, session *string) *GetUserResponseBody { - response := httptest.NewRecorder() - body, err := json.Marshal(ChangePasswordRequestBody{ - Username: username, - CurrentPassword: currentPassword, - NewPassword: newPassword, - NewPasswordConfim: newPassword, - }) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("PATCH", "/user/password", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return nil - } - *session = response.Header().Get("Set-Cookie") - - var responseBody GetUserResponseBody - err = json.Unmarshal(response.Body.Bytes(), &responseBody) - if err != nil { - req.t.Fatal(err.Error()) - } - - return &responseBody -} - -func (req *integrationRequests) SetAdminState(username string, admin bool, session *string) *GetUserResponseBody { - response := httptest.NewRecorder() - body, err := json.Marshal(SetUserAdminStateRequestBody{ - Username: username, - Admin: admin, - }) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("PATCH", "/user/admin", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return nil - } - *session = response.Header().Get("Set-Cookie") - - var responseBody GetUserResponseBody - err = json.Unmarshal(response.Body.Bytes(), &responseBody) - if err != nil { - req.t.Fatal(err.Error()) - } - - return &responseBody -} - -func (req *integrationRequests) SetEnabledState(username string, enabled bool, session *string) *GetUserResponseBody { - response := httptest.NewRecorder() - body, err := json.Marshal(SetUserEnabledStateRequestBody{ - Username: username, - Enabled: enabled, - }) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - buff := new(bytes.Buffer) - buff.Write(body) - request, _ := http.NewRequest("PATCH", "/user/enabled", buff) - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Cookie", *session) - req.r.ServeHTTP(response, request) - - if response.Code != 200 { - req.t.Fatalf("Response code: %d", response.Code) - return nil - } - *session = response.Header().Get("Set-Cookie") - - var responseBody GetUserResponseBody - err = json.Unmarshal(response.Body.Bytes(), &responseBody) - if err != nil { - req.t.Fatal(err.Error()) - return nil - } - - return &responseBody -} - -func (req *integrationRequests) CheckUser(username, email, name string, admin, reset, enabled bool) { - var user User - err := req.db.Where("username = ?", username).First(&user).Error - if err != nil { - req.t.Fatal(err.Error()) - return - } - if user.Email != email { - req.t.Fatalf("user's email was not set in the DB: %s", user.Email) - } - if user.Admin != admin { - req.t.Fatal("user admin settings don't match") - } - if user.NeedsPasswordReset != reset { - req.t.Fatal("user reset settings don't match") - } - if user.Enabled != enabled { - req.t.Fatal("user enabled settings don't match") - } - if user.Name != name { - req.t.Fatalf("user names don't match: %s and %s", user.Name, name) - } - return -} - -func (req *integrationRequests) GetAndCheckUser(username string, session *string) { - resp := req.GetUser(username, session) - req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, resp.Enabled) -} - -func (req *integrationRequests) CreateAndCheckUser(username, password, email, name string, admin, reset bool, session *string) { - resp := req.CreateUser(username, password, email, admin, reset, session) - req.CheckUser(username, email, name, admin, reset, resp.Enabled) -} - -func (req *integrationRequests) ModifyAndCheckUser(username, email, name string, admin, reset bool, session *string) { - resp := req.ModifyUser(username, name, email, session) - req.CheckUser(username, email, name, admin, reset, resp.Enabled) -} - -func (req *integrationRequests) DeleteAndCheckUser(username string, session *string) { - req.DeleteUser(username, session) - - var user User - err := req.db.Where("username = ?", username).First(&user).Error - if err == nil { - req.t.Fatal("User should have been deleted.") - } -} - -func (req *integrationRequests) SetEnabledStateAndCheckUser(username string, enabled bool, session *string) { - resp := req.SetEnabledState(username, enabled, session) - req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, enabled) -} - -func (req *integrationRequests) SetAdminStateAndCheckUser(username string, admin bool, session *string) { - resp := req.SetAdminState(username, admin, session) - req.CheckUser(username, resp.Email, resp.Name, admin, resp.NeedsPasswordReset, resp.Enabled) -} - func TestUserAndAccountManagement(t *testing.T) { // Create and configure the webserver which will be used to handle the tests - var req integrationRequests + var req IntegrationRequests req.New(t) // Instantiate the variables that will store the most recent session cookie diff --git a/test_util.go b/test_util.go new file mode 100644 index 0000000000..4937d749ae --- /dev/null +++ b/test_util.go @@ -0,0 +1,364 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/gorilla/sessions" + "github.com/jinzhu/gorm" +) + +const testSessionName = "TestSession" + +func getTestStore() sessions.Store { + return sessions.NewCookieStore([]byte("test")) +} + +func testSessionMiddleware(c *gin.Context) { + CreateSession(testSessionName, getTestStore())(c) +} + +type IntegrationRequests struct { + r *gin.Engine + db *gorm.DB + t *testing.T +} + +func (req *IntegrationRequests) New(t *testing.T) { + req.t = t + + *debug = false + req.db = openTestDB() + + // Until we have a better solution for first-user onboarding, manually + // create an admin + _, err := NewUser(req.db, "admin", "foobar", "admin@kolide.co", true, false) + if err != nil { + t.Fatalf("Error opening DB: %s", err.Error()) + } + + req.r = CreateServer(req.db) +} + +func (req *IntegrationRequests) Login(username, password string, sessionOut *string) { + response := httptest.NewRecorder() + body, err := json.Marshal(LoginRequestBody{ + Username: username, + Password: password, + }) + if err != nil { + req.t.Fatal(err.Error()) + return + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("POST", "/api/v1/kolide/login", buff) + request.Header.Set("Content-Type", "application/json") + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return + } + *sessionOut = response.Header().Get("Set-Cookie") + + return +} + +func (req *IntegrationRequests) CreateUser(username, password, email string, admin, reset bool, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(CreateUserRequestBody{ + Username: username, + Password: password, + Email: email, + Admin: admin, + NeedsPasswordReset: reset, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PUT", "/api/v1/kolide/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *IntegrationRequests) GetUser(username string, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(GetUserRequestBody{ + Username: username, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("POST", "/api/v1/kolide/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *IntegrationRequests) ModifyUser(username, name, email string, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(ModifyUserRequestBody{ + Username: username, + Name: name, + Email: email, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *IntegrationRequests) DeleteUser(username string, session *string) { + response := httptest.NewRecorder() + body, err := json.Marshal(DeleteUserRequestBody{ + Username: username, + }) + if err != nil { + req.t.Fatal(err.Error()) + return + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("DELETE", "/api/v1/kolide/user", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return + } + *session = response.Header().Get("Set-Cookie") + + return +} + +func (req *IntegrationRequests) ChangePassword(username, currentPassword, newPassword string, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(ChangePasswordRequestBody{ + Username: username, + CurrentPassword: currentPassword, + NewPassword: newPassword, + NewPasswordConfim: newPassword, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user/password", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + } + + return &responseBody +} + +func (req *IntegrationRequests) SetAdminState(username string, admin bool, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(SetUserAdminStateRequestBody{ + Username: username, + Admin: admin, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user/admin", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + } + + return &responseBody +} + +func (req *IntegrationRequests) SetEnabledState(username string, enabled bool, session *string) *GetUserResponseBody { + response := httptest.NewRecorder() + body, err := json.Marshal(SetUserEnabledStateRequestBody{ + Username: username, + Enabled: enabled, + }) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + buff := new(bytes.Buffer) + buff.Write(body) + request, _ := http.NewRequest("PATCH", "/api/v1/kolide/user/enabled", buff) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Cookie", *session) + req.r.ServeHTTP(response, request) + + if response.Code != 200 { + req.t.Fatalf("Response code: %d", response.Code) + return nil + } + *session = response.Header().Get("Set-Cookie") + + var responseBody GetUserResponseBody + err = json.Unmarshal(response.Body.Bytes(), &responseBody) + if err != nil { + req.t.Fatal(err.Error()) + return nil + } + + return &responseBody +} + +func (req *IntegrationRequests) CheckUser(username, email, name string, admin, reset, enabled bool) { + var user User + err := req.db.Where("username = ?", username).First(&user).Error + if err != nil { + req.t.Fatal(err.Error()) + return + } + if user.Email != email { + req.t.Fatalf("user's email was not set in the DB: %s", user.Email) + } + if user.Admin != admin { + req.t.Fatal("user admin settings don't match") + } + if user.NeedsPasswordReset != reset { + req.t.Fatal("user reset settings don't match") + } + if user.Enabled != enabled { + req.t.Fatal("user enabled settings don't match") + } + if user.Name != name { + req.t.Fatalf("user names don't match: %s and %s", user.Name, name) + } + return +} + +func (req *IntegrationRequests) GetAndCheckUser(username string, session *string) { + resp := req.GetUser(username, session) + req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, resp.Enabled) +} + +func (req *IntegrationRequests) CreateAndCheckUser(username, password, email, name string, admin, reset bool, session *string) { + resp := req.CreateUser(username, password, email, admin, reset, session) + req.CheckUser(username, email, name, admin, reset, resp.Enabled) +} + +func (req *IntegrationRequests) ModifyAndCheckUser(username, email, name string, admin, reset bool, session *string) { + resp := req.ModifyUser(username, name, email, session) + req.CheckUser(username, email, name, admin, reset, resp.Enabled) +} + +func (req *IntegrationRequests) DeleteAndCheckUser(username string, session *string) { + req.DeleteUser(username, session) + + var user User + err := req.db.Where("username = ?", username).First(&user).Error + if err == nil { + req.t.Fatal("User should have been deleted.") + } +} + +func (req *IntegrationRequests) SetEnabledStateAndCheckUser(username string, enabled bool, session *string) { + resp := req.SetEnabledState(username, enabled, session) + req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, enabled) +} + +func (req *IntegrationRequests) SetAdminStateAndCheckUser(username string, admin bool, session *string) { + resp := req.SetAdminState(username, admin, session) + req.CheckUser(username, resp.Email, resp.Name, admin, resp.NeedsPasswordReset, resp.Enabled) +} diff --git a/users.go b/users.go index 04c5b05204..83f5cacca4 100644 --- a/users.go +++ b/users.go @@ -103,12 +103,7 @@ func GetUser(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) vc, err := VC(c, db) if err != nil { @@ -158,12 +153,7 @@ func CreateUser(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) vc, err := VC(c, db) if err != nil { @@ -210,12 +200,7 @@ func ModifyUser(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) vc, err := VC(c, db) if err != nil { @@ -275,12 +260,7 @@ func DeleteUser(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) vc, err := VC(c, db) if err != nil { @@ -340,12 +320,7 @@ func ChangeUserPassword(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) var user User user.ID = body.ID @@ -411,12 +386,7 @@ func SetUserAdminState(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) vc, err := VC(c, db) if err != nil { @@ -471,12 +441,7 @@ func SetUserEnabledState(c *gin.Context) { return } - db, err := GetDB(c) - if err != nil { - logrus.Errorf("Could not open database: %s", err.Error()) - DatabaseError(c) - return - } + db := GetDB(c) vc, err := VC(c, db) if err != nil { diff --git a/users_test.go b/users_test.go index 83b22294b2..c55e495f77 100644 --- a/users_test.go +++ b/users_test.go @@ -7,10 +7,7 @@ import ( ) func TestNewUser(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false) if err != nil { @@ -37,10 +34,7 @@ func TestNewUser(t *testing.T) { } func TestValidatePassword(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false) if err != nil { @@ -59,10 +53,7 @@ func TestValidatePassword(t *testing.T) { } func TestMakeAdmin(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil { @@ -96,10 +87,7 @@ func TestMakeAdmin(t *testing.T) { } func TestUpdatingUser(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil { @@ -129,10 +117,7 @@ func TestUpdatingUser(t *testing.T) { } func TestDeletingUser(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil { @@ -161,10 +146,7 @@ func TestDeletingUser(t *testing.T) { } func TestSetPassword(t *testing.T) { - db, err := openTestDB() - if err != nil { - t.Fatal(err.Error()) - } + db := openTestDB() user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) if err != nil {