mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 01:18:42 +00:00
Authentication, authorization and user management (#10)
This commit is contained in:
parent
91e78d276f
commit
eee370e127
12 changed files with 1716 additions and 37 deletions
|
|
@ -8,6 +8,14 @@ To build the code, run the following from the root of the repository:
|
|||
go build
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
To run the application's tests, run the following from the root of the repository:
|
||||
|
||||
```
|
||||
go test
|
||||
```
|
||||
|
||||
## Development Environment
|
||||
|
||||
To set up the development environment via docker, run the following frmo the root of the repository:
|
||||
|
|
|
|||
302
auth.go
Normal file
302
auth.go
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// ViewerContext is a struct which represents the ability for an execution
|
||||
// context to participate in certain actions. Most often, a ViewerContext is
|
||||
// associated with an application user, but a ViewerContext can represent a
|
||||
// variety of other execution contexts as well (script, test, etc). The main
|
||||
// purpose of a ViewerContext is to assist in the authorization of sensitive
|
||||
// actions.
|
||||
type ViewerContext struct {
|
||||
user *User
|
||||
}
|
||||
|
||||
// JWT returns a JWT token in serialized string form given a ViewerContext as
|
||||
// well as a potential error in the event that things have gone wrong.
|
||||
func (vc *ViewerContext) JWT() (string, error) {
|
||||
return GenerateJWT(vc.user.ID)
|
||||
}
|
||||
|
||||
// IsAdmin indicates whether or not the current user can perform administrative
|
||||
// actions.
|
||||
func (vc *ViewerContext) IsAdmin() bool {
|
||||
if vc.user != nil {
|
||||
return vc.user.Admin && vc.user.Enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UserID is a helper that enables quick access to the user ID of the current
|
||||
// user.
|
||||
func (vc *ViewerContext) UserID() (uint, error) {
|
||||
if vc.user != nil {
|
||||
return vc.user.ID, nil
|
||||
}
|
||||
return 0, errors.New("No user set")
|
||||
}
|
||||
|
||||
func (vc *ViewerContext) CanPerformActions(db *gorm.DB) bool {
|
||||
if vc.user == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !vc.user.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (vc *ViewerContext) IsUserID(id uint) bool {
|
||||
userID, err := vc.UserID()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if userID == id {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (vc *ViewerContext) CanPerformWriteActionOnUser(db *gorm.DB, u *User) bool {
|
||||
return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin())
|
||||
}
|
||||
|
||||
func (vc *ViewerContext) CanPerformReadActionOnUser(db *gorm.DB, u *User) bool {
|
||||
return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin())
|
||||
}
|
||||
|
||||
// GenerateJWT generates a JWT token in serialized string form given a
|
||||
// ViewerContext as well as a potential error in the event that things have
|
||||
// gone wrong.
|
||||
func GenerateJWT(userID uint) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": userID,
|
||||
// "Not Before": https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
"nbf": time.Now().UTC().Unix(),
|
||||
// "Expiration Time": https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
"exp": time.Now().UTC().AddDate(0, 2, 0).Unix(),
|
||||
})
|
||||
|
||||
return token.SignedString([]byte(config.App.JWTKey))
|
||||
}
|
||||
|
||||
// ParseJWT attempts to parse a JWT token in serialized string form into a
|
||||
// JWT token in a deserialized jwt.Token struct.
|
||||
func ParseJWT(token string) (*jwt.Token, error) {
|
||||
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
||||
method, ok := t.Method.(*jwt.SigningMethodHMAC)
|
||||
if !ok || method != jwt.SigningMethodHS256 {
|
||||
return nil, errors.New("Unexpected signing method")
|
||||
}
|
||||
return []byte(config.App.JWTKey), nil
|
||||
})
|
||||
}
|
||||
|
||||
// JWTRenewalMiddleware optimistically tries to renew the user's JWT token.
|
||||
// This allows kolide to have sessions that last forever, assuming that a user
|
||||
// logs in and uses the application within a reasonable time window (which is
|
||||
// defined in the JWT token generation method). If anything goes wrong, this
|
||||
// middleware will back off and defer recovery of the situation to the
|
||||
// downstream web request.
|
||||
func JWTRenewalMiddleware(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
tokenCookie := session.Get("jwt")
|
||||
if tokenCookie == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString, ok := tokenCookie.(string)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := ParseJWT(tokenString)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
|
||||
if !ok || !token.Valid {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
userID := uint(claims["user_id"].(float64))
|
||||
|
||||
jwt, err := GenerateJWT(userID)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
session.Set("jwt", jwt)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// GenerateVC generates a ViewerContext given a user struct
|
||||
func GenerateVC(user *User) *ViewerContext {
|
||||
return &ViewerContext{
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
// EmptyVC is a utility which generates an empty ViewerContext. This is often
|
||||
// used to represent users which are not logged in.
|
||||
func EmptyVC() *ViewerContext {
|
||||
return &ViewerContext{
|
||||
user: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// VC accepts a web request context and a database handler and attempts
|
||||
// to parse a user's jwt token out of the active session, validate the token,
|
||||
// and generate an appropriate ViewerContext given the data in the session.
|
||||
func VC(c *gin.Context, db *gorm.DB) (*ViewerContext, error) {
|
||||
session := GetSession(c)
|
||||
tokenCookie := session.Get("jwt")
|
||||
if tokenCookie == nil {
|
||||
return nil, errors.New("jwt session attribute not set")
|
||||
}
|
||||
|
||||
tokenString, ok := tokenCookie.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("jwt token was not string")
|
||||
}
|
||||
|
||||
token, err := ParseJWT(tokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("Invalid token")
|
||||
}
|
||||
|
||||
userID := uint(claims["user_id"].(float64))
|
||||
var user User
|
||||
err = db.Where("id = ?", userID).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return GenerateVC(&user), nil
|
||||
|
||||
}
|
||||
|
||||
type LoginRequestBody struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func Login(c *gin.Context) {
|
||||
var body LoginRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing Login post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err = db.Where("username = ?", body.Username).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Debugf("User not found: %s", body.Username)
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
err = user.ValidatePassword(body.Password)
|
||||
if err != nil {
|
||||
logrus.Debugf("Invalid password for user: %s", body.Username)
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := GenerateVC(&user).JWT()
|
||||
if err != nil {
|
||||
logrus.Fatalf("Error generating token: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
session := GetSession(c)
|
||||
session.Set("jwt", token)
|
||||
session.Save()
|
||||
|
||||
c.JSON(200, map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"admin": user.Admin,
|
||||
})
|
||||
}
|
||||
|
||||
func Logout(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
session.Clear()
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
const (
|
||||
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
letterIndex = 6 // 6 bits to represent a letter index
|
||||
letterIndexMask = 1<<letterIndex - 1 // All 1-bits, as many as letterIndex
|
||||
letterIndexMax = 63 / letterIndex // # of letter indices fitting in 63 bits
|
||||
)
|
||||
|
||||
var psrngSource = rand.NewSource(time.Now().UnixNano())
|
||||
|
||||
func generateRandomText(length int) string {
|
||||
|
||||
text := make([]byte, length)
|
||||
for i, cache, remain := length-1, psrngSource.Int63(), letterIndexMax; i >= 0; {
|
||||
if remain == 0 {
|
||||
cache, remain = psrngSource.Int63(), letterIndexMax
|
||||
}
|
||||
if idx := int(cache & letterIndexMask); idx < len(letters) {
|
||||
text[i] = letters[idx]
|
||||
i--
|
||||
}
|
||||
cache >>= letterIndex
|
||||
remain--
|
||||
}
|
||||
|
||||
return string(text)
|
||||
}
|
||||
|
||||
func HashPassword(salt, password string) ([]byte, error) {
|
||||
return bcrypt.GenerateFromPassword(
|
||||
[]byte(fmt.Sprintf("%s%s", salt, password)),
|
||||
config.App.BcryptCost,
|
||||
)
|
||||
}
|
||||
|
||||
func SaltAndHashPassword(password string) (string, []byte, error) {
|
||||
salt := generateRandomText(config.App.SaltLength)
|
||||
hashed, err := HashPassword(salt, password)
|
||||
return salt, hashed, err
|
||||
}
|
||||
198
auth_test.go
Normal file
198
auth_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestGenerateRandomText(t *testing.T) {
|
||||
text := generateRandomText(12)
|
||||
if utf8.RuneCountInString(text) != 12 {
|
||||
t.Fatal("generateRandomText generated the wrong length string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateVC(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
tokenString, err := GenerateVC(user).JWT()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
token, err := ParseJWT(tokenString)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
t.Fatal("Token is invalid")
|
||||
}
|
||||
|
||||
userID := uint(claims["user_id"].(float64))
|
||||
if userID != user.ID {
|
||||
t.Fatal("Claims are incorrect. userID is %d", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVC(t *testing.T) {
|
||||
r := createTestServer()
|
||||
r.Use(testSessionMiddleware)
|
||||
r.Use(JWTRenewalMiddleware)
|
||||
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
r.GET("/admin_login", func(c *gin.Context) {
|
||||
token, err := GenerateVC(admin).JWT()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
session := GetSession(c)
|
||||
session.Set("jwt", token)
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/user_login", func(c *gin.Context) {
|
||||
token, err := GenerateVC(user).JWT()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
session := GetSession(c)
|
||||
session.Set("jwt", token)
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/admin", func(c *gin.Context) {
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if !vc.IsAdmin() {
|
||||
t.Fatal("Not admin")
|
||||
}
|
||||
c.String(200, "OK")
|
||||
})
|
||||
|
||||
r.GET("/user", func(c *gin.Context) {
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if vc.IsAdmin() {
|
||||
t.Fatal("Not user")
|
||||
}
|
||||
c.String(200, "OK")
|
||||
})
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest("GET", "/admin_login", nil)
|
||||
r.ServeHTTP(res1, req1)
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest("GET", "/admin", nil)
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
|
||||
res3 := httptest.NewRecorder()
|
||||
req3, _ := http.NewRequest("GET", "/user_login", nil)
|
||||
r.ServeHTTP(res3, req3)
|
||||
|
||||
res4 := httptest.NewRecorder()
|
||||
req4, _ := http.NewRequest("GET", "/user", nil)
|
||||
req4.Header.Set("Cookie", res3.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res4, req4)
|
||||
|
||||
}
|
||||
|
||||
func TestIsUserID(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
vc := GenerateVC(user1)
|
||||
|
||||
if !vc.IsUserID(user1.ID) {
|
||||
t.Fatal("IsUserID failed on same user object")
|
||||
}
|
||||
|
||||
if vc.IsUserID(user1.ID + 1) {
|
||||
t.Fatal("IsUserID passed for incorrect ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanPerformActionsOnUser(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user2, err := NewUser(db, "zwass", "foobar", "zwass@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
adminVC := GenerateVC(admin)
|
||||
user1VC := GenerateVC(user1)
|
||||
|
||||
if !adminVC.CanPerformWriteActionOnUser(db, user1) || !adminVC.CanPerformWriteActionOnUser(db, user2) {
|
||||
t.Fatal("Admin should be able to perform writes on users")
|
||||
}
|
||||
|
||||
if !adminVC.CanPerformReadActionOnUser(db, user1) || !adminVC.CanPerformReadActionOnUser(db, user2) {
|
||||
t.Fatal("Admin should be able to perform reads on users")
|
||||
}
|
||||
|
||||
if user1VC.CanPerformWriteActionOnUser(db, user2) {
|
||||
t.Fatal("user1 shouldn't be able to perform writes on user2")
|
||||
}
|
||||
|
||||
if user1VC.CanPerformReadActionOnUser(db, user2) {
|
||||
t.Fatal("user1 should be able to perform reads on user2")
|
||||
}
|
||||
|
||||
}
|
||||
62
config.go
62
config.go
|
|
@ -5,24 +5,64 @@ import (
|
|||
"io/ioutil"
|
||||
)
|
||||
|
||||
type mysqlConfigData struct {
|
||||
Address string `json:"address"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
}
|
||||
|
||||
type serverConfigData struct {
|
||||
Address string `json:"address"`
|
||||
Cert string `json:"cert"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
type appConfigData struct {
|
||||
BcryptCost int `json:"bcrypt_cost"`
|
||||
SaltLength int `json:"salt_length"`
|
||||
JWTKey string `json:"jwt_key"`
|
||||
}
|
||||
|
||||
type configData struct {
|
||||
MySQL struct {
|
||||
Address string `json:"address"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
} `json:"mysql"`
|
||||
Server struct {
|
||||
Address string `json:"address"`
|
||||
Cert string `json:"cert"`
|
||||
Key string `json:"key"`
|
||||
} `json:"server"`
|
||||
MySQL mysqlConfigData `json:"mysql"`
|
||||
Server serverConfigData `json:"server"`
|
||||
App appConfigData `json:"app"`
|
||||
}
|
||||
|
||||
var (
|
||||
config configData
|
||||
)
|
||||
|
||||
var defaultMysqlConfigData = mysqlConfigData{
|
||||
Address: "127.0.0.1:3306",
|
||||
Username: "kolide",
|
||||
Password: "kolide",
|
||||
Database: "kolide",
|
||||
}
|
||||
|
||||
var defaultServerConfigData = serverConfigData{
|
||||
Address: ":8080",
|
||||
Cert: "./tools/kolide.crt",
|
||||
Key: "./tools/kolide.key",
|
||||
}
|
||||
|
||||
var defaultAppConfigData = appConfigData{
|
||||
BcryptCost: 12,
|
||||
SaltLength: 32,
|
||||
JWTKey: "very secure",
|
||||
}
|
||||
|
||||
var defaultConfigData = configData{
|
||||
MySQL: defaultMysqlConfigData,
|
||||
Server: defaultServerConfigData,
|
||||
App: defaultAppConfigData,
|
||||
}
|
||||
|
||||
func setDefaultConfigValues() {
|
||||
config = defaultConfigData
|
||||
}
|
||||
|
||||
func loadConfig(path string) error {
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
|
|
|
|||
42
kolide.go
42
kolide.go
|
|
@ -2,11 +2,14 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
|
|
@ -39,12 +42,27 @@ var (
|
|||
serve = app.Command("serve", "Run the Kolide server")
|
||||
)
|
||||
|
||||
func init() {
|
||||
// set gin mode to release to silence some superfluous logging
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
// configure logging
|
||||
logrus.AddHook(logContextHook{})
|
||||
|
||||
// populate the global config data structure with sane defaults
|
||||
setDefaultConfigValues()
|
||||
}
|
||||
|
||||
// logContextHook is a logrus hook which is used to contextualize application
|
||||
// logs to include data stuch as line numbers, file names, etc.
|
||||
type logContextHook struct{}
|
||||
|
||||
// Levels defines which levels the logContextHook logrus hook should apply to
|
||||
func (hook logContextHook) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
}
|
||||
|
||||
// Fire defines what the logContextHook should actually do when it is triggered
|
||||
func (hook logContextHook) Fire(entry *logrus.Entry) error {
|
||||
if pc, file, line, ok := runtime.Caller(8); ok {
|
||||
funcName := runtime.FuncForPC(pc).Name()
|
||||
|
|
@ -57,14 +75,15 @@ func (hook logContextHook) Fire(entry *logrus.Entry) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func main() {
|
||||
// configure flag parsing and parse flags
|
||||
app.Version(version)
|
||||
args, err := app.Parse(os.Args[1:])
|
||||
|
||||
// configure logging
|
||||
logrus.AddHook(logContextHook{})
|
||||
|
||||
// configure the application based on the flags that have been set
|
||||
if *debug {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
|
|
@ -76,18 +95,21 @@ func main() {
|
|||
|
||||
// if config hasn't been defined and the example config exists relative to
|
||||
// the binary, it's likely that the tool is being ran right after building
|
||||
// from source so we auto-populate the config path.
|
||||
// from source so we auto-populate the example config path.
|
||||
if *configPath == "" {
|
||||
if _, err = os.Stat("./tools/example_config.json"); err == nil {
|
||||
*configPath = "./tools/example_config.json"
|
||||
} else {
|
||||
logrus.Fatalln("No config file specified. Use --config to specify config path.")
|
||||
}
|
||||
}
|
||||
|
||||
err = loadConfig(*configPath)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Error loading config: %s", err.Error())
|
||||
// if the user has defined a config path OR the example config is found
|
||||
// relative to the binary, load config content from the file. any content
|
||||
// in the config file will overwrite the default values
|
||||
if *configPath != "" {
|
||||
err = loadConfig(*configPath)
|
||||
if err != nil {
|
||||
logrus.Fatalf("Error loading config: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// route the executable based on the sub-command
|
||||
|
|
@ -100,7 +122,7 @@ func main() {
|
|||
dropTables(db)
|
||||
createTables(db)
|
||||
case serve.FullCommand():
|
||||
createServer().RunTLS(
|
||||
CreateServer().RunTLS(
|
||||
config.Server.Address,
|
||||
config.Server.Cert,
|
||||
config.Server.Key)
|
||||
|
|
|
|||
66
models.go
66
models.go
|
|
@ -1,13 +1,49 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
)
|
||||
|
||||
type DBEnvironment int
|
||||
|
||||
const (
|
||||
ProductionDB DBEnvironment = iota
|
||||
TestingDB DBEnvironment = iota
|
||||
)
|
||||
|
||||
var injectedTestDB *gorm.DB
|
||||
|
||||
func GetDB(c *gin.Context) (*gorm.DB, error) {
|
||||
f, ok := c.Get("DB")
|
||||
if !ok {
|
||||
return nil, errors.New("DB was not set on the supplied *gin.Context. Use a middleware to set it.")
|
||||
}
|
||||
switch f.(DBEnvironment) {
|
||||
case ProductionDB:
|
||||
return openDB(config.MySQL.Username, config.MySQL.Password, config.MySQL.Address, config.MySQL.Database)
|
||||
case TestingDB:
|
||||
if injectedTestDB != nil {
|
||||
return injectedTestDB, nil
|
||||
}
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
return nil, errors.New("Could not open a test database")
|
||||
}
|
||||
injectedTestDB = db
|
||||
return injectedTestDB, nil
|
||||
default:
|
||||
return nil, errors.New("GetDB not implemented for DBEnvironment object")
|
||||
}
|
||||
}
|
||||
|
||||
type BaseModel struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
|
|
@ -15,16 +51,6 @@ type BaseModel struct {
|
|||
DeletedAt *time.Time `sql:"index"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
BaseModel
|
||||
Username string `gorm:"not null;unique_index:idx_user_unique_username"`
|
||||
Password string `gorm:"not null"`
|
||||
Salt string `gorm:"not null"`
|
||||
Name string
|
||||
Email string `gorm:"unique_index:idx_user_unique_email"`
|
||||
Admin bool `gorm:"not null"`
|
||||
}
|
||||
|
||||
type ScheduledQuery struct {
|
||||
BaseModel
|
||||
Name string `gorm:"not null"`
|
||||
|
|
@ -162,6 +188,26 @@ func openDB(user, password, address, dbName string) (*gorm.DB, error) {
|
|||
return gorm.Open("mysql", connectionString)
|
||||
}
|
||||
|
||||
func openTestDB() (*gorm.DB, error) {
|
||||
db, err := gorm.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createTables(db)
|
||||
return db, db.Error
|
||||
}
|
||||
|
||||
func ProductionDatabaseMiddleware(c *gin.Context) {
|
||||
c.Set("DB", ProductionDB)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func TestingDatabaseMiddleware(c *gin.Context) {
|
||||
c.Set("DB", TestingDB)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func dropTables(db *gorm.DB) {
|
||||
for _, table := range tables {
|
||||
db.DropTableIfExists(table)
|
||||
|
|
|
|||
57
server.go
57
server.go
|
|
@ -4,8 +4,57 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func attachRoutes(router *gin.Engine) {
|
||||
v1 := router.Group("/api/v1")
|
||||
// ServerError is a helper which accepts a string error and returns a map in
|
||||
// format that is required by gin.Context.JSON
|
||||
func ServerError(e string) *map[string]interface{} {
|
||||
return &map[string]interface{}{
|
||||
"error": e,
|
||||
}
|
||||
}
|
||||
|
||||
// DatabaseError emits a response that is appropriate in the event that a
|
||||
// database failure occurs, a record is not found in the database, etc
|
||||
func DatabaseError(c *gin.Context) {
|
||||
c.JSON(500, ServerError("Database error"))
|
||||
}
|
||||
|
||||
// UnauthorizedError emits a response that is appropriate in the event that a
|
||||
// request is received by a user which is not authorized to carry out the
|
||||
// requested action
|
||||
func UnauthorizedError(c *gin.Context) {
|
||||
c.JSON(401, ServerError("Unauthorized"))
|
||||
}
|
||||
|
||||
func createTestServer() *gin.Engine {
|
||||
server := gin.New()
|
||||
server.Use(TestingDatabaseMiddleware)
|
||||
return server
|
||||
}
|
||||
|
||||
// CreateServer creates a gin.Engine HTTP server and configures it to be in a
|
||||
// state such that it is ready to serve HTTP requests for the kolide application
|
||||
func CreateServer() *gin.Engine {
|
||||
server := gin.New()
|
||||
server.Use(ProductionDatabaseMiddleware)
|
||||
|
||||
v1 := server.Group("/api/v1")
|
||||
|
||||
// Kolide application API endpoints
|
||||
kolide := v1.Group("/kolide")
|
||||
kolide.Use(SessionMiddleware)
|
||||
kolide.Use(JWTRenewalMiddleware)
|
||||
|
||||
kolide.POST("/login", Login)
|
||||
kolide.GET("/logout", Logout)
|
||||
|
||||
kolide.GET("/user", GetUser)
|
||||
kolide.PUT("/user", CreateUser)
|
||||
kolide.PATCH("/user", ModifyUser)
|
||||
kolide.DELETE("/user", DeleteUser)
|
||||
|
||||
kolide.PATCH("/user/password", ResetUserPassword)
|
||||
kolide.PATCH("/user/admin", SetUserAdminState)
|
||||
kolide.PATCH("/user/enabled", SetUserEnabledState)
|
||||
|
||||
// osquery API endpoints
|
||||
osquery := v1.Group("/osquery")
|
||||
|
|
@ -14,10 +63,6 @@ func attachRoutes(router *gin.Engine) {
|
|||
osquery.POST("/log", OsqueryLog)
|
||||
osquery.POST("/distributed/read", OsqueryDistributedRead)
|
||||
osquery.POST("/distributed/write", OsqueryDistributedWrite)
|
||||
}
|
||||
|
||||
func createServer() *gin.Engine {
|
||||
server := gin.New()
|
||||
attachRoutes(server)
|
||||
return server
|
||||
}
|
||||
|
|
|
|||
90
sessions.go
Normal file
90
sessions.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/context"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// GetSession allows you to get the Session object given a web request. This
|
||||
// is often used in HTTP handlers as the main entry point into managing and
|
||||
// manipulating the session
|
||||
func GetSession(c *gin.Context) *Session {
|
||||
return c.MustGet("Session").(*Session)
|
||||
}
|
||||
|
||||
// SessionMiddleware is the middleware used for production session management.
|
||||
// Tests should use `testSessionMiddleware`, which follows the same pattern,
|
||||
// but creates a session configured for testing.
|
||||
func SessionMiddleware(c *gin.Context) {
|
||||
CreateSession("Session", sessions.NewCookieStore([]byte("c")))(c)
|
||||
}
|
||||
|
||||
// CreateSessions is a helper which returns a gin.HandlerFunc which creates
|
||||
// a new session management middleware given the name of the session to manage
|
||||
// and the session storage mechanism. This is commonly used to generate session
|
||||
// middleware given a variety of settings in both production and testing
|
||||
// environments
|
||||
func CreateSession(name string, store sessions.Store) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
s := &Session{name, c.Request, store, nil, c.Writer}
|
||||
c.Set("Session", s)
|
||||
defer context.Clear(c.Request)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Session is a convenience wrapper around gorilla sessions, which is provided
|
||||
// by github.com/gorilla/sessions
|
||||
type Session struct {
|
||||
name string
|
||||
request *http.Request
|
||||
store sessions.Store
|
||||
session *sessions.Session
|
||||
writer http.ResponseWriter
|
||||
}
|
||||
|
||||
// Session returns the gorilla session from the Session struct and allows you
|
||||
// to use any of the functionality of the underlying sessions.Session struct
|
||||
func (s *Session) Session() *sessions.Session {
|
||||
if s.session == nil {
|
||||
var err error
|
||||
s.session, err = s.store.Get(s.request, s.name)
|
||||
if err != nil {
|
||||
logrus.Error(err.Error())
|
||||
}
|
||||
}
|
||||
return s.session
|
||||
}
|
||||
|
||||
// Set simply sets a session key value pair which will be stored in the
|
||||
// current session for later usage
|
||||
func (s *Session) Set(key interface{}, val interface{}) {
|
||||
s.Session().Values[key] = val
|
||||
}
|
||||
|
||||
// Get retrieves a session key value pair which has previously been set
|
||||
func (s *Session) Get(key interface{}) interface{} {
|
||||
return s.Session().Values[key]
|
||||
}
|
||||
|
||||
// Delete deletes a session key value pair which has previously been set
|
||||
func (s *Session) Delete(key interface{}) {
|
||||
delete(s.Session().Values, key)
|
||||
}
|
||||
|
||||
// Clear deletes all session key value pairs that are set
|
||||
func (s *Session) Clear() {
|
||||
for key := range s.Session().Values {
|
||||
s.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Save writes the session, which is required after altering the session in any
|
||||
// way
|
||||
func (s *Session) Save() error {
|
||||
return s.Session().Save(s.request, s.writer)
|
||||
}
|
||||
179
sessions_test.go
Normal file
179
sessions_test.go
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const testSessionName = "TestSession"
|
||||
|
||||
func getTestStore() sessions.Store {
|
||||
return sessions.NewCookieStore([]byte("test"))
|
||||
}
|
||||
|
||||
func testSessionMiddleware(c *gin.Context) {
|
||||
CreateSession(testSessionName, getTestStore())(c)
|
||||
}
|
||||
|
||||
func TestSessionGetSet(t *testing.T) {
|
||||
r := createTestServer()
|
||||
r.Use(testSessionMiddleware)
|
||||
r.Use(JWTRenewalMiddleware)
|
||||
|
||||
r.GET("/set", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
session.Set("key", "foobar")
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/get", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
if session.Get("key") != "foobar" {
|
||||
t.Fatal("Session writing failed")
|
||||
}
|
||||
c.String(200, "OK")
|
||||
})
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest("GET", "/set", nil)
|
||||
r.ServeHTTP(res1, req1)
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest("GET", "/get", nil)
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
}
|
||||
|
||||
func TestSessionDeleteKey(t *testing.T) {
|
||||
r := createTestServer()
|
||||
r.Use(testSessionMiddleware)
|
||||
r.Use(JWTRenewalMiddleware)
|
||||
|
||||
r.GET("/set", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
session.Set("key", "foobar")
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/delete", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
session.Delete("key")
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/get", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
if session.Get("key") != nil {
|
||||
t.Fatal("Session deleting failed")
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest("GET", "/set", nil)
|
||||
r.ServeHTTP(res1, req1)
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest("GET", "/delete", nil)
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
|
||||
res3 := httptest.NewRecorder()
|
||||
req3, _ := http.NewRequest("GET", "/get", nil)
|
||||
req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res3, req3)
|
||||
}
|
||||
|
||||
func TestSessionFlashes(t *testing.T) {
|
||||
r := createTestServer()
|
||||
r.Use(testSessionMiddleware)
|
||||
r.Use(JWTRenewalMiddleware)
|
||||
|
||||
r.GET("/set", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
session.Session().AddFlash("foobar")
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/flash", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
l := len(session.Session().Flashes())
|
||||
if l != 1 {
|
||||
t.Fatal("Flashes count does not equal 1. Equals ", l)
|
||||
}
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/check", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
l := len(session.Session().Flashes())
|
||||
if l != 0 {
|
||||
t.Fatal("flashes count is not 0 after reading. Equals ", l)
|
||||
}
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest("GET", "/set", nil)
|
||||
r.ServeHTTP(res1, req1)
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest("GET", "/flash", nil)
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
|
||||
res3 := httptest.NewRecorder()
|
||||
req3, _ := http.NewRequest("GET", "/check", nil)
|
||||
req3.Header.Set("Cookie", res2.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res3, req3)
|
||||
}
|
||||
|
||||
func TestSessionClear(t *testing.T) {
|
||||
data := map[string]string{
|
||||
"key": "val",
|
||||
"foo": "bar",
|
||||
}
|
||||
r := createTestServer()
|
||||
store := getTestStore()
|
||||
r.Use(CreateSession(testSessionName, store))
|
||||
r.Use(JWTRenewalMiddleware)
|
||||
|
||||
r.GET("/set", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
for k, v := range data {
|
||||
session.Set(k, v)
|
||||
}
|
||||
session.Clear()
|
||||
session.Save()
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/check", func(c *gin.Context) {
|
||||
session := GetSession(c)
|
||||
for k, v := range data {
|
||||
if session.Get(k) == v {
|
||||
t.Fatal("Session clear failed")
|
||||
}
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest("GET", "/set", nil)
|
||||
r.ServeHTTP(res1, req1)
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest("GET", "/check", nil)
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
}
|
||||
|
|
@ -9,5 +9,10 @@
|
|||
"address": ":8080",
|
||||
"cert": "./tools/kolide.crt",
|
||||
"key": "./tools/kolide.key"
|
||||
},
|
||||
"app": {
|
||||
"bcrypt_cost": 12,
|
||||
"salt_length": 32,
|
||||
"jwt_key": "very secure"
|
||||
}
|
||||
}
|
||||
435
users.go
Normal file
435
users.go
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// User is the model struct which represents a kolide user
|
||||
type User struct {
|
||||
BaseModel
|
||||
Username string `gorm:"not null;unique_index:idx_user_unique_username"`
|
||||
Password []byte `gorm:"not null"`
|
||||
Salt string `gorm:"not null"`
|
||||
Name string
|
||||
Email string `gorm:"not null;unique_index:idx_user_unique_email"`
|
||||
Admin bool `gorm:"not null"`
|
||||
Enabled bool `gorm:"not null"`
|
||||
NeedsPasswordReset bool
|
||||
}
|
||||
|
||||
// NewUser is a wrapper around the creation of a new user.
|
||||
// NewUser exists largely to allow the API to simply accept a string password
|
||||
// while using the applications password hashing mechanisms to salt and hash the
|
||||
// password.
|
||||
func NewUser(db *gorm.DB, username, password, email string, admin, needsPasswordReset bool) (*User, error) {
|
||||
salt, hash, err := SaltAndHashPassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := &User{
|
||||
Username: username,
|
||||
Password: hash,
|
||||
Salt: salt,
|
||||
Email: email,
|
||||
Admin: admin,
|
||||
Enabled: true,
|
||||
NeedsPasswordReset: needsPasswordReset,
|
||||
}
|
||||
|
||||
err = db.Create(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ValidatePassword accepts a potential password for a given user and attempts
|
||||
// to validate it against the hash stored in the database after joining the
|
||||
// supplied password with the stored password salt
|
||||
func (u *User) ValidatePassword(password string) error {
|
||||
saltAndPass := []byte(fmt.Sprintf("%s%s", u.Salt, password))
|
||||
return bcrypt.CompareHashAndPassword(u.Password, saltAndPass)
|
||||
}
|
||||
|
||||
// SetPassword accepts a new password for a user object and updates the salt
|
||||
// and hash for that user in the database. This function assumes that the
|
||||
// appropriate authorization checks have already occurred by the caller.
|
||||
func (u *User) SetPassword(db *gorm.DB, password string) error {
|
||||
salt, hash, err := SaltAndHashPassword(password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Salt = salt
|
||||
u.Password = hash
|
||||
return db.Save(u).Error
|
||||
}
|
||||
|
||||
// MakeAdmin is a simple wrapper around promoting a user to an administrator.
|
||||
// If the user is already an admin, this function will return without modifying
|
||||
// the database
|
||||
func (u *User) MakeAdmin(db *gorm.DB) error {
|
||||
if !u.Admin {
|
||||
u.Admin = true
|
||||
return db.Save(&u).Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GetUserRequestBody struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) {
|
||||
var body GetUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing GetUser post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c) // TODO tampered?
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err = db.Where("id = ?", body.ID).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error finding user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformReadActionOnUser(db, &user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
"name": user.Name,
|
||||
"email": user.Email,
|
||||
"admin": user.Admin,
|
||||
"enabled": user.Enabled,
|
||||
"needs_password_reset": user.NeedsPasswordReset,
|
||||
})
|
||||
}
|
||||
|
||||
type CreateUserRequestBody struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
Admin bool `json:"admin"`
|
||||
NeedsPasswordReset bool `json:"needs_password_reset"`
|
||||
}
|
||||
|
||||
func CreateUser(c *gin.Context) {
|
||||
var body CreateUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing CreateUser post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error creating new user: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type ModifyUserRequestBody struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func ModifyUser(c *gin.Context) {
|
||||
var body ModifyUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing ModifyUser post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err = db.Where("id = ?", body.ID).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error finding user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformWriteActionOnUser(db, &user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.Save(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error updating user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type DeleteUserRequestBody struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
}
|
||||
|
||||
func DeleteUser(c *gin.Context) {
|
||||
var body DeleteUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing DeleteUser post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err = db.Where("id = ?", body.ID).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error finding user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.Delete(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error deleting user from database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type ResetPasswordRequestBody struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
PasswordConfim string `json:"password_confirm" binding:"required"`
|
||||
}
|
||||
|
||||
func ResetUserPassword(c *gin.Context) {
|
||||
var body ResetPasswordRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing ResetPassword post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if body.Password != body.PasswordConfim {
|
||||
c.JSON(406, map[string]interface{}{"error": "Passwords do not match"})
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
var user User
|
||||
err = db.Where("id = ?", body.ID).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error finding user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformWriteActionOnUser(db, &user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
err = user.SetPassword(db, body.Password)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error setting user password: %s", err.Error())
|
||||
// xxx don't try to write to the db?
|
||||
}
|
||||
|
||||
err = db.Save(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error updating user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type SetUserAdminStateRequestBody struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Admin bool `json:"admin" binding:"required"`
|
||||
}
|
||||
|
||||
func SetUserAdminState(c *gin.Context) {
|
||||
var body SetUserAdminStateRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing SetUserAdminState post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err = db.Where("id = ?", body.ID).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error finding user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
user.Admin = body.Admin
|
||||
err = db.Save(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error updating user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type SetUserEnabledStateRequestBody struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Enabled bool `json:"enabled" binding:"required"`
|
||||
}
|
||||
|
||||
func SetUserEnabledState(c *gin.Context) {
|
||||
var body SetUserEnabledStateRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Error parsing SetUserEnabledState post body: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db, err := GetDB(c)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not open database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
vc, err := VC(c, db)
|
||||
if err != nil {
|
||||
logrus.Errorf("Could not create VC: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var user User
|
||||
err = db.Where("id = ?", body.ID).First(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error finding user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
user.Enabled = body.Enabled
|
||||
err = db.Save(&user).Error
|
||||
if err != nil {
|
||||
logrus.Errorf("Error updating user in database: %s", err.Error())
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
309
users_test.go
Normal file
309
users_test.go
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if user.Username != "marpaia" {
|
||||
t.Fatalf("Username is not what's expected: %s", user.Username)
|
||||
}
|
||||
|
||||
if user.Email != "mike@kolide.co" {
|
||||
t.Fatalf("Email is not what's expected: %s", user.Email)
|
||||
}
|
||||
|
||||
if !user.Admin {
|
||||
t.Fatal("User is not an admin")
|
||||
}
|
||||
|
||||
var verify User
|
||||
db.Where("username = ?", "marpaia").First(&verify)
|
||||
if verify.ID != user.ID {
|
||||
t.Fatal("Couldn't select user back from database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePassword(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = user.ValidatePassword("foobar")
|
||||
if err != nil {
|
||||
t.Fatal("Password validation failed")
|
||||
}
|
||||
|
||||
err = user.ValidatePassword("not correct")
|
||||
if err == nil {
|
||||
t.Fatal("Incorrect password worked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeAdmin(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if user.Admin {
|
||||
t.Fatal("Admin should be false")
|
||||
}
|
||||
|
||||
err = user.MakeAdmin(db)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if !user.Admin {
|
||||
t.Fatal("Admin should be true")
|
||||
}
|
||||
|
||||
var verify User
|
||||
db.Where("admin = ?", true).First(&verify)
|
||||
|
||||
if user.ID != verify.ID {
|
||||
t.Fatal("Users don't match")
|
||||
}
|
||||
|
||||
if !verify.Admin {
|
||||
t.Fatal("User wasn't set as admin in the database")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestUpdatingUser(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user.Email = "marpaia@kolide.co"
|
||||
err = db.Save(user).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if user.Email != "marpaia@kolide.co" {
|
||||
t.Fatal("user.Email was reset")
|
||||
}
|
||||
|
||||
var verify User
|
||||
err = db.Where("id = ?", user.ID).First(&verify).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if verify.Email != "marpaia@kolide.co" {
|
||||
t.Fatalf("user's email was not updated in the DB: %s", verify.Email)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDeletingUser(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
var verify1 User
|
||||
err = db.Where("username = ?", "marpaia").First(&verify1).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if verify1.ID != user.ID {
|
||||
t.Fatal("users are not the same")
|
||||
}
|
||||
|
||||
err = db.Delete(&user).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
var verify2 User
|
||||
err = db.Where("username = ?", "marpaia").First(&verify2).Error
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
t.Fatal("Record was not deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPassword(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
user, err := NewUser(db, "marpaia", "foobar", "mike@kolide.co", false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = user.ValidatePassword("foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = user.SetPassword(db, "baz")
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = user.ValidatePassword("baz")
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
var verify User
|
||||
err = db.Where("username = ?", "marpaia").First(&verify).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
err = verify.ValidatePassword("baz")
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserManagementIntegration(t *testing.T) {
|
||||
r := createTestServer()
|
||||
r.Use(testSessionMiddleware)
|
||||
r.Use(JWTRenewalMiddleware)
|
||||
|
||||
db, err := openTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
injectedTestDB = db
|
||||
|
||||
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
_ = admin
|
||||
|
||||
r.POST("/login", Login)
|
||||
r.GET("/logout", Logout)
|
||||
|
||||
r.GET("/user", GetUser)
|
||||
r.PUT("/user", CreateUser)
|
||||
r.PATCH("/user", ModifyUser)
|
||||
r.DELETE("/user", DeleteUser)
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
body1, err := json.Marshal(LoginRequestBody{
|
||||
Username: "admin",
|
||||
Password: "foobar",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
buff1 := new(bytes.Buffer)
|
||||
buff1.Write(body1)
|
||||
req1, _ := http.NewRequest("POST", "/login", buff1)
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(res1, req1)
|
||||
if res1.Code != 200 {
|
||||
t.Fatalf("Response code: %d", res1.Code)
|
||||
}
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
body2, err := json.Marshal(CreateUserRequestBody{
|
||||
Username: "marpaia",
|
||||
Password: "foobar",
|
||||
Email: "mike@kolide.co",
|
||||
Admin: false,
|
||||
NeedsPasswordReset: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
buff2 := new(bytes.Buffer)
|
||||
buff2.Write(body2)
|
||||
req2, _ := http.NewRequest("PUT", "/user", buff2)
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
|
||||
res3 := httptest.NewRecorder()
|
||||
body3, err := json.Marshal(CreateUserRequestBody{
|
||||
Username: "admin2",
|
||||
Password: "foobar",
|
||||
Email: "admin2@kolide.co",
|
||||
Admin: true,
|
||||
NeedsPasswordReset: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
buff3 := new(bytes.Buffer)
|
||||
buff3.Write(body3)
|
||||
req3, _ := http.NewRequest("PUT", "/user", buff3)
|
||||
req3.Header.Set("Content-Type", "application/json")
|
||||
req3.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res3, req3)
|
||||
|
||||
var user User
|
||||
err = db.Where("username = ?", "marpaia").First(&user).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if user.Email != "mike@kolide.co" {
|
||||
t.Fatalf("user's email was not set in the DB: %s", user.Email)
|
||||
}
|
||||
if user.Admin {
|
||||
t.Fatal("user shouldn't be admin")
|
||||
}
|
||||
|
||||
var admin2 User
|
||||
err = db.Where("username = ?", "admin2").First(&admin2).Error
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if admin2.Email != "admin2@kolide.co" {
|
||||
t.Fatalf("admin2's email was not set in the DB: %s", admin2.Email)
|
||||
}
|
||||
if !admin2.Admin {
|
||||
t.Fatal("admin2 should be admin")
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in a new issue