From eee370e12786857e638e470fc234602b2c387f67 Mon Sep 17 00:00:00 2001 From: Mike Arpaia Date: Mon, 1 Aug 2016 16:32:20 -0700 Subject: [PATCH] Authentication, authorization and user management (#10) --- README.md | 8 + auth.go | 302 ++++++++++++++++++++++++++ auth_test.go | 198 +++++++++++++++++ config.go | 62 +++++- kolide.go | 42 +++- models.go | 66 +++++- server.go | 57 ++++- sessions.go | 90 ++++++++ sessions_test.go | 179 ++++++++++++++++ tools/example_config.json | 5 + users.go | 435 ++++++++++++++++++++++++++++++++++++++ users_test.go | 309 +++++++++++++++++++++++++++ 12 files changed, 1716 insertions(+), 37 deletions(-) create mode 100644 auth.go create mode 100644 auth_test.go create mode 100644 sessions.go create mode 100644 sessions_test.go create mode 100644 users.go create mode 100644 users_test.go diff --git a/README.md b/README.md index 88c5a628d8..2d2b74aa36 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,14 @@ To build the code, run the following from the root of the repository: go build ``` +## Testing + +To run the application's tests, run the following from the root of the repository: + +``` +go test +``` + ## Development Environment To set up the development environment via docker, run the following frmo the root of the repository: diff --git a/auth.go b/auth.go new file mode 100644 index 0000000000..dfbe9190b5 --- /dev/null +++ b/auth.go @@ -0,0 +1,302 @@ +package main + +import ( + "errors" + "fmt" + "math/rand" + "time" + + "github.com/Sirupsen/logrus" + "github.com/dgrijalva/jwt-go" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "golang.org/x/crypto/bcrypt" +) + +// ViewerContext is a struct which represents the ability for an execution +// context to participate in certain actions. Most often, a ViewerContext is +// associated with an application user, but a ViewerContext can represent a +// variety of other execution contexts as well (script, test, etc). The main +// purpose of a ViewerContext is to assist in the authorization of sensitive +// actions. +type ViewerContext struct { + user *User +} + +// JWT returns a JWT token in serialized string form given a ViewerContext as +// well as a potential error in the event that things have gone wrong. +func (vc *ViewerContext) JWT() (string, error) { + return GenerateJWT(vc.user.ID) +} + +// IsAdmin indicates whether or not the current user can perform administrative +// actions. +func (vc *ViewerContext) IsAdmin() bool { + if vc.user != nil { + return vc.user.Admin && vc.user.Enabled + } + return false +} + +// UserID is a helper that enables quick access to the user ID of the current +// user. +func (vc *ViewerContext) UserID() (uint, error) { + if vc.user != nil { + return vc.user.ID, nil + } + return 0, errors.New("No user set") +} + +func (vc *ViewerContext) CanPerformActions(db *gorm.DB) bool { + if vc.user == nil { + return false + } + + if !vc.user.Enabled { + return false + } + + return true +} + +func (vc *ViewerContext) IsUserID(id uint) bool { + userID, err := vc.UserID() + if err != nil { + return false + } + if userID == id { + return true + } + return false +} + +func (vc *ViewerContext) CanPerformWriteActionOnUser(db *gorm.DB, u *User) bool { + return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin()) +} + +func (vc *ViewerContext) CanPerformReadActionOnUser(db *gorm.DB, u *User) bool { + return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin()) +} + +// GenerateJWT generates a JWT token in serialized string form given a +// ViewerContext as well as a potential error in the event that things have +// gone wrong. +func GenerateJWT(userID uint) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": userID, + // "Not Before": https://tools.ietf.org/html/rfc7519#section-4.1.5 + "nbf": time.Now().UTC().Unix(), + // "Expiration Time": https://tools.ietf.org/html/rfc7519#section-4.1.4 + "exp": time.Now().UTC().AddDate(0, 2, 0).Unix(), + }) + + return token.SignedString([]byte(config.App.JWTKey)) +} + +// ParseJWT attempts to parse a JWT token in serialized string form into a +// JWT token in a deserialized jwt.Token struct. +func ParseJWT(token string) (*jwt.Token, error) { + return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { + method, ok := t.Method.(*jwt.SigningMethodHMAC) + if !ok || method != jwt.SigningMethodHS256 { + return nil, errors.New("Unexpected signing method") + } + return []byte(config.App.JWTKey), nil + }) +} + +// JWTRenewalMiddleware optimistically tries to renew the user's JWT token. +// This allows kolide to have sessions that last forever, assuming that a user +// logs in and uses the application within a reasonable time window (which is +// defined in the JWT token generation method). If anything goes wrong, this +// middleware will back off and defer recovery of the situation to the +// downstream web request. +func JWTRenewalMiddleware(c *gin.Context) { + session := GetSession(c) + tokenCookie := session.Get("jwt") + if tokenCookie == nil { + c.Next() + return + } + + tokenString, ok := tokenCookie.(string) + if !ok { + c.Next() + return + } + + token, err := ParseJWT(tokenString) + if err != nil { + c.Next() + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + + if !ok || !token.Valid { + c.Next() + return + } + + userID := uint(claims["user_id"].(float64)) + + jwt, err := GenerateJWT(userID) + if err != nil { + c.Next() + return + } + + session.Set("jwt", jwt) + + c.Next() +} + +// GenerateVC generates a ViewerContext given a user struct +func GenerateVC(user *User) *ViewerContext { + return &ViewerContext{ + user: user, + } +} + +// EmptyVC is a utility which generates an empty ViewerContext. This is often +// used to represent users which are not logged in. +func EmptyVC() *ViewerContext { + return &ViewerContext{ + user: nil, + } +} + +// VC accepts a web request context and a database handler and attempts +// to parse a user's jwt token out of the active session, validate the token, +// and generate an appropriate ViewerContext given the data in the session. +func VC(c *gin.Context, db *gorm.DB) (*ViewerContext, error) { + session := GetSession(c) + tokenCookie := session.Get("jwt") + if tokenCookie == nil { + return nil, errors.New("jwt session attribute not set") + } + + tokenString, ok := tokenCookie.(string) + if !ok { + return nil, errors.New("jwt token was not string") + } + + token, err := ParseJWT(tokenString) + if err != nil { + return nil, err + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, errors.New("Invalid token") + } + + userID := uint(claims["user_id"].(float64)) + var user User + err = db.Where("id = ?", userID).First(&user).Error + if err != nil { + return nil, err + } + + return GenerateVC(&user), nil + +} + +type LoginRequestBody struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +func Login(c *gin.Context) { + var body LoginRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing Login post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + var user User + err = db.Where("username = ?", body.Username).First(&user).Error + if err != nil { + logrus.Debugf("User not found: %s", body.Username) + UnauthorizedError(c) + return + } + + err = user.ValidatePassword(body.Password) + if err != nil { + logrus.Debugf("Invalid password for user: %s", body.Username) + UnauthorizedError(c) + return + } + + token, err := GenerateVC(&user).JWT() + if err != nil { + logrus.Fatalf("Error generating token: %s", err.Error()) + DatabaseError(c) + return + } + session := GetSession(c) + session.Set("jwt", token) + session.Save() + + c.JSON(200, map[string]interface{}{ + "id": user.ID, + "username": user.Username, + "email": user.Email, + "name": user.Name, + "admin": user.Admin, + }) +} + +func Logout(c *gin.Context) { + session := GetSession(c) + session.Clear() + c.JSON(200, nil) +} + +const ( + letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + letterIndex = 6 // 6 bits to represent a letter index + letterIndexMask = 1<= 0; { + if remain == 0 { + cache, remain = psrngSource.Int63(), letterIndexMax + } + if idx := int(cache & letterIndexMask); idx < len(letters) { + text[i] = letters[idx] + i-- + } + cache >>= letterIndex + remain-- + } + + return string(text) +} + +func HashPassword(salt, password string) ([]byte, error) { + return bcrypt.GenerateFromPassword( + []byte(fmt.Sprintf("%s%s", salt, password)), + config.App.BcryptCost, + ) +} + +func SaltAndHashPassword(password string) (string, []byte, error) { + salt := generateRandomText(config.App.SaltLength) + hashed, err := HashPassword(salt, password) + return salt, hashed, err +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000000..2a60da0358 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,198 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + "unicode/utf8" + + "github.com/dgrijalva/jwt-go" + "github.com/gin-gonic/gin" +) + +func TestGenerateRandomText(t *testing.T) { + text := generateRandomText(12) + if utf8.RuneCountInString(text) != 12 { + t.Fatal("generateRandomText generated the wrong length string") + } +} + +func TestGenerateVC(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false) + if err != nil { + t.Fatal(err.Error()) + } + + tokenString, err := GenerateVC(user).JWT() + if err != nil { + t.Fatal(err.Error()) + } + + token, err := ParseJWT(tokenString) + if err != nil { + t.Fatal(err.Error()) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + t.Fatal("Token is invalid") + } + + userID := uint(claims["user_id"].(float64)) + if userID != user.ID { + t.Fatal("Claims are incorrect. userID is %d", userID) + } +} + +func TestVC(t *testing.T) { + r := createTestServer() + r.Use(testSessionMiddleware) + r.Use(JWTRenewalMiddleware) + + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false) + if err != nil { + t.Fatal(err.Error()) + } + + r.GET("/admin_login", func(c *gin.Context) { + token, err := GenerateVC(admin).JWT() + if err != nil { + t.Fatal(err.Error()) + } + session := GetSession(c) + session.Set("jwt", token) + session.Save() + c.JSON(200, nil) + }) + + r.GET("/user_login", func(c *gin.Context) { + token, err := GenerateVC(user).JWT() + if err != nil { + t.Fatal(err.Error()) + } + session := GetSession(c) + session.Set("jwt", token) + session.Save() + c.JSON(200, nil) + }) + + r.GET("/admin", func(c *gin.Context) { + vc, err := VC(c, db) + if err != nil { + t.Fatal(err.Error()) + } + if !vc.IsAdmin() { + t.Fatal("Not admin") + } + c.String(200, "OK") + }) + + r.GET("/user", func(c *gin.Context) { + vc, err := VC(c, db) + if err != nil { + t.Fatal(err.Error()) + } + if vc.IsAdmin() { + t.Fatal("Not user") + } + c.String(200, "OK") + }) + + res1 := httptest.NewRecorder() + req1, _ := http.NewRequest("GET", "/admin_login", nil) + r.ServeHTTP(res1, req1) + + res2 := httptest.NewRecorder() + req2, _ := http.NewRequest("GET", "/admin", nil) + req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res2, req2) + + res3 := httptest.NewRecorder() + req3, _ := http.NewRequest("GET", "/user_login", nil) + r.ServeHTTP(res3, req3) + + res4 := httptest.NewRecorder() + req4, _ := http.NewRequest("GET", "/user", nil) + req4.Header.Set("Cookie", res3.Header().Get("Set-Cookie")) + r.ServeHTTP(res4, req4) + +} + +func TestIsUserID(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err) + } + + user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + vc := GenerateVC(user1) + + if !vc.IsUserID(user1.ID) { + t.Fatal("IsUserID failed on same user object") + } + + if vc.IsUserID(user1.ID + 1) { + t.Fatal("IsUserID passed for incorrect ID") + } +} + +func TestCanPerformActionsOnUser(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err) + } + + user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + user2, err := NewUser(db, "zwass", "foobar", "zwass@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false) + if err != nil { + t.Fatal(err.Error()) + } + + adminVC := GenerateVC(admin) + user1VC := GenerateVC(user1) + + if !adminVC.CanPerformWriteActionOnUser(db, user1) || !adminVC.CanPerformWriteActionOnUser(db, user2) { + t.Fatal("Admin should be able to perform writes on users") + } + + if !adminVC.CanPerformReadActionOnUser(db, user1) || !adminVC.CanPerformReadActionOnUser(db, user2) { + t.Fatal("Admin should be able to perform reads on users") + } + + if user1VC.CanPerformWriteActionOnUser(db, user2) { + t.Fatal("user1 shouldn't be able to perform writes on user2") + } + + if user1VC.CanPerformReadActionOnUser(db, user2) { + t.Fatal("user1 should be able to perform reads on user2") + } + +} diff --git a/config.go b/config.go index 9a81632e08..7c0ae239f9 100644 --- a/config.go +++ b/config.go @@ -5,24 +5,64 @@ import ( "io/ioutil" ) +type mysqlConfigData struct { + Address string `json:"address"` + Username string `json:"username"` + Password string `json:"password"` + Database string `json:"database"` +} + +type serverConfigData struct { + Address string `json:"address"` + Cert string `json:"cert"` + Key string `json:"key"` +} + +type appConfigData struct { + BcryptCost int `json:"bcrypt_cost"` + SaltLength int `json:"salt_length"` + JWTKey string `json:"jwt_key"` +} + type configData struct { - MySQL struct { - Address string `json:"address"` - Username string `json:"username"` - Password string `json:"password"` - Database string `json:"database"` - } `json:"mysql"` - Server struct { - Address string `json:"address"` - Cert string `json:"cert"` - Key string `json:"key"` - } `json:"server"` + MySQL mysqlConfigData `json:"mysql"` + Server serverConfigData `json:"server"` + App appConfigData `json:"app"` } var ( config configData ) +var defaultMysqlConfigData = mysqlConfigData{ + Address: "127.0.0.1:3306", + Username: "kolide", + Password: "kolide", + Database: "kolide", +} + +var defaultServerConfigData = serverConfigData{ + Address: ":8080", + Cert: "./tools/kolide.crt", + Key: "./tools/kolide.key", +} + +var defaultAppConfigData = appConfigData{ + BcryptCost: 12, + SaltLength: 32, + JWTKey: "very secure", +} + +var defaultConfigData = configData{ + MySQL: defaultMysqlConfigData, + Server: defaultServerConfigData, + App: defaultAppConfigData, +} + +func setDefaultConfigValues() { + config = defaultConfigData +} + func loadConfig(path string) error { content, err := ioutil.ReadFile(path) if err != nil { diff --git a/kolide.go b/kolide.go index e86cdee093..e0de5bc4f9 100644 --- a/kolide.go +++ b/kolide.go @@ -2,11 +2,14 @@ package main import ( "fmt" + "math/rand" "os" "path" "runtime" + "time" "github.com/Sirupsen/logrus" + "github.com/gin-gonic/gin" "gopkg.in/alecthomas/kingpin.v2" ) @@ -39,12 +42,27 @@ var ( serve = app.Command("serve", "Run the Kolide server") ) +func init() { + // set gin mode to release to silence some superfluous logging + gin.SetMode(gin.ReleaseMode) + + // configure logging + logrus.AddHook(logContextHook{}) + + // populate the global config data structure with sane defaults + setDefaultConfigValues() +} + +// logContextHook is a logrus hook which is used to contextualize application +// logs to include data stuch as line numbers, file names, etc. type logContextHook struct{} +// Levels defines which levels the logContextHook logrus hook should apply to func (hook logContextHook) Levels() []logrus.Level { return logrus.AllLevels } +// Fire defines what the logContextHook should actually do when it is triggered func (hook logContextHook) Fire(entry *logrus.Entry) error { if pc, file, line, ok := runtime.Caller(8); ok { funcName := runtime.FuncForPC(pc).Name() @@ -57,14 +75,15 @@ func (hook logContextHook) Fire(entry *logrus.Entry) error { return nil } +func init() { + rand.Seed(time.Now().UnixNano()) +} + func main() { // configure flag parsing and parse flags app.Version(version) args, err := app.Parse(os.Args[1:]) - // configure logging - logrus.AddHook(logContextHook{}) - // configure the application based on the flags that have been set if *debug { logrus.SetLevel(logrus.DebugLevel) @@ -76,18 +95,21 @@ func main() { // if config hasn't been defined and the example config exists relative to // the binary, it's likely that the tool is being ran right after building - // from source so we auto-populate the config path. + // from source so we auto-populate the example config path. if *configPath == "" { if _, err = os.Stat("./tools/example_config.json"); err == nil { *configPath = "./tools/example_config.json" - } else { - logrus.Fatalln("No config file specified. Use --config to specify config path.") } } - err = loadConfig(*configPath) - if err != nil { - logrus.Fatalf("Error loading config: %s", err.Error()) + // if the user has defined a config path OR the example config is found + // relative to the binary, load config content from the file. any content + // in the config file will overwrite the default values + if *configPath != "" { + err = loadConfig(*configPath) + if err != nil { + logrus.Fatalf("Error loading config: %s", err.Error()) + } } // route the executable based on the sub-command @@ -100,7 +122,7 @@ func main() { dropTables(db) createTables(db) case serve.FullCommand(): - createServer().RunTLS( + CreateServer().RunTLS( config.Server.Address, config.Server.Cert, config.Server.Key) diff --git a/models.go b/models.go index 04419234fb..8c9d650d31 100644 --- a/models.go +++ b/models.go @@ -1,13 +1,49 @@ package main import ( + "errors" "fmt" "time" "github.com/jinzhu/gorm" + + "github.com/gin-gonic/gin" _ "github.com/jinzhu/gorm/dialects/mysql" + _ "github.com/jinzhu/gorm/dialects/sqlite" ) +type DBEnvironment int + +const ( + ProductionDB DBEnvironment = iota + TestingDB DBEnvironment = iota +) + +var injectedTestDB *gorm.DB + +func GetDB(c *gin.Context) (*gorm.DB, error) { + f, ok := c.Get("DB") + if !ok { + return nil, errors.New("DB was not set on the supplied *gin.Context. Use a middleware to set it.") + } + switch f.(DBEnvironment) { + case ProductionDB: + return openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database) + case TestingDB: + if injectedTestDB != nil { + return injectedTestDB, nil + } + db, err := openTestDB() + if err != nil { + return nil, errors.New("Could not open a test database") + } + injectedTestDB = db + return injectedTestDB, nil + default: + return nil, errors.New("GetDB not implemented for DBEnvironment object") + } +} + type BaseModel struct { ID uint `gorm:"primary_key"` CreatedAt time.Time @@ -15,16 +51,6 @@ type BaseModel struct { DeletedAt *time.Time `sql:"index"` } -type User struct { - BaseModel - Username string `gorm:"not null;unique_index:idx_user_unique_username"` - Password string `gorm:"not null"` - Salt string `gorm:"not null"` - Name string - Email string `gorm:"unique_index:idx_user_unique_email"` - Admin bool `gorm:"not null"` -} - type ScheduledQuery struct { BaseModel Name string `gorm:"not null"` @@ -162,6 +188,26 @@ func openDB(user, password, address, dbName string) (*gorm.DB, error) { return gorm.Open("mysql", connectionString) } +func openTestDB() (*gorm.DB, error) { + db, err := gorm.Open("sqlite3", ":memory:") + if err != nil { + return nil, err + } + + createTables(db) + return db, db.Error +} + +func ProductionDatabaseMiddleware(c *gin.Context) { + c.Set("DB", ProductionDB) + c.Next() +} + +func TestingDatabaseMiddleware(c *gin.Context) { + c.Set("DB", TestingDB) + c.Next() +} + func dropTables(db *gorm.DB) { for _, table := range tables { db.DropTableIfExists(table) diff --git a/server.go b/server.go index 15963dd475..2e507124bd 100644 --- a/server.go +++ b/server.go @@ -4,8 +4,57 @@ import ( "github.com/gin-gonic/gin" ) -func attachRoutes(router *gin.Engine) { - v1 := router.Group("/api/v1") +// ServerError is a helper which accepts a string error and returns a map in +// format that is required by gin.Context.JSON +func ServerError(e string) *map[string]interface{} { + return &map[string]interface{}{ + "error": e, + } +} + +// DatabaseError emits a response that is appropriate in the event that a +// database failure occurs, a record is not found in the database, etc +func DatabaseError(c *gin.Context) { + c.JSON(500, ServerError("Database error")) +} + +// UnauthorizedError emits a response that is appropriate in the event that a +// request is received by a user which is not authorized to carry out the +// requested action +func UnauthorizedError(c *gin.Context) { + c.JSON(401, ServerError("Unauthorized")) +} + +func createTestServer() *gin.Engine { + server := gin.New() + server.Use(TestingDatabaseMiddleware) + return server +} + +// CreateServer creates a gin.Engine HTTP server and configures it to be in a +// state such that it is ready to serve HTTP requests for the kolide application +func CreateServer() *gin.Engine { + server := gin.New() + server.Use(ProductionDatabaseMiddleware) + + v1 := server.Group("/api/v1") + + // Kolide application API endpoints + kolide := v1.Group("/kolide") + kolide.Use(SessionMiddleware) + kolide.Use(JWTRenewalMiddleware) + + kolide.POST("/login", Login) + kolide.GET("/logout", Logout) + + kolide.GET("/user", GetUser) + kolide.PUT("/user", CreateUser) + kolide.PATCH("/user", ModifyUser) + kolide.DELETE("/user", DeleteUser) + + kolide.PATCH("/user/password", ResetUserPassword) + kolide.PATCH("/user/admin", SetUserAdminState) + kolide.PATCH("/user/enabled", SetUserEnabledState) // osquery API endpoints osquery := v1.Group("/osquery") @@ -14,10 +63,6 @@ func attachRoutes(router *gin.Engine) { osquery.POST("/log", OsqueryLog) osquery.POST("/distributed/read", OsqueryDistributedRead) osquery.POST("/distributed/write", OsqueryDistributedWrite) -} -func createServer() *gin.Engine { - server := gin.New() - attachRoutes(server) return server } diff --git a/sessions.go b/sessions.go new file mode 100644 index 0000000000..00719409bb --- /dev/null +++ b/sessions.go @@ -0,0 +1,90 @@ +package main + +import ( + "net/http" + + "github.com/Sirupsen/logrus" + "github.com/gin-gonic/gin" + "github.com/gorilla/context" + "github.com/gorilla/sessions" +) + +// GetSession allows you to get the Session object given a web request. This +// is often used in HTTP handlers as the main entry point into managing and +// manipulating the session +func GetSession(c *gin.Context) *Session { + return c.MustGet("Session").(*Session) +} + +// SessionMiddleware is the middleware used for production session management. +// Tests should use `testSessionMiddleware`, which follows the same pattern, +// but creates a session configured for testing. +func SessionMiddleware(c *gin.Context) { + CreateSession("Session", sessions.NewCookieStore([]byte("c")))(c) +} + +// CreateSessions is a helper which returns a gin.HandlerFunc which creates +// a new session management middleware given the name of the session to manage +// and the session storage mechanism. This is commonly used to generate session +// middleware given a variety of settings in both production and testing +// environments +func CreateSession(name string, store sessions.Store) gin.HandlerFunc { + return func(c *gin.Context) { + s := &Session{name, c.Request, store, nil, c.Writer} + c.Set("Session", s) + defer context.Clear(c.Request) + c.Next() + } +} + +// Session is a convenience wrapper around gorilla sessions, which is provided +// by github.com/gorilla/sessions +type Session struct { + name string + request *http.Request + store sessions.Store + session *sessions.Session + writer http.ResponseWriter +} + +// Session returns the gorilla session from the Session struct and allows you +// to use any of the functionality of the underlying sessions.Session struct +func (s *Session) Session() *sessions.Session { + if s.session == nil { + var err error + s.session, err = s.store.Get(s.request, s.name) + if err != nil { + logrus.Error(err.Error()) + } + } + return s.session +} + +// Set simply sets a session key value pair which will be stored in the +// current session for later usage +func (s *Session) Set(key interface{}, val interface{}) { + s.Session().Values[key] = val +} + +// Get retrieves a session key value pair which has previously been set +func (s *Session) Get(key interface{}) interface{} { + return s.Session().Values[key] +} + +// Delete deletes a session key value pair which has previously been set +func (s *Session) Delete(key interface{}) { + delete(s.Session().Values, key) +} + +// Clear deletes all session key value pairs that are set +func (s *Session) Clear() { + for key := range s.Session().Values { + s.Delete(key) + } +} + +// Save writes the session, which is required after altering the session in any +// way +func (s *Session) Save() error { + return s.Session().Save(s.request, s.writer) +} diff --git a/sessions_test.go b/sessions_test.go new file mode 100644 index 0000000000..dca8e28e42 --- /dev/null +++ b/sessions_test.go @@ -0,0 +1,179 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/gorilla/sessions" +) + +const testSessionName = "TestSession" + +func getTestStore() sessions.Store { + return sessions.NewCookieStore([]byte("test")) +} + +func testSessionMiddleware(c *gin.Context) { + CreateSession(testSessionName, getTestStore())(c) +} + +func TestSessionGetSet(t *testing.T) { + r := createTestServer() + r.Use(testSessionMiddleware) + r.Use(JWTRenewalMiddleware) + + r.GET("/set", func(c *gin.Context) { + session := GetSession(c) + session.Set("key", "foobar") + session.Save() + c.JSON(200, nil) + }) + + r.GET("/get", func(c *gin.Context) { + session := GetSession(c) + if session.Get("key") != "foobar" { + t.Fatal("Session writing failed") + } + c.String(200, "OK") + }) + + res1 := httptest.NewRecorder() + req1, _ := http.NewRequest("GET", "/set", nil) + r.ServeHTTP(res1, req1) + + res2 := httptest.NewRecorder() + req2, _ := http.NewRequest("GET", "/get", nil) + req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res2, req2) +} + +func TestSessionDeleteKey(t *testing.T) { + r := createTestServer() + r.Use(testSessionMiddleware) + r.Use(JWTRenewalMiddleware) + + r.GET("/set", func(c *gin.Context) { + session := GetSession(c) + session.Set("key", "foobar") + session.Save() + c.JSON(200, nil) + }) + + r.GET("/delete", func(c *gin.Context) { + session := GetSession(c) + session.Delete("key") + session.Save() + c.JSON(200, nil) + }) + + r.GET("/get", func(c *gin.Context) { + session := GetSession(c) + if session.Get("key") != nil { + t.Fatal("Session deleting failed") + } + c.JSON(200, nil) + }) + + res1 := httptest.NewRecorder() + req1, _ := http.NewRequest("GET", "/set", nil) + r.ServeHTTP(res1, req1) + + res2 := httptest.NewRecorder() + req2, _ := http.NewRequest("GET", "/delete", nil) + req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res2, req2) + + res3 := httptest.NewRecorder() + req3, _ := http.NewRequest("GET", "/get", nil) + req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie")) + r.ServeHTTP(res3, req3) +} + +func TestSessionFlashes(t *testing.T) { + r := createTestServer() + r.Use(testSessionMiddleware) + r.Use(JWTRenewalMiddleware) + + r.GET("/set", func(c *gin.Context) { + session := GetSession(c) + session.Session().AddFlash("foobar") + session.Save() + c.JSON(200, nil) + }) + + r.GET("/flash", func(c *gin.Context) { + session := GetSession(c) + l := len(session.Session().Flashes()) + if l != 1 { + t.Fatal("Flashes count does not equal 1. Equals ", l) + } + session.Save() + c.JSON(200, nil) + }) + + r.GET("/check", func(c *gin.Context) { + session := GetSession(c) + l := len(session.Session().Flashes()) + if l != 0 { + t.Fatal("flashes count is not 0 after reading. Equals ", l) + } + session.Save() + c.JSON(200, nil) + }) + + res1 := httptest.NewRecorder() + req1, _ := http.NewRequest("GET", "/set", nil) + r.ServeHTTP(res1, req1) + + res2 := httptest.NewRecorder() + req2, _ := http.NewRequest("GET", "/flash", nil) + req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res2, req2) + + res3 := httptest.NewRecorder() + req3, _ := http.NewRequest("GET", "/check", nil) + req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie")) + r.ServeHTTP(res3, req3) +} + +func TestSessionClear(t *testing.T) { + data := map[string]string{ + "key": "val", + "foo": "bar", + } + r := createTestServer() + store := getTestStore() + r.Use(CreateSession(testSessionName, store)) + r.Use(JWTRenewalMiddleware) + + r.GET("/set", func(c *gin.Context) { + session := GetSession(c) + for k, v := range data { + session.Set(k, v) + } + session.Clear() + session.Save() + c.JSON(200, nil) + }) + + r.GET("/check", func(c *gin.Context) { + session := GetSession(c) + for k, v := range data { + if session.Get(k) == v { + t.Fatal("Session clear failed") + } + } + c.JSON(200, nil) + }) + + res1 := httptest.NewRecorder() + req1, _ := http.NewRequest("GET", "/set", nil) + r.ServeHTTP(res1, req1) + + res2 := httptest.NewRecorder() + req2, _ := http.NewRequest("GET", "/check", nil) + req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res2, req2) +} diff --git a/tools/example_config.json b/tools/example_config.json index 5b262de606..1b94b3978f 100644 --- a/tools/example_config.json +++ b/tools/example_config.json @@ -9,5 +9,10 @@ "address": ":8080", "cert": "./tools/kolide.crt", "key": "./tools/kolide.key" + }, + "app": { + "bcrypt_cost": 12, + "salt_length": 32, + "jwt_key": "very secure" } } \ No newline at end of file diff --git a/users.go b/users.go new file mode 100644 index 0000000000..92d38919ad --- /dev/null +++ b/users.go @@ -0,0 +1,435 @@ +package main + +import ( + "fmt" + + "github.com/Sirupsen/logrus" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "golang.org/x/crypto/bcrypt" +) + +// User is the model struct which represents a kolide user +type User struct { + BaseModel + Username string `gorm:"not null;unique_index:idx_user_unique_username"` + Password []byte `gorm:"not null"` + Salt string `gorm:"not null"` + Name string + Email string `gorm:"not null;unique_index:idx_user_unique_email"` + Admin bool `gorm:"not null"` + Enabled bool `gorm:"not null"` + NeedsPasswordReset bool +} + +// NewUser is a wrapper around the creation of a new user. +// NewUser exists largely to allow the API to simply accept a string password +// while using the applications password hashing mechanisms to salt and hash the +// password. +func NewUser(db *gorm.DB, username, password, email string, admin, needsPasswordReset bool) (*User, error) { + salt, hash, err := SaltAndHashPassword(password) + if err != nil { + return nil, err + } + user := &User{ + Username: username, + Password: hash, + Salt: salt, + Email: email, + Admin: admin, + Enabled: true, + NeedsPasswordReset: needsPasswordReset, + } + + err = db.Create(&user).Error + if err != nil { + return nil, err + } + return user, nil +} + +// ValidatePassword accepts a potential password for a given user and attempts +// to validate it against the hash stored in the database after joining the +// supplied password with the stored password salt +func (u *User) ValidatePassword(password string) error { + saltAndPass := []byte(fmt.Sprintf("%s%s", u.Salt, password)) + return bcrypt.CompareHashAndPassword(u.Password, saltAndPass) +} + +// SetPassword accepts a new password for a user object and updates the salt +// and hash for that user in the database. This function assumes that the +// appropriate authorization checks have already occurred by the caller. +func (u *User) SetPassword(db *gorm.DB, password string) error { + salt, hash, err := SaltAndHashPassword(password) + if err != nil { + return err + } + u.Salt = salt + u.Password = hash + return db.Save(u).Error +} + +// MakeAdmin is a simple wrapper around promoting a user to an administrator. +// If the user is already an admin, this function will return without modifying +// the database +func (u *User) MakeAdmin(db *gorm.DB) error { + if !u.Admin { + u.Admin = true + return db.Save(&u).Error + } + return nil +} + +type GetUserRequestBody struct { + ID uint `json:"id" binding:"required"` +} + +func GetUser(c *gin.Context) { + var body GetUserRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing GetUser post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) // TODO tampered? + return + } + + var user User + err = db.Where("id = ?", body.ID).First(&user).Error + if err != nil { + logrus.Errorf("Error finding user in database: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.CanPerformReadActionOnUser(db, &user) { + UnauthorizedError(c) + return + } + + c.JSON(200, map[string]interface{}{ + "id": user.ID, + "username": user.Username, + "name": user.Name, + "email": user.Email, + "admin": user.Admin, + "enabled": user.Enabled, + "needs_password_reset": user.NeedsPasswordReset, + }) +} + +type CreateUserRequestBody struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + Email string `json:"email" binding:"required"` + Admin bool `json:"admin"` + NeedsPasswordReset bool `json:"needs_password_reset"` +} + +func CreateUser(c *gin.Context) { + var body CreateUserRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing CreateUser post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.IsAdmin() { + UnauthorizedError(c) + return + } + + _, err = NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset) + if err != nil { + logrus.Errorf("Error creating new user: %s", err.Error()) + DatabaseError(c) + return + } + + c.JSON(200, nil) +} + +type ModifyUserRequestBody struct { + ID uint `json:"id" binding:"required"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` +} + +func ModifyUser(c *gin.Context) { + var body ModifyUserRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing ModifyUser post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) + return + } + + var user User + err = db.Where("id = ?", body.ID).First(&user).Error + if err != nil { + logrus.Errorf("Error finding user in database: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.CanPerformWriteActionOnUser(db, &user) { + UnauthorizedError(c) + return + } + + err = db.Save(&user).Error + if err != nil { + logrus.Errorf("Error updating user in database: %s", err.Error()) + DatabaseError(c) + return + } + c.JSON(200, nil) +} + +type DeleteUserRequestBody struct { + ID uint `json:"id" binding:"required"` +} + +func DeleteUser(c *gin.Context) { + var body DeleteUserRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing DeleteUser post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.IsAdmin() { + UnauthorizedError(c) + return + } + + var user User + err = db.Where("id = ?", body.ID).First(&user).Error + if err != nil { + logrus.Errorf("Error finding user in database: %s", err.Error()) + DatabaseError(c) + return + } + + err = db.Delete(&user).Error + if err != nil { + logrus.Errorf("Error deleting user from database: %s", err.Error()) + DatabaseError(c) + return + } + c.JSON(200, nil) +} + +type ResetPasswordRequestBody struct { + ID uint `json:"id" binding:"required"` + Password string `json:"password" binding:"required"` + PasswordConfim string `json:"password_confirm" binding:"required"` +} + +func ResetUserPassword(c *gin.Context) { + var body ResetPasswordRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing ResetPassword post body: %s", err.Error()) + return + } + + if body.Password != body.PasswordConfim { + c.JSON(406, map[string]interface{}{"error": "Passwords do not match"}) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) + return + } + var user User + err = db.Where("id = ?", body.ID).First(&user).Error + if err != nil { + logrus.Errorf("Error finding user in database: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.CanPerformWriteActionOnUser(db, &user) { + UnauthorizedError(c) + return + } + + err = user.SetPassword(db, body.Password) + if err != nil { + logrus.Errorf("Error setting user password: %s", err.Error()) + // xxx don't try to write to the db? + } + + err = db.Save(&user).Error + if err != nil { + logrus.Errorf("Error updating user in database: %s", err.Error()) + DatabaseError(c) + return + } + c.JSON(200, nil) +} + +type SetUserAdminStateRequestBody struct { + ID uint `json:"id" binding:"required"` + Admin bool `json:"admin" binding:"required"` +} + +func SetUserAdminState(c *gin.Context) { + var body SetUserAdminStateRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing SetUserAdminState post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.IsAdmin() { + UnauthorizedError(c) + return + } + + var user User + err = db.Where("id = ?", body.ID).First(&user).Error + if err != nil { + logrus.Errorf("Error finding user in database: %s", err.Error()) + DatabaseError(c) + return + } + + user.Admin = body.Admin + err = db.Save(&user).Error + if err != nil { + logrus.Errorf("Error updating user in database: %s", err.Error()) + DatabaseError(c) + return + } + c.JSON(200, nil) +} + +type SetUserEnabledStateRequestBody struct { + ID uint `json:"id" binding:"required"` + Enabled bool `json:"enabled" binding:"required"` +} + +func SetUserEnabledState(c *gin.Context) { + var body SetUserEnabledStateRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf("Error parsing SetUserEnabledState post body: %s", err.Error()) + return + } + + db, err := GetDB(c) + if err != nil { + logrus.Errorf("Could not open database: %s", err.Error()) + DatabaseError(c) + return + } + + vc, err := VC(c, db) + if err != nil { + logrus.Errorf("Could not create VC: %s", err.Error()) + DatabaseError(c) + return + } + + if !vc.IsAdmin() { + UnauthorizedError(c) + return + } + + var user User + err = db.Where("id = ?", body.ID).First(&user).Error + if err != nil { + logrus.Errorf("Error finding user in database: %s", err.Error()) + DatabaseError(c) + return + } + + user.Enabled = body.Enabled + err = db.Save(&user).Error + if err != nil { + logrus.Errorf("Error updating user in database: %s", err.Error()) + DatabaseError(c) + return + } + c.JSON(200, nil) +} diff --git a/users_test.go b/users_test.go new file mode 100644 index 0000000000..bc51d65423 --- /dev/null +++ b/users_test.go @@ -0,0 +1,309 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/jinzhu/gorm" +) + +func TestNewUser(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false) + if err != nil { + t.Fatal(err.Error()) + } + + if user.Username != "marpaia" { + t.Fatalf("Username is not what's expected: %s", user.Username) + } + + if user.Email != "mike@kolide.co" { + t.Fatalf("Email is not what's expected: %s", user.Email) + } + + if !user.Admin { + t.Fatal("User is not an admin") + } + + var verify User + db.Where("username = ?", "marpaia").First(&verify) + if verify.ID != user.ID { + t.Fatal("Couldn't select user back from database") + } +} + +func TestValidatePassword(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false) + if err != nil { + t.Fatal(err.Error()) + } + + err = user.ValidatePassword("foobar") + if err != nil { + t.Fatal("Password validation failed") + } + + err = user.ValidatePassword("not correct") + if err == nil { + t.Fatal("Incorrect password worked") + } +} + +func TestMakeAdmin(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + if user.Admin { + t.Fatal("Admin should be false") + } + + err = user.MakeAdmin(db) + if err != nil { + t.Fatal(err.Error()) + } + + if !user.Admin { + t.Fatal("Admin should be true") + } + + var verify User + db.Where("admin = ?", true).First(&verify) + + if user.ID != verify.ID { + t.Fatal("Users don't match") + } + + if !verify.Admin { + t.Fatal("User wasn't set as admin in the database") + } + +} + +func TestUpdatingUser(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + user.Email = "marpaia@kolide.co" + err = db.Save(user).Error + if err != nil { + t.Fatal(err.Error()) + } + + if user.Email != "marpaia@kolide.co" { + t.Fatal("user.Email was reset") + } + + var verify User + err = db.Where("id = ?", user.ID).First(&verify).Error + if err != nil { + t.Fatal(err.Error()) + } + + if verify.Email != "marpaia@kolide.co" { + t.Fatalf("user's email was not updated in the DB: %s", verify.Email) + } + +} + +func TestDeletingUser(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + var verify1 User + err = db.Where("username = ?", "marpaia").First(&verify1).Error + if err != nil { + t.Fatal(err.Error()) + } + if verify1.ID != user.ID { + t.Fatal("users are not the same") + } + + err = db.Delete(&user).Error + if err != nil { + t.Fatal(err.Error()) + } + + var verify2 User + err = db.Where("username = ?", "marpaia").First(&verify2).Error + if err != gorm.ErrRecordNotFound { + t.Fatal("Record was not deleted") + } +} + +func TestSetPassword(t *testing.T) { + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + + user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false) + if err != nil { + t.Fatal(err.Error()) + } + + err = user.ValidatePassword("foobar") + if err != nil { + t.Fatal(err.Error()) + } + + err = user.SetPassword(db, "baz") + if err != nil { + t.Fatal(err.Error()) + } + + err = user.ValidatePassword("baz") + if err != nil { + t.Fatal(err.Error()) + } + + var verify User + err = db.Where("username = ?", "marpaia").First(&verify).Error + if err != nil { + t.Fatal(err.Error()) + } + + err = verify.ValidatePassword("baz") + if err != nil { + t.Fatal(err.Error()) + } +} + +func TestUserManagementIntegration(t *testing.T) { + r := createTestServer() + r.Use(testSessionMiddleware) + r.Use(JWTRenewalMiddleware) + + db, err := openTestDB() + if err != nil { + t.Fatal(err.Error()) + } + injectedTestDB = db + + admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false) + if err != nil { + t.Fatal(err.Error()) + } + _ = admin + + r.POST("/login", Login) + r.GET("/logout", Logout) + + r.GET("/user", GetUser) + r.PUT("/user", CreateUser) + r.PATCH("/user", ModifyUser) + r.DELETE("/user", DeleteUser) + + res1 := httptest.NewRecorder() + body1, err := json.Marshal(LoginRequestBody{ + Username: "admin", + Password: "foobar", + }) + if err != nil { + t.Fatal(err.Error()) + } + buff1 := new(bytes.Buffer) + buff1.Write(body1) + req1, _ := http.NewRequest("POST", "/login", buff1) + req1.Header.Set("Content-Type", "application/json") + r.ServeHTTP(res1, req1) + if res1.Code != 200 { + t.Fatalf("Response code: %d", res1.Code) + } + + res2 := httptest.NewRecorder() + body2, err := json.Marshal(CreateUserRequestBody{ + Username: "marpaia", + Password: "foobar", + Email: "mike@kolide.co", + Admin: false, + NeedsPasswordReset: false, + }) + if err != nil { + t.Fatal(err.Error()) + } + buff2 := new(bytes.Buffer) + buff2.Write(body2) + req2, _ := http.NewRequest("PUT", "/user", buff2) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res2, req2) + + res3 := httptest.NewRecorder() + body3, err := json.Marshal(CreateUserRequestBody{ + Username: "admin2", + Password: "foobar", + Email: "admin2@kolide.co", + Admin: true, + NeedsPasswordReset: false, + }) + if err != nil { + t.Fatal(err.Error()) + } + buff3 := new(bytes.Buffer) + buff3.Write(body3) + req3, _ := http.NewRequest("PUT", "/user", buff3) + req3.Header.Set("Content-Type", "application/json") + req3.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) + r.ServeHTTP(res3, req3) + + var user User + err = db.Where("username = ?", "marpaia").First(&user).Error + if err != nil { + t.Fatal(err.Error()) + } + + if user.Email != "mike@kolide.co" { + t.Fatalf("user's email was not set in the DB: %s", user.Email) + } + if user.Admin { + t.Fatal("user shouldn't be admin") + } + + var admin2 User + err = db.Where("username = ?", "admin2").First(&admin2).Error + if err != nil { + t.Fatal(err.Error()) + } + + if admin2.Email != "admin2@kolide.co" { + t.Fatalf("admin2's email was not set in the DB: %s", admin2.Email) + } + if !admin2.Admin { + t.Fatal("admin2 should be admin") + } + +}