Authentication, authorization and user management (#10)

This commit is contained in:
Mike Arpaia 2016-08-01 16:32:20 -07:00 committed by Zachary Wasserman
parent 91e78d276f
commit eee370e127
12 changed files with 1716 additions and 37 deletions

View file

@ -8,6 +8,14 @@ To build the code, run the following from the root of the repository:
go build
```
## Testing
To run the application's tests, run the following from the root of the repository:
```
go test
```
## Development Environment
To set up the development environment via docker, run the following frmo the root of the repository:

302
auth.go Normal file
View file

@ -0,0 +1,302 @@
package main
import (
"errors"
"fmt"
"math/rand"
"time"
"github.com/Sirupsen/logrus"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt"
)
// ViewerContext is a struct which represents the ability for an execution
// context to participate in certain actions. Most often, a ViewerContext is
// associated with an application user, but a ViewerContext can represent a
// variety of other execution contexts as well (script, test, etc). The main
// purpose of a ViewerContext is to assist in the authorization of sensitive
// actions.
type ViewerContext struct {
user *User
}
// JWT returns a JWT token in serialized string form given a ViewerContext as
// well as a potential error in the event that things have gone wrong.
func (vc *ViewerContext) JWT() (string, error) {
return GenerateJWT(vc.user.ID)
}
// IsAdmin indicates whether or not the current user can perform administrative
// actions.
func (vc *ViewerContext) IsAdmin() bool {
if vc.user != nil {
return vc.user.Admin && vc.user.Enabled
}
return false
}
// UserID is a helper that enables quick access to the user ID of the current
// user.
func (vc *ViewerContext) UserID() (uint, error) {
if vc.user != nil {
return vc.user.ID, nil
}
return 0, errors.New("No user set")
}
func (vc *ViewerContext) CanPerformActions(db *gorm.DB) bool {
if vc.user == nil {
return false
}
if !vc.user.Enabled {
return false
}
return true
}
func (vc *ViewerContext) IsUserID(id uint) bool {
userID, err := vc.UserID()
if err != nil {
return false
}
if userID == id {
return true
}
return false
}
func (vc *ViewerContext) CanPerformWriteActionOnUser(db *gorm.DB, u *User) bool {
return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin())
}
func (vc *ViewerContext) CanPerformReadActionOnUser(db *gorm.DB, u *User) bool {
return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin())
}
// GenerateJWT generates a JWT token in serialized string form given a
// ViewerContext as well as a potential error in the event that things have
// gone wrong.
func GenerateJWT(userID uint) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": userID,
// "Not Before": https://tools.ietf.org/html/rfc7519#section-4.1.5
"nbf": time.Now().UTC().Unix(),
// "Expiration Time": https://tools.ietf.org/html/rfc7519#section-4.1.4
"exp": time.Now().UTC().AddDate(0, 2, 0).Unix(),
})
return token.SignedString([]byte(config.App.JWTKey))
}
// ParseJWT attempts to parse a JWT token in serialized string form into a
// JWT token in a deserialized jwt.Token struct.
func ParseJWT(token string) (*jwt.Token, error) {
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
method, ok := t.Method.(*jwt.SigningMethodHMAC)
if !ok || method != jwt.SigningMethodHS256 {
return nil, errors.New("Unexpected signing method")
}
return []byte(config.App.JWTKey), nil
})
}
// JWTRenewalMiddleware optimistically tries to renew the user's JWT token.
// This allows kolide to have sessions that last forever, assuming that a user
// logs in and uses the application within a reasonable time window (which is
// defined in the JWT token generation method). If anything goes wrong, this
// middleware will back off and defer recovery of the situation to the
// downstream web request.
func JWTRenewalMiddleware(c *gin.Context) {
session := GetSession(c)
tokenCookie := session.Get("jwt")
if tokenCookie == nil {
c.Next()
return
}
tokenString, ok := tokenCookie.(string)
if !ok {
c.Next()
return
}
token, err := ParseJWT(tokenString)
if err != nil {
c.Next()
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
c.Next()
return
}
userID := uint(claims["user_id"].(float64))
jwt, err := GenerateJWT(userID)
if err != nil {
c.Next()
return
}
session.Set("jwt", jwt)
c.Next()
}
// GenerateVC generates a ViewerContext given a user struct
func GenerateVC(user *User) *ViewerContext {
return &ViewerContext{
user: user,
}
}
// EmptyVC is a utility which generates an empty ViewerContext. This is often
// used to represent users which are not logged in.
func EmptyVC() *ViewerContext {
return &ViewerContext{
user: nil,
}
}
// VC accepts a web request context and a database handler and attempts
// to parse a user's jwt token out of the active session, validate the token,
// and generate an appropriate ViewerContext given the data in the session.
func VC(c *gin.Context, db *gorm.DB) (*ViewerContext, error) {
session := GetSession(c)
tokenCookie := session.Get("jwt")
if tokenCookie == nil {
return nil, errors.New("jwt session attribute not set")
}
tokenString, ok := tokenCookie.(string)
if !ok {
return nil, errors.New("jwt token was not string")
}
token, err := ParseJWT(tokenString)
if err != nil {
return nil, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return nil, errors.New("Invalid token")
}
userID := uint(claims["user_id"].(float64))
var user User
err = db.Where("id = ?", userID).First(&user).Error
if err != nil {
return nil, err
}
return GenerateVC(&user), nil
}
type LoginRequestBody struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
func Login(c *gin.Context) {
var body LoginRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing Login post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
var user User
err = db.Where("username = ?", body.Username).First(&user).Error
if err != nil {
logrus.Debugf("User not found: %s", body.Username)
UnauthorizedError(c)
return
}
err = user.ValidatePassword(body.Password)
if err != nil {
logrus.Debugf("Invalid password for user: %s", body.Username)
UnauthorizedError(c)
return
}
token, err := GenerateVC(&user).JWT()
if err != nil {
logrus.Fatalf("Error generating token: %s", err.Error())
DatabaseError(c)
return
}
session := GetSession(c)
session.Set("jwt", token)
session.Save()
c.JSON(200, map[string]interface{}{
"id": user.ID,
"username": user.Username,
"email": user.Email,
"name": user.Name,
"admin": user.Admin,
})
}
func Logout(c *gin.Context) {
session := GetSession(c)
session.Clear()
c.JSON(200, nil)
}
const (
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
letterIndex = 6 // 6 bits to represent a letter index
letterIndexMask = 1<<letterIndex - 1 // All 1-bits, as many as letterIndex
letterIndexMax = 63 / letterIndex // # of letter indices fitting in 63 bits
)
var psrngSource = rand.NewSource(time.Now().UnixNano())
func generateRandomText(length int) string {
text := make([]byte, length)
for i, cache, remain := length-1, psrngSource.Int63(), letterIndexMax; i >= 0; {
if remain == 0 {
cache, remain = psrngSource.Int63(), letterIndexMax
}
if idx := int(cache & letterIndexMask); idx < len(letters) {
text[i] = letters[idx]
i--
}
cache >>= letterIndex
remain--
}
return string(text)
}
func HashPassword(salt, password string) ([]byte, error) {
return bcrypt.GenerateFromPassword(
[]byte(fmt.Sprintf("%s%s", salt, password)),
config.App.BcryptCost,
)
}
func SaltAndHashPassword(password string) (string, []byte, error) {
salt := generateRandomText(config.App.SaltLength)
hashed, err := HashPassword(salt, password)
return salt, hashed, err
}

198
auth_test.go Normal file
View file

@ -0,0 +1,198 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
"unicode/utf8"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
)
func TestGenerateRandomText(t *testing.T) {
text := generateRandomText(12)
if utf8.RuneCountInString(text) != 12 {
t.Fatal("generateRandomText generated the wrong length string")
}
}
func TestGenerateVC(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
tokenString, err := GenerateVC(user).JWT()
if err != nil {
t.Fatal(err.Error())
}
token, err := ParseJWT(tokenString)
if err != nil {
t.Fatal(err.Error())
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
t.Fatal("Token is invalid")
}
userID := uint(claims["user_id"].(float64))
if userID != user.ID {
t.Fatal("Claims are incorrect. userID is %d", userID)
}
}
func TestVC(t *testing.T) {
r := createTestServer()
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
r.GET("/admin_login", func(c *gin.Context) {
token, err := GenerateVC(admin).JWT()
if err != nil {
t.Fatal(err.Error())
}
session := GetSession(c)
session.Set("jwt", token)
session.Save()
c.JSON(200, nil)
})
r.GET("/user_login", func(c *gin.Context) {
token, err := GenerateVC(user).JWT()
if err != nil {
t.Fatal(err.Error())
}
session := GetSession(c)
session.Set("jwt", token)
session.Save()
c.JSON(200, nil)
})
r.GET("/admin", func(c *gin.Context) {
vc, err := VC(c, db)
if err != nil {
t.Fatal(err.Error())
}
if !vc.IsAdmin() {
t.Fatal("Not admin")
}
c.String(200, "OK")
})
r.GET("/user", func(c *gin.Context) {
vc, err := VC(c, db)
if err != nil {
t.Fatal(err.Error())
}
if vc.IsAdmin() {
t.Fatal("Not user")
}
c.String(200, "OK")
})
res1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/admin_login", nil)
r.ServeHTTP(res1, req1)
res2 := httptest.NewRecorder()
req2, _ := http.NewRequest("GET", "/admin", nil)
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
res3 := httptest.NewRecorder()
req3, _ := http.NewRequest("GET", "/user_login", nil)
r.ServeHTTP(res3, req3)
res4 := httptest.NewRecorder()
req4, _ := http.NewRequest("GET", "/user", nil)
req4.Header.Set("Cookie", res3.Header().Get("Set-Cookie"))
r.ServeHTTP(res4, req4)
}
func TestIsUserID(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
vc := GenerateVC(user1)
if !vc.IsUserID(user1.ID) {
t.Fatal("IsUserID failed on same user object")
}
if vc.IsUserID(user1.ID + 1) {
t.Fatal("IsUserID passed for incorrect ID")
}
}
func TestCanPerformActionsOnUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
user2, err := NewUser(db, "zwass", "foobar", "zwass@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
adminVC := GenerateVC(admin)
user1VC := GenerateVC(user1)
if !adminVC.CanPerformWriteActionOnUser(db, user1) || !adminVC.CanPerformWriteActionOnUser(db, user2) {
t.Fatal("Admin should be able to perform writes on users")
}
if !adminVC.CanPerformReadActionOnUser(db, user1) || !adminVC.CanPerformReadActionOnUser(db, user2) {
t.Fatal("Admin should be able to perform reads on users")
}
if user1VC.CanPerformWriteActionOnUser(db, user2) {
t.Fatal("user1 shouldn't be able to perform writes on user2")
}
if user1VC.CanPerformReadActionOnUser(db, user2) {
t.Fatal("user1 should be able to perform reads on user2")
}
}

View file

@ -5,24 +5,64 @@ import (
"io/ioutil"
)
type mysqlConfigData struct {
Address string `json:"address"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database"`
}
type serverConfigData struct {
Address string `json:"address"`
Cert string `json:"cert"`
Key string `json:"key"`
}
type appConfigData struct {
BcryptCost int `json:"bcrypt_cost"`
SaltLength int `json:"salt_length"`
JWTKey string `json:"jwt_key"`
}
type configData struct {
MySQL struct {
Address string `json:"address"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database"`
} `json:"mysql"`
Server struct {
Address string `json:"address"`
Cert string `json:"cert"`
Key string `json:"key"`
} `json:"server"`
MySQL mysqlConfigData `json:"mysql"`
Server serverConfigData `json:"server"`
App appConfigData `json:"app"`
}
var (
config configData
)
var defaultMysqlConfigData = mysqlConfigData{
Address: "127.0.0.1:3306",
Username: "kolide",
Password: "kolide",
Database: "kolide",
}
var defaultServerConfigData = serverConfigData{
Address: ":8080",
Cert: "./tools/kolide.crt",
Key: "./tools/kolide.key",
}
var defaultAppConfigData = appConfigData{
BcryptCost: 12,
SaltLength: 32,
JWTKey: "very secure",
}
var defaultConfigData = configData{
MySQL: defaultMysqlConfigData,
Server: defaultServerConfigData,
App: defaultAppConfigData,
}
func setDefaultConfigValues() {
config = defaultConfigData
}
func loadConfig(path string) error {
content, err := ioutil.ReadFile(path)
if err != nil {

View file

@ -2,11 +2,14 @@ package main
import (
"fmt"
"math/rand"
"os"
"path"
"runtime"
"time"
"github.com/Sirupsen/logrus"
"github.com/gin-gonic/gin"
"gopkg.in/alecthomas/kingpin.v2"
)
@ -39,12 +42,27 @@ var (
serve = app.Command("serve", "Run the Kolide server")
)
func init() {
// set gin mode to release to silence some superfluous logging
gin.SetMode(gin.ReleaseMode)
// configure logging
logrus.AddHook(logContextHook{})
// populate the global config data structure with sane defaults
setDefaultConfigValues()
}
// logContextHook is a logrus hook which is used to contextualize application
// logs to include data stuch as line numbers, file names, etc.
type logContextHook struct{}
// Levels defines which levels the logContextHook logrus hook should apply to
func (hook logContextHook) Levels() []logrus.Level {
return logrus.AllLevels
}
// Fire defines what the logContextHook should actually do when it is triggered
func (hook logContextHook) Fire(entry *logrus.Entry) error {
if pc, file, line, ok := runtime.Caller(8); ok {
funcName := runtime.FuncForPC(pc).Name()
@ -57,14 +75,15 @@ func (hook logContextHook) Fire(entry *logrus.Entry) error {
return nil
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func main() {
// configure flag parsing and parse flags
app.Version(version)
args, err := app.Parse(os.Args[1:])
// configure logging
logrus.AddHook(logContextHook{})
// configure the application based on the flags that have been set
if *debug {
logrus.SetLevel(logrus.DebugLevel)
@ -76,18 +95,21 @@ func main() {
// if config hasn't been defined and the example config exists relative to
// the binary, it's likely that the tool is being ran right after building
// from source so we auto-populate the config path.
// from source so we auto-populate the example config path.
if *configPath == "" {
if _, err = os.Stat("./tools/example_config.json"); err == nil {
*configPath = "./tools/example_config.json"
} else {
logrus.Fatalln("No config file specified. Use --config to specify config path.")
}
}
err = loadConfig(*configPath)
if err != nil {
logrus.Fatalf("Error loading config: %s", err.Error())
// if the user has defined a config path OR the example config is found
// relative to the binary, load config content from the file. any content
// in the config file will overwrite the default values
if *configPath != "" {
err = loadConfig(*configPath)
if err != nil {
logrus.Fatalf("Error loading config: %s", err.Error())
}
}
// route the executable based on the sub-command
@ -100,7 +122,7 @@ func main() {
dropTables(db)
createTables(db)
case serve.FullCommand():
createServer().RunTLS(
CreateServer().RunTLS(
config.Server.Address,
config.Server.Cert,
config.Server.Key)

View file

@ -1,13 +1,49 @@
package main
import (
"errors"
"fmt"
"time"
"github.com/jinzhu/gorm"
"github.com/gin-gonic/gin"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
type DBEnvironment int
const (
ProductionDB DBEnvironment = iota
TestingDB DBEnvironment = iota
)
var injectedTestDB *gorm.DB
func GetDB(c *gin.Context) (*gorm.DB, error) {
f, ok := c.Get("DB")
if !ok {
return nil, errors.New("DB was not set on the supplied *gin.Context. Use a middleware to set it.")
}
switch f.(DBEnvironment) {
case ProductionDB:
return openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database)
case TestingDB:
if injectedTestDB != nil {
return injectedTestDB, nil
}
db, err := openTestDB()
if err != nil {
return nil, errors.New("Could not open a test database")
}
injectedTestDB = db
return injectedTestDB, nil
default:
return nil, errors.New("GetDB not implemented for DBEnvironment object")
}
}
type BaseModel struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
@ -15,16 +51,6 @@ type BaseModel struct {
DeletedAt *time.Time `sql:"index"`
}
type User struct {
BaseModel
Username string `gorm:"not null;unique_index:idx_user_unique_username"`
Password string `gorm:"not null"`
Salt string `gorm:"not null"`
Name string
Email string `gorm:"unique_index:idx_user_unique_email"`
Admin bool `gorm:"not null"`
}
type ScheduledQuery struct {
BaseModel
Name string `gorm:"not null"`
@ -162,6 +188,26 @@ func openDB(user, password, address, dbName string) (*gorm.DB, error) {
return gorm.Open("mysql", connectionString)
}
func openTestDB() (*gorm.DB, error) {
db, err := gorm.Open("sqlite3", ":memory:")
if err != nil {
return nil, err
}
createTables(db)
return db, db.Error
}
func ProductionDatabaseMiddleware(c *gin.Context) {
c.Set("DB", ProductionDB)
c.Next()
}
func TestingDatabaseMiddleware(c *gin.Context) {
c.Set("DB", TestingDB)
c.Next()
}
func dropTables(db *gorm.DB) {
for _, table := range tables {
db.DropTableIfExists(table)

View file

@ -4,8 +4,57 @@ import (
"github.com/gin-gonic/gin"
)
func attachRoutes(router *gin.Engine) {
v1 := router.Group("/api/v1")
// ServerError is a helper which accepts a string error and returns a map in
// format that is required by gin.Context.JSON
func ServerError(e string) *map[string]interface{} {
return &map[string]interface{}{
"error": e,
}
}
// DatabaseError emits a response that is appropriate in the event that a
// database failure occurs, a record is not found in the database, etc
func DatabaseError(c *gin.Context) {
c.JSON(500, ServerError("Database error"))
}
// UnauthorizedError emits a response that is appropriate in the event that a
// request is received by a user which is not authorized to carry out the
// requested action
func UnauthorizedError(c *gin.Context) {
c.JSON(401, ServerError("Unauthorized"))
}
func createTestServer() *gin.Engine {
server := gin.New()
server.Use(TestingDatabaseMiddleware)
return server
}
// CreateServer creates a gin.Engine HTTP server and configures it to be in a
// state such that it is ready to serve HTTP requests for the kolide application
func CreateServer() *gin.Engine {
server := gin.New()
server.Use(ProductionDatabaseMiddleware)
v1 := server.Group("/api/v1")
// Kolide application API endpoints
kolide := v1.Group("/kolide")
kolide.Use(SessionMiddleware)
kolide.Use(JWTRenewalMiddleware)
kolide.POST("/login", Login)
kolide.GET("/logout", Logout)
kolide.GET("/user", GetUser)
kolide.PUT("/user", CreateUser)
kolide.PATCH("/user", ModifyUser)
kolide.DELETE("/user", DeleteUser)
kolide.PATCH("/user/password", ResetUserPassword)
kolide.PATCH("/user/admin", SetUserAdminState)
kolide.PATCH("/user/enabled", SetUserEnabledState)
// osquery API endpoints
osquery := v1.Group("/osquery")
@ -14,10 +63,6 @@ func attachRoutes(router *gin.Engine) {
osquery.POST("/log", OsqueryLog)
osquery.POST("/distributed/read", OsqueryDistributedRead)
osquery.POST("/distributed/write", OsqueryDistributedWrite)
}
func createServer() *gin.Engine {
server := gin.New()
attachRoutes(server)
return server
}

90
sessions.go Normal file
View file

@ -0,0 +1,90 @@
package main
import (
"net/http"
"github.com/Sirupsen/logrus"
"github.com/gin-gonic/gin"
"github.com/gorilla/context"
"github.com/gorilla/sessions"
)
// GetSession allows you to get the Session object given a web request. This
// is often used in HTTP handlers as the main entry point into managing and
// manipulating the session
func GetSession(c *gin.Context) *Session {
return c.MustGet("Session").(*Session)
}
// SessionMiddleware is the middleware used for production session management.
// Tests should use `testSessionMiddleware`, which follows the same pattern,
// but creates a session configured for testing.
func SessionMiddleware(c *gin.Context) {
CreateSession("Session", sessions.NewCookieStore([]byte("c")))(c)
}
// CreateSessions is a helper which returns a gin.HandlerFunc which creates
// a new session management middleware given the name of the session to manage
// and the session storage mechanism. This is commonly used to generate session
// middleware given a variety of settings in both production and testing
// environments
func CreateSession(name string, store sessions.Store) gin.HandlerFunc {
return func(c *gin.Context) {
s := &Session{name, c.Request, store, nil, c.Writer}
c.Set("Session", s)
defer context.Clear(c.Request)
c.Next()
}
}
// Session is a convenience wrapper around gorilla sessions, which is provided
// by github.com/gorilla/sessions
type Session struct {
name string
request *http.Request
store sessions.Store
session *sessions.Session
writer http.ResponseWriter
}
// Session returns the gorilla session from the Session struct and allows you
// to use any of the functionality of the underlying sessions.Session struct
func (s *Session) Session() *sessions.Session {
if s.session == nil {
var err error
s.session, err = s.store.Get(s.request, s.name)
if err != nil {
logrus.Error(err.Error())
}
}
return s.session
}
// Set simply sets a session key value pair which will be stored in the
// current session for later usage
func (s *Session) Set(key interface{}, val interface{}) {
s.Session().Values[key] = val
}
// Get retrieves a session key value pair which has previously been set
func (s *Session) Get(key interface{}) interface{} {
return s.Session().Values[key]
}
// Delete deletes a session key value pair which has previously been set
func (s *Session) Delete(key interface{}) {
delete(s.Session().Values, key)
}
// Clear deletes all session key value pairs that are set
func (s *Session) Clear() {
for key := range s.Session().Values {
s.Delete(key)
}
}
// Save writes the session, which is required after altering the session in any
// way
func (s *Session) Save() error {
return s.Session().Save(s.request, s.writer)
}

179
sessions_test.go Normal file
View file

@ -0,0 +1,179 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/sessions"
)
const testSessionName = "TestSession"
func getTestStore() sessions.Store {
return sessions.NewCookieStore([]byte("test"))
}
func testSessionMiddleware(c *gin.Context) {
CreateSession(testSessionName, getTestStore())(c)
}
func TestSessionGetSet(t *testing.T) {
r := createTestServer()
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
r.GET("/set", func(c *gin.Context) {
session := GetSession(c)
session.Set("key", "foobar")
session.Save()
c.JSON(200, nil)
})
r.GET("/get", func(c *gin.Context) {
session := GetSession(c)
if session.Get("key") != "foobar" {
t.Fatal("Session writing failed")
}
c.String(200, "OK")
})
res1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/set", nil)
r.ServeHTTP(res1, req1)
res2 := httptest.NewRecorder()
req2, _ := http.NewRequest("GET", "/get", nil)
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
}
func TestSessionDeleteKey(t *testing.T) {
r := createTestServer()
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
r.GET("/set", func(c *gin.Context) {
session := GetSession(c)
session.Set("key", "foobar")
session.Save()
c.JSON(200, nil)
})
r.GET("/delete", func(c *gin.Context) {
session := GetSession(c)
session.Delete("key")
session.Save()
c.JSON(200, nil)
})
r.GET("/get", func(c *gin.Context) {
session := GetSession(c)
if session.Get("key") != nil {
t.Fatal("Session deleting failed")
}
c.JSON(200, nil)
})
res1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/set", nil)
r.ServeHTTP(res1, req1)
res2 := httptest.NewRecorder()
req2, _ := http.NewRequest("GET", "/delete", nil)
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
res3 := httptest.NewRecorder()
req3, _ := http.NewRequest("GET", "/get", nil)
req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie"))
r.ServeHTTP(res3, req3)
}
func TestSessionFlashes(t *testing.T) {
r := createTestServer()
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
r.GET("/set", func(c *gin.Context) {
session := GetSession(c)
session.Session().AddFlash("foobar")
session.Save()
c.JSON(200, nil)
})
r.GET("/flash", func(c *gin.Context) {
session := GetSession(c)
l := len(session.Session().Flashes())
if l != 1 {
t.Fatal("Flashes count does not equal 1. Equals ", l)
}
session.Save()
c.JSON(200, nil)
})
r.GET("/check", func(c *gin.Context) {
session := GetSession(c)
l := len(session.Session().Flashes())
if l != 0 {
t.Fatal("flashes count is not 0 after reading. Equals ", l)
}
session.Save()
c.JSON(200, nil)
})
res1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/set", nil)
r.ServeHTTP(res1, req1)
res2 := httptest.NewRecorder()
req2, _ := http.NewRequest("GET", "/flash", nil)
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
res3 := httptest.NewRecorder()
req3, _ := http.NewRequest("GET", "/check", nil)
req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie"))
r.ServeHTTP(res3, req3)
}
func TestSessionClear(t *testing.T) {
data := map[string]string{
"key": "val",
"foo": "bar",
}
r := createTestServer()
store := getTestStore()
r.Use(CreateSession(testSessionName, store))
r.Use(JWTRenewalMiddleware)
r.GET("/set", func(c *gin.Context) {
session := GetSession(c)
for k, v := range data {
session.Set(k, v)
}
session.Clear()
session.Save()
c.JSON(200, nil)
})
r.GET("/check", func(c *gin.Context) {
session := GetSession(c)
for k, v := range data {
if session.Get(k) == v {
t.Fatal("Session clear failed")
}
}
c.JSON(200, nil)
})
res1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/set", nil)
r.ServeHTTP(res1, req1)
res2 := httptest.NewRecorder()
req2, _ := http.NewRequest("GET", "/check", nil)
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
}

View file

@ -9,5 +9,10 @@
"address": ":8080",
"cert": "./tools/kolide.crt",
"key": "./tools/kolide.key"
},
"app": {
"bcrypt_cost": 12,
"salt_length": 32,
"jwt_key": "very secure"
}
}

435
users.go Normal file
View file

@ -0,0 +1,435 @@
package main
import (
"fmt"
"github.com/Sirupsen/logrus"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt"
)
// User is the model struct which represents a kolide user
type User struct {
BaseModel
Username string `gorm:"not null;unique_index:idx_user_unique_username"`
Password []byte `gorm:"not null"`
Salt string `gorm:"not null"`
Name string
Email string `gorm:"not null;unique_index:idx_user_unique_email"`
Admin bool `gorm:"not null"`
Enabled bool `gorm:"not null"`
NeedsPasswordReset bool
}
// NewUser is a wrapper around the creation of a new user.
// NewUser exists largely to allow the API to simply accept a string password
// while using the applications password hashing mechanisms to salt and hash the
// password.
func NewUser(db *gorm.DB, username, password, email string, admin, needsPasswordReset bool) (*User, error) {
salt, hash, err := SaltAndHashPassword(password)
if err != nil {
return nil, err
}
user := &User{
Username: username,
Password: hash,
Salt: salt,
Email: email,
Admin: admin,
Enabled: true,
NeedsPasswordReset: needsPasswordReset,
}
err = db.Create(&user).Error
if err != nil {
return nil, err
}
return user, nil
}
// ValidatePassword accepts a potential password for a given user and attempts
// to validate it against the hash stored in the database after joining the
// supplied password with the stored password salt
func (u *User) ValidatePassword(password string) error {
saltAndPass := []byte(fmt.Sprintf("%s%s", u.Salt, password))
return bcrypt.CompareHashAndPassword(u.Password, saltAndPass)
}
// SetPassword accepts a new password for a user object and updates the salt
// and hash for that user in the database. This function assumes that the
// appropriate authorization checks have already occurred by the caller.
func (u *User) SetPassword(db *gorm.DB, password string) error {
salt, hash, err := SaltAndHashPassword(password)
if err != nil {
return err
}
u.Salt = salt
u.Password = hash
return db.Save(u).Error
}
// MakeAdmin is a simple wrapper around promoting a user to an administrator.
// If the user is already an admin, this function will return without modifying
// the database
func (u *User) MakeAdmin(db *gorm.DB) error {
if !u.Admin {
u.Admin = true
return db.Save(&u).Error
}
return nil
}
type GetUserRequestBody struct {
ID uint `json:"id" binding:"required"`
}
func GetUser(c *gin.Context) {
var body GetUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing GetUser post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c) // TODO tampered?
return
}
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
if !vc.CanPerformReadActionOnUser(db, &user) {
UnauthorizedError(c)
return
}
c.JSON(200, map[string]interface{}{
"id": user.ID,
"username": user.Username,
"name": user.Name,
"email": user.Email,
"admin": user.Admin,
"enabled": user.Enabled,
"needs_password_reset": user.NeedsPasswordReset,
})
}
type CreateUserRequestBody struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email" binding:"required"`
Admin bool `json:"admin"`
NeedsPasswordReset bool `json:"needs_password_reset"`
}
func CreateUser(c *gin.Context) {
var body CreateUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing CreateUser post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c)
return
}
if !vc.IsAdmin() {
UnauthorizedError(c)
return
}
_, err = NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset)
if err != nil {
logrus.Errorf("Error creating new user: %s", err.Error())
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type ModifyUserRequestBody struct {
ID uint `json:"id" binding:"required"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
}
func ModifyUser(c *gin.Context) {
var body ModifyUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing ModifyUser post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c)
return
}
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(db, &user) {
UnauthorizedError(c)
return
}
err = db.Save(&user).Error
if err != nil {
logrus.Errorf("Error updating user in database: %s", err.Error())
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type DeleteUserRequestBody struct {
ID uint `json:"id" binding:"required"`
}
func DeleteUser(c *gin.Context) {
var body DeleteUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing DeleteUser post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c)
return
}
if !vc.IsAdmin() {
UnauthorizedError(c)
return
}
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
err = db.Delete(&user).Error
if err != nil {
logrus.Errorf("Error deleting user from database: %s", err.Error())
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type ResetPasswordRequestBody struct {
ID uint `json:"id" binding:"required"`
Password string `json:"password" binding:"required"`
PasswordConfim string `json:"password_confirm" binding:"required"`
}
func ResetUserPassword(c *gin.Context) {
var body ResetPasswordRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing ResetPassword post body: %s", err.Error())
return
}
if body.Password != body.PasswordConfim {
c.JSON(406, map[string]interface{}{"error": "Passwords do not match"})
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c)
return
}
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(db, &user) {
UnauthorizedError(c)
return
}
err = user.SetPassword(db, body.Password)
if err != nil {
logrus.Errorf("Error setting user password: %s", err.Error())
// xxx don't try to write to the db?
}
err = db.Save(&user).Error
if err != nil {
logrus.Errorf("Error updating user in database: %s", err.Error())
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type SetUserAdminStateRequestBody struct {
ID uint `json:"id" binding:"required"`
Admin bool `json:"admin" binding:"required"`
}
func SetUserAdminState(c *gin.Context) {
var body SetUserAdminStateRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing SetUserAdminState post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c)
return
}
if !vc.IsAdmin() {
UnauthorizedError(c)
return
}
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
user.Admin = body.Admin
err = db.Save(&user).Error
if err != nil {
logrus.Errorf("Error updating user in database: %s", err.Error())
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type SetUserEnabledStateRequestBody struct {
ID uint `json:"id" binding:"required"`
Enabled bool `json:"enabled" binding:"required"`
}
func SetUserEnabledState(c *gin.Context) {
var body SetUserEnabledStateRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf("Error parsing SetUserEnabledState post body: %s", err.Error())
return
}
db, err := GetDB(c)
if err != nil {
logrus.Errorf("Could not open database: %s", err.Error())
DatabaseError(c)
return
}
vc, err := VC(c, db)
if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c)
return
}
if !vc.IsAdmin() {
UnauthorizedError(c)
return
}
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
user.Enabled = body.Enabled
err = db.Save(&user).Error
if err != nil {
logrus.Errorf("Error updating user in database: %s", err.Error())
DatabaseError(c)
return
}
c.JSON(200, nil)
}

309
users_test.go Normal file
View file

@ -0,0 +1,309 @@
package main
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/jinzhu/gorm"
)
func TestNewUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err)
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
if user.Username != "marpaia" {
t.Fatalf("Username is not what's expected: %s", user.Username)
}
if user.Email != "mike@kolide.co" {
t.Fatalf("Email is not what's expected: %s", user.Email)
}
if !user.Admin {
t.Fatal("User is not an admin")
}
var verify User
db.Where("username = ?", "marpaia").First(&verify)
if verify.ID != user.ID {
t.Fatal("Couldn't select user back from database")
}
}
func TestValidatePassword(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
err = user.ValidatePassword("foobar")
if err != nil {
t.Fatal("Password validation failed")
}
err = user.ValidatePassword("not correct")
if err == nil {
t.Fatal("Incorrect password worked")
}
}
func TestMakeAdmin(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
if user.Admin {
t.Fatal("Admin should be false")
}
err = user.MakeAdmin(db)
if err != nil {
t.Fatal(err.Error())
}
if !user.Admin {
t.Fatal("Admin should be true")
}
var verify User
db.Where("admin = ?", true).First(&verify)
if user.ID != verify.ID {
t.Fatal("Users don't match")
}
if !verify.Admin {
t.Fatal("User wasn't set as admin in the database")
}
}
func TestUpdatingUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
user.Email = "marpaia@kolide.co"
err = db.Save(user).Error
if err != nil {
t.Fatal(err.Error())
}
if user.Email != "marpaia@kolide.co" {
t.Fatal("user.Email was reset")
}
var verify User
err = db.Where("id = ?", user.ID).First(&verify).Error
if err != nil {
t.Fatal(err.Error())
}
if verify.Email != "marpaia@kolide.co" {
t.Fatalf("user's email was not updated in the DB: %s", verify.Email)
}
}
func TestDeletingUser(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
var verify1 User
err = db.Where("username = ?", "marpaia").First(&verify1).Error
if err != nil {
t.Fatal(err.Error())
}
if verify1.ID != user.ID {
t.Fatal("users are not the same")
}
err = db.Delete(&user).Error
if err != nil {
t.Fatal(err.Error())
}
var verify2 User
err = db.Where("username = ?", "marpaia").First(&verify2).Error
if err != gorm.ErrRecordNotFound {
t.Fatal("Record was not deleted")
}
}
func TestSetPassword(t *testing.T) {
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
if err != nil {
t.Fatal(err.Error())
}
err = user.ValidatePassword("foobar")
if err != nil {
t.Fatal(err.Error())
}
err = user.SetPassword(db, "baz")
if err != nil {
t.Fatal(err.Error())
}
err = user.ValidatePassword("baz")
if err != nil {
t.Fatal(err.Error())
}
var verify User
err = db.Where("username = ?", "marpaia").First(&verify).Error
if err != nil {
t.Fatal(err.Error())
}
err = verify.ValidatePassword("baz")
if err != nil {
t.Fatal(err.Error())
}
}
func TestUserManagementIntegration(t *testing.T) {
r := createTestServer()
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
injectedTestDB = db
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
_ = admin
r.POST("/login", Login)
r.GET("/logout", Logout)
r.GET("/user", GetUser)
r.PUT("/user", CreateUser)
r.PATCH("/user", ModifyUser)
r.DELETE("/user", DeleteUser)
res1 := httptest.NewRecorder()
body1, err := json.Marshal(LoginRequestBody{
Username: "admin",
Password: "foobar",
})
if err != nil {
t.Fatal(err.Error())
}
buff1 := new(bytes.Buffer)
buff1.Write(body1)
req1, _ := http.NewRequest("POST", "/login", buff1)
req1.Header.Set("Content-Type", "application/json")
r.ServeHTTP(res1, req1)
if res1.Code != 200 {
t.Fatalf("Response code: %d", res1.Code)
}
res2 := httptest.NewRecorder()
body2, err := json.Marshal(CreateUserRequestBody{
Username: "marpaia",
Password: "foobar",
Email: "mike@kolide.co",
Admin: false,
NeedsPasswordReset: false,
})
if err != nil {
t.Fatal(err.Error())
}
buff2 := new(bytes.Buffer)
buff2.Write(body2)
req2, _ := http.NewRequest("PUT", "/user", buff2)
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
res3 := httptest.NewRecorder()
body3, err := json.Marshal(CreateUserRequestBody{
Username: "admin2",
Password: "foobar",
Email: "admin2@kolide.co",
Admin: true,
NeedsPasswordReset: false,
})
if err != nil {
t.Fatal(err.Error())
}
buff3 := new(bytes.Buffer)
buff3.Write(body3)
req3, _ := http.NewRequest("PUT", "/user", buff3)
req3.Header.Set("Content-Type", "application/json")
req3.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res3, req3)
var user User
err = db.Where("username = ?", "marpaia").First(&user).Error
if err != nil {
t.Fatal(err.Error())
}
if user.Email != "mike@kolide.co" {
t.Fatalf("user's email was not set in the DB: %s", user.Email)
}
if user.Admin {
t.Fatal("user shouldn't be admin")
}
var admin2 User
err = db.Where("username = ?", "admin2").First(&admin2).Error
if err != nil {
t.Fatal(err.Error())
}
if admin2.Email != "admin2@kolide.co" {
t.Fatalf("admin2's email was not set in the DB: %s", admin2.Email)
}
if !admin2.Admin {
t.Fatal("admin2 should be admin")
}
}