diff --git a/auth.go b/auth.go index 1a04384549..80b6e6cbec 100644 --- a/auth.go +++ b/auth.go @@ -7,8 +7,8 @@ import ( "fmt" "github.com/Sirupsen/logrus" - "github.com/dgrijalva/jwt-go" "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" "golang.org/x/crypto/bcrypt" ) @@ -99,32 +99,23 @@ func EmptyVC() *ViewerContext { // and generate an appropriate ViewerContext given the data in the session. func VC(c *gin.Context) *ViewerContext { sm := NewSessionManager(c) - return sm.VC() + session, err := sm.Session() + if err != nil { + return EmptyVC() + } + return VCForID(GetDB(c), session.UserID) } -//////////////////////////////////////////////////////////////////////////////// -// JSON Web Tokens -//////////////////////////////////////////////////////////////////////////////// +func VCForID(db *gorm.DB, id uint) *ViewerContext { + // Generating a VC requires a user struct. Attempt to populate one using + // the user id of the current session holder + user := &User{BaseModel: BaseModel{ID: id}} + err := db.Where(user).First(user).Error + if err != nil { + return EmptyVC() + } -// Given a session key create a JWT to be delivered to the client -func GenerateJWT(sessionKey string) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "session_key": sessionKey, - }) - - return token.SignedString([]byte(config.App.JWTKey)) -} - -// ParseJWT attempts to parse a JWT token in serialized string form into a -// JWT token in a deserialized jwt.Token struct. -func ParseJWT(token string) (*jwt.Token, error) { - return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { - method, ok := t.Method.(*jwt.SigningMethodHMAC) - if !ok || method != jwt.SigningMethodHS256 { - return nil, errors.New("Unexpected signing method") - } - return []byte(config.App.JWTKey), nil - }) + return GenerateVC(user) } //////////////////////////////////////////////////////////////////////////////// @@ -194,7 +185,7 @@ func Login(c *gin.Context) { } sm := NewSessionManager(c) - sm.MakeSessionForUser(user) + sm.MakeSessionForUserID(user.ID) err = sm.Save() if err != nil { DatabaseError(c) diff --git a/auth_test.go b/auth_test.go index 861593636c..3751a708fe 100644 --- a/auth_test.go +++ b/auth_test.go @@ -5,7 +5,6 @@ import ( "net/http/httptest" "testing" - "github.com/dgrijalva/jwt-go" "github.com/gin-gonic/gin" ) @@ -24,23 +23,6 @@ func TestGenerateVC(t *testing.T) { } -func TestGenerateJWT(t *testing.T) { - tokenString, err := GenerateJWT("4") - token, err := ParseJWT(tokenString) - if err != nil { - t.Fatal(err.Error()) - } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok || !token.Valid { - t.Fatal("Token is invalid") - } - - sessionKey := claims["session_key"].(string) - if sessionKey != "4" { - t.Fatalf("Claims are incorrect. session key is %s", sessionKey) - } -} - func TestVC(t *testing.T) { db := openTestDB(t) r := createEmptyTestServer(db) @@ -57,7 +39,7 @@ func TestVC(t *testing.T) { r.GET("/admin_login", func(c *gin.Context) { sm := NewSessionManager(c) - sm.MakeSessionForUser(admin) + sm.MakeSessionForUserID(admin.ID) err := sm.Save() if err != nil { t.Fatal(err.Error()) @@ -67,7 +49,7 @@ func TestVC(t *testing.T) { r.GET("/user_login", func(c *gin.Context) { sm := NewSessionManager(c) - sm.MakeSessionForUser(user) + sm.MakeSessionForUserID(user.ID) err := sm.Save() if err != nil { t.Fatal(err.Error()) diff --git a/models.go b/models.go index 3c9257203c..8263c8ed1b 100644 --- a/models.go +++ b/models.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" _ "github.com/jinzhu/gorm/dialects/mysql" _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/kolide/kolide-ose/sessions" ) // Get the database connection from the context, or panic @@ -143,7 +144,7 @@ type Decorator struct { var tables = [...]interface{}{ &User{}, - &Session{}, + &sessions.Session{}, &ScheduledQuery{}, &Pack{}, &DiscoveryQuery{}, diff --git a/server.go b/server.go index 006b1ae19b..4c10332f5f 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/contrib/ginrus" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" + "github.com/kolide/kolide-ose/sessions" ) // ServerError is a helper which accepts a string error and returns a map in @@ -54,6 +55,17 @@ func DatabaseMiddleware(db *gorm.DB) gin.HandlerFunc { } } +// NewSessionManager allows you to get a SessionManager instance for a given +// web request. Unless you're interacting with login, logout, or core auth +// code, this should be abstracted by the ViewerContext pattern. +func NewSessionManager(c *gin.Context) *sessions.SessionManager { + return &sessions.SessionManager{ + Request: c.Request, + Backend: GetSessionBackend(c), + Writer: c.Writer, + } +} + // CreateServer creates a gin.Engine HTTP server and configures it to be in a // state such that it is ready to serve HTTP requests for the kolide application func CreateServer(db *gorm.DB, w io.Writer) *gin.Engine { @@ -61,6 +73,13 @@ func CreateServer(db *gorm.DB, w io.Writer) *gin.Engine { server.Use(DatabaseMiddleware(db)) server.Use(SessionBackendMiddleware) + sessions.Configure(&sessions.SessionConfiguration{ + CookieName: "KolideSession", + JWTKey: config.App.JWTKey, + SessionKeySize: config.App.SessionKeySize, + Lifespan: config.App.SessionExpirationSeconds, + }) + // TODO: The following loggers are not synchronized with each other or // logrus.StandardLogger() used through the rest of the codebase. As // such, their output may become intermingled. diff --git a/sessions.go b/sessions.go deleted file mode 100644 index 2f31669c45..0000000000 --- a/sessions.go +++ /dev/null @@ -1,570 +0,0 @@ -package main - -import ( - "errors" - "net/http" - "time" - - "github.com/Sirupsen/logrus" - "github.com/dgrijalva/jwt-go" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" -) - -var ( - // An error returned by SessionBackend.Get() if no session record was found - // in the database - ErrNoActiveSession = errors.New("Active session is not present in the database") - - // An error returned by SessionBackend methods when no session object has - // been created yet but the requested action requires one - ErrSessionNotCreated = errors.New("The session has not been created") - - // An error returned by SessionBackend.Get() when a session is requested but - // it has expired - ErrSessionExpired = errors.New("The session has expired") -) - -const ( - // The name of the session cookie - CookieName = "KolideSession" -) - -// Session is the model object which represents what an active session is -type Session struct { - BaseModel - UserID uint `gorm:"not null"` - Key string `gorm:"not null;unique_index:idx_session_unique_key"` - AccessedAt time.Time -} - -//////////////////////////////////////////////////////////////////////////////// -// Managing sessions -//////////////////////////////////////////////////////////////////////////////// - -// SessionManager is a management object which helps with the administration of -// sessions within the application. Use NewSessionManager to create an instance -type SessionManager struct { - backend SessionBackend - request *http.Request - writer http.ResponseWriter - session *Session - vc *ViewerContext - db *gorm.DB -} - -// NewSessionManager allows you to get a SessionManager instance for a given -// web request. Unless you're interacting with login, logout, or core auth -// code, this should be abstracted by the ViewerContext pattern. -func NewSessionManager(c *gin.Context) *SessionManager { - return &SessionManager{ - request: c.Request, - backend: GetSessionBackend(c), - writer: c.Writer, - db: GetDB(c), - } -} - -// Get the ViewerContext instance for a user represented by the active session -func (sm *SessionManager) VC() *ViewerContext { - if sm.session == nil { - cookie, err := sm.request.Cookie(CookieName) - if err != nil { - switch err { - case http.ErrNoCookie: - // No cookie was set - return EmptyVC() - default: - // Something went wrong and the cookie may or may not be set - logrus.Errorf("Couldn't get cookie: %s", err.Error()) - return EmptyVC() - } - } - - token, err := ParseJWT(cookie.Value) - if err != nil { - logrus.Errorf("Couldn't parse JWT token string from cookie: %s", err.Error()) - return EmptyVC() - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - logrus.Error("Could not parse the claims from the JWT token") - return EmptyVC() - } - - sessionKeyClaim, ok := claims["session_key"] - if !ok { - logrus.Warn("JWT did not have session_key claim") - return EmptyVC() - } - - sessionKey, ok := sessionKeyClaim.(string) - if !ok { - logrus.Warn("JWT session_key claim was not a string") - return EmptyVC() - } - - session, err := sm.backend.FindKey(sessionKey) - if err != nil { - switch err { - case ErrNoActiveSession: - // If the code path got this far, it's likely that the user was logged - // in some time in the past, but their session has been expired since - // their last usage of the application - return EmptyVC() - default: - logrus.Errorf("Couldn't call Get on backend object: %s", err.Error()) - return EmptyVC() - } - } - sm.session = session - } - - if sm.vc == nil { - // Generating a VC requires a user struct. Attempt to populate one using - // the user id of the current session holder - user := &User{BaseModel: BaseModel{ID: sm.session.UserID}} - err := sm.db.Where(user).First(user).Error - if err != nil { - return EmptyVC() - } - - sm.vc = GenerateVC(user) - } - - return sm.vc -} - -// MakeSessionForUserID creates a session in the database for a given user id. -// You must call Save() after calling this. -func (sm *SessionManager) MakeSessionForUserID(id uint) error { - session, err := sm.backend.Create(id) - if err != nil { - return err - } - sm.session = session - return nil -} - -// MakeSessionForUserID creates a session in the database for a given user -// You must call Save() after calling this. -func (sm *SessionManager) MakeSessionForUser(u *User) error { - return sm.MakeSessionForUserID(u.ID) -} - -// Save writes the current session to a token and delivers the token as a cookie -// to the user. Save must be called after every write action on this struct -// (MakeSessionForUser, Destroy, etc.) -func (sm *SessionManager) Save() error { - token, err := GenerateJWT(sm.session.Key) - if err != nil { - return err - } - - // TODO: set proper flags on cookie for maximum security - http.SetCookie(sm.writer, &http.Cookie{ - Name: CookieName, - Value: token, - }) - - return nil -} - -// Destroy deletes the active session from the database and erases the session -// instance from this object's access. You must call Save() after calling this. -func (sm *SessionManager) Destroy() error { - if sm.backend != nil { - err := sm.backend.Destroy(sm.session) - if err != nil { - return err - } - } - return nil -} - -//////////////////////////////////////////////////////////////////////////////// -// Session Backend API -//////////////////////////////////////////////////////////////////////////////// - -// SessionBackend is the abstract interface that all session backends must -// conform to. SessionBackend instances are only expected to exist within the -// context of a single request. -type SessionBackend interface { - // Given a session key, find and return a session object or an error if one - // could not be found for the given key - FindKey(key string) (*Session, error) - - // Given a session id, find and return a session object or an error if one - // could not be found for the given id - FindID(id uint) (*Session, error) - - // Find all of the active sessions for a given user - FindAllForUser(id uint) ([]*Session, error) - - // Create a session object tied to the given user ID - Create(userID uint) (*Session, error) - - // Destroy the currently tracked session - Destroy(session *Session) error - - // Destroy all of the sessions for a given user - DestroyAllForUser(id uint) error - - // Mark the currently tracked session as access to extend expiration - MarkAccessed(session *Session) error -} - -//////////////////////////////////////////////////////////////////////////////// -// Session Backend Plugins -//////////////////////////////////////////////////////////////////////////////// - -// GormSessionBackend stores sessions using a pre-instantiated gorm database -// object -type GormSessionBackend struct { - db *gorm.DB -} - -func (s *GormSessionBackend) validate(session *Session) error { - if time.Since(session.AccessedAt).Seconds() >= config.App.SessionExpirationSeconds { - err := s.db.Delete(session).Error - if err != nil { - return err - } - return ErrSessionExpired - } - - err := s.MarkAccessed(session) - if err != nil { - return err - } - - return nil -} - -func (s *GormSessionBackend) FindID(id uint) (*Session, error) { - session := &Session{ - BaseModel: BaseModel{ - ID: id, - }, - } - - err := s.db.Where(session).First(session).Error - if err != nil { - switch err { - case gorm.ErrRecordNotFound: - return nil, ErrNoActiveSession - default: - return nil, err - } - } - - err = s.validate(session) - if err != nil { - return nil, err - } - - return session, nil - -} - -func (s *GormSessionBackend) FindKey(key string) (*Session, error) { - session := &Session{ - Key: key, - } - - err := s.db.Where(session).First(session).Error - if err != nil { - switch err { - case gorm.ErrRecordNotFound: - return nil, ErrNoActiveSession - default: - return nil, err - } - } - - err = s.validate(session) - if err != nil { - return nil, err - } - - return session, nil -} - -func (s *GormSessionBackend) FindAllForUser(id uint) ([]*Session, error) { - var sessions []*Session - err := s.db.Where("user_id = ?", id).Find(&sessions).Error - return sessions, err -} - -func (s *GormSessionBackend) Create(userID uint) (*Session, error) { - key, err := generateRandomText(config.App.SessionKeySize) - if err != nil { - return nil, err - } - - session := &Session{ - UserID: userID, - Key: key, - } - - err = s.db.Create(session).Error - if err != nil { - return nil, err - } - - err = s.MarkAccessed(session) - if err != nil { - return nil, err - } - - return session, nil -} - -func (s *GormSessionBackend) Destroy(session *Session) error { - err := s.db.Delete(session).Error - if err != nil { - return err - } - - return nil -} - -func (s *GormSessionBackend) DestroyAllForUser(id uint) error { - return s.db.Delete(&Session{}, "user_id = ?", id).Error -} - -func (s *GormSessionBackend) MarkAccessed(session *Session) error { - session.AccessedAt = time.Now().UTC() - return s.db.Save(session).Error -} - -//////////////////////////////////////////////////////////////////////////////// -// Session management HTTP endpoints -//////////////////////////////////////////////////////////////////////////////// - -// Setting the session backend via a middleware -func SessionBackendMiddleware(c *gin.Context) { - db := GetDB(c) - c.Set("SessionBackend", &GormSessionBackend{db}) - c.Next() -} - -// Get the database connection from the context, or panic -func GetSessionBackend(c *gin.Context) SessionBackend { - return c.MustGet("SessionBackend").(SessionBackend) -} - -//////////////////////////////////////////////////////////////////////////////// -// Session management HTTP endpoints -//////////////////////////////////////////////////////////////////////////////// - -type DeleteSessionRequestBody struct { - SessionID uint `json:"session_id" binding:"required"` -} - -func DeleteSession(c *gin.Context) { - var body DeleteSessionRequestBody - err := c.BindJSON(&body) - if err != nil { - logrus.Errorf(err.Error()) - return - } - - vc := VC(c) - if !vc.CanPerformActions() { - UnauthorizedError(c) - return - } - - sb := GetSessionBackend(c) - - session, err := sb.FindID(body.SessionID) - if err != nil { - - } - - db := GetDB(c) - user := &User{ - BaseModel: BaseModel{ - ID: session.UserID, - }, - } - err = db.Where(user).First(user).Error - if err != nil { - DatabaseError(c) - return - } - - if !vc.CanPerformWriteActionOnUser(user) { - UnauthorizedError(c) - return - } - - err = sb.Destroy(session) - if err != nil { - DatabaseError(c) - return - } - - c.JSON(200, nil) -} - -type DeleteSessionsForUserRequestBody struct { - ID uint `json:"id"` - Username string `json:"username"` -} - -func DeleteSessionsForUser(c *gin.Context) { - var body DeleteSessionsForUserRequestBody - err := c.BindJSON(&body) - if err != nil { - logrus.Errorf(err.Error()) - } - - vc := VC(c) - if !vc.CanPerformActions() { - UnauthorizedError(c) - return - } - - db := GetDB(c) - var user User - user.ID = body.ID - user.Username = body.Username - err = db.Where(&user).First(&user).Error - if err != nil { - DatabaseError(c) - return - } - - if !vc.CanPerformWriteActionOnUser(&user) { - UnauthorizedError(c) - return - } - - sb := GetSessionBackend(c) - err = sb.DestroyAllForUser(user.ID) - err = db.Delete(&Session{}, "user_id = ?", user.ID).Error - if err != nil { - DatabaseError(c) - return - } - - c.JSON(200, nil) - -} - -type GetInfoAboutSessionRequestBody struct { - SessionKey string `json:"session_key" binding:"required"` -} - -type SessionInfoResponseBody struct { - SessionID uint `json:"session_id"` - UserID uint `json:"user_id"` - CreatedAt time.Time `json:"created_at"` - AccessedAt time.Time `json:"created_at"` -} - -func GetInfoAboutSession(c *gin.Context) { - var body GetInfoAboutSessionRequestBody - err := c.BindJSON(&body) - if err != nil { - logrus.Errorf(err.Error()) - return - } - - vc := VC(c) - if !vc.CanPerformActions() { - UnauthorizedError(c) - return - } - - sb := GetSessionBackend(c) - session, err := sb.FindKey(body.SessionKey) - if err != nil { - DatabaseError(c) - return - } - - db := GetDB(c) - var user User - user.ID = session.UserID - err = db.Where(&user).First(&user).Error - if err != nil { - DatabaseError(c) - return - } - - if !vc.IsAdmin() && !vc.IsUserID(user.ID) { - UnauthorizedError(c) - return - } - - c.JSON(200, &SessionInfoResponseBody{ - SessionID: session.ID, - UserID: session.UserID, - CreatedAt: session.CreatedAt, - AccessedAt: session.AccessedAt, - }) -} - -type GetInfoAboutSessionsForUserRequestBody struct { - ID uint `json:"id"` - Username string `json:"username"` -} - -type GetInfoAboutSessionsForUserResponseBody struct { - Sessions []*SessionInfoResponseBody `json:"sessions"` -} - -func GetInfoAboutSessionsForUser(c *gin.Context) { - var body GetInfoAboutSessionsForUserRequestBody - err := c.BindJSON(&body) - if err != nil { - logrus.Errorf(err.Error()) - return - } - - vc := VC(c) - if !vc.CanPerformActions() { - UnauthorizedError(c) - return - } - - db := GetDB(c) - var user User - user.ID = body.ID - user.Username = body.Username - err = db.Where(&user).First(&user).Error - if err != nil { - DatabaseError(c) - return - } - - if !vc.IsAdmin() && !vc.IsUserID(user.ID) { - UnauthorizedError(c) - return - } - - sb := GetSessionBackend(c) - sessions, err := sb.FindAllForUser(user.ID) - if err != nil { - DatabaseError(c) - return - } - - var response []*SessionInfoResponseBody - for _, session := range sessions { - response = append(response, &SessionInfoResponseBody{ - SessionID: session.ID, - UserID: session.UserID, - CreatedAt: session.CreatedAt, - AccessedAt: session.AccessedAt, - }) - } - - c.JSON(200, &GetInfoAboutSessionsForUserResponseBody{ - Sessions: response, - }) -} diff --git a/sessions/backends.go b/sessions/backends.go new file mode 100644 index 0000000000..680c31f18b --- /dev/null +++ b/sessions/backends.go @@ -0,0 +1,164 @@ +package sessions + +import ( + "crypto/rand" + "encoding/base64" + "time" + + "github.com/jinzhu/gorm" +) + +//////////////////////////////////////////////////////////////////////////////// +// Session Backend API +//////////////////////////////////////////////////////////////////////////////// + +// SessionBackend is the abstract interface that all session backends must +// conform to. SessionBackend instances are only expected to exist within the +// context of a single request. +type SessionBackend interface { + // Given a session key, find and return a session object or an error if one + // could not be found for the given key + FindKey(key string) (*Session, error) + + // Given a session id, find and return a session object or an error if one + // could not be found for the given id + FindID(id uint) (*Session, error) + + // Find all of the active sessions for a given user + FindAllForUser(id uint) ([]*Session, error) + + // Create a session object tied to the given user ID + Create(userID uint) (*Session, error) + + // Destroy the currently tracked session + Destroy(session *Session) error + + // Destroy all of the sessions for a given user + DestroyAllForUser(id uint) error + + // Mark the currently tracked session as access to extend expiration + MarkAccessed(session *Session) error +} + +//////////////////////////////////////////////////////////////////////////////// +// Session Backend Plugins +//////////////////////////////////////////////////////////////////////////////// + +// GormSessionBackend stores sessions using a pre-instantiated gorm database +// object +type GormSessionBackend struct { + DB *gorm.DB +} + +func (s *GormSessionBackend) validate(session *Session) error { + if time.Since(session.AccessedAt).Seconds() >= Lifespan { + err := s.DB.Delete(session).Error + if err != nil { + return err + } + return ErrSessionExpired + } + + err := s.MarkAccessed(session) + if err != nil { + return err + } + + return nil +} + +func (s *GormSessionBackend) FindID(id uint) (*Session, error) { + session := &Session{ + ID: id, + } + + err := s.DB.Where(session).First(session).Error + if err != nil { + switch err { + case gorm.ErrRecordNotFound: + return nil, ErrNoActiveSession + default: + return nil, err + } + } + + err = s.validate(session) + if err != nil { + return nil, err + } + + return session, nil + +} + +func (s *GormSessionBackend) FindKey(key string) (*Session, error) { + session := &Session{ + Key: key, + } + + err := s.DB.Where(session).First(session).Error + if err != nil { + switch err { + case gorm.ErrRecordNotFound: + return nil, ErrNoActiveSession + default: + return nil, err + } + } + + err = s.validate(session) + if err != nil { + return nil, err + } + + return session, nil +} + +func (s *GormSessionBackend) FindAllForUser(id uint) ([]*Session, error) { + var sessions []*Session + err := s.DB.Where("user_id = ?", id).Find(&sessions).Error + return sessions, err +} + +func (s *GormSessionBackend) Create(userID uint) (*Session, error) { + key := make([]byte, SessionKeySize) + _, err := rand.Read(key) + if err != nil { + return nil, err + } + + session := &Session{ + UserID: userID, + Key: base64.StdEncoding.EncodeToString(key), + } + + err = s.DB.Create(session).Error + if err != nil { + return nil, err + } + + err = s.MarkAccessed(session) + if err != nil { + return nil, err + } + + return session, nil +} + +func (s *GormSessionBackend) Destroy(session *Session) error { + err := s.DB.Delete(session).Error + if err != nil { + return err + } + + return nil +} + +func (s *GormSessionBackend) DestroyAllForUser(id uint) error { + return s.DB.Delete(&Session{}, "user_id = ?", id).Error +} + +func (s *GormSessionBackend) MarkAccessed(session *Session) error { + session.AccessedAt = time.Now().UTC() + return s.DB.Save(session).Error +} diff --git a/sessions/backends_test.go b/sessions/backends_test.go new file mode 100644 index 0000000000..1ef1ae31ae --- /dev/null +++ b/sessions/backends_test.go @@ -0,0 +1,129 @@ +package sessions + +import ( + "crypto/rand" + "encoding/base64" + "net/http" + "testing" + "time" +) + +type mockSessionBackend struct { + sessions []*Session + id uint +} + +func newMockSessionBackend() *mockSessionBackend { + return &mockSessionBackend{ + sessions: []*Session{}, + id: 0, + } +} + +func (s *mockSessionBackend) FindID(id uint) (*Session, error) { + for _, each := range s.sessions { + if each.ID == id { + return each, nil + } + } + return nil, ErrNoActiveSession +} + +func (s *mockSessionBackend) FindKey(key string) (*Session, error) { + for _, each := range s.sessions { + if each.Key == key { + return each, nil + } + } + return nil, ErrNoActiveSession +} + +func (s *mockSessionBackend) FindAllForUser(id uint) ([]*Session, error) { + var sessions []*Session + for _, each := range sessions { + if each.UserID == id { + sessions = append(sessions, each) + } + } + return sessions, nil +} + +func (s *mockSessionBackend) nextID() uint { + s.id = s.id + 1 + return s.id +} + +func (s *mockSessionBackend) Create(userID uint) (*Session, error) { + key := make([]byte, SessionKeySize) + _, err := rand.Read(key) + if err != nil { + return nil, err + } + + session := &Session{ + ID: s.nextID(), + UserID: userID, + Key: base64.StdEncoding.EncodeToString(key), + } + + err = s.MarkAccessed(session) + if err != nil { + return nil, err + } + + s.sessions = append(s.sessions, session) + + return session, nil +} + +func (s *mockSessionBackend) Destroy(session *Session) error { + var sessions []*Session + for _, each := range s.sessions { + if each.ID != session.ID { + sessions = append(sessions, each) + } + } + s.sessions = sessions + return nil +} + +func (s *mockSessionBackend) DestroyAllForUser(id uint) error { + var sessions []*Session + for _, each := range s.sessions { + if each.UserID != id { + sessions = append(sessions, each) + } + } + s.sessions = sessions + return nil +} + +func (s *mockSessionBackend) MarkAccessed(session *Session) error { + session.AccessedAt = time.Now().UTC() + return nil +} + +type mockResponseWriter struct { + headers map[string][]string +} + +func newMocResponseWriter() *mockResponseWriter { + return &mockResponseWriter{ + headers: map[string][]string{}, + } +} + +func (w *mockResponseWriter) Header() http.Header { + return w.headers +} + +func (w *mockResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (w *mockResponseWriter) WriteHeader(int) { +} + +func TestFindID(t *testing.T) { + +} diff --git a/sessions/sessions.go b/sessions/sessions.go new file mode 100644 index 0000000000..3c3aeed693 --- /dev/null +++ b/sessions/sessions.go @@ -0,0 +1,212 @@ +package sessions + +import ( + "errors" + "net/http" + "time" + + "github.com/Sirupsen/logrus" + "github.com/dgrijalva/jwt-go" +) + +var ( + // An error returned by SessionBackend.Get() if no session record was found + // in the database + ErrNoActiveSession = errors.New("Active session is not present in the database") + + // An error returned by SessionBackend methods when no session object has + // been created yet but the requested action requires one + ErrSessionNotCreated = errors.New("The session has not been created") + + // An error returned by SessionBackend.Get() when a session is requested but + // it has expired + ErrSessionExpired = errors.New("The session has expired") + + // An error returned by SessionBackend which indicates that the token + // or it's content were malformed + ErrSessionMalformed = errors.New("The session token was malformed") +) + +var ( + // The name of the session cookie + CookieName = "Session" + + // The key to be used to sign and verify JWTs + jwtKey = "" + + // The amount of random data, in bytes, which will be used to create each + // session key + SessionKeySize = 64 + + // The amount of seconds that will pass before an inactive user is logged out + Lifespan = float64(60 * 60 * 24 * 90) +) + +// Session is the model object which represents what an active session is +type Session struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + AccessedAt time.Time + UserID uint `gorm:"not null"` + Key string `gorm:"not null;unique_index:idx_session_unique_key"` +} + +//////////////////////////////////////////////////////////////////////////////// +// Configuring the library +//////////////////////////////////////////////////////////////////////////////// + +type SessionConfiguration struct { + CookieName string + JWTKey string + SessionKeySize int + Lifespan float64 +} + +func Configure(s *SessionConfiguration) { + CookieName = s.CookieName + jwtKey = s.JWTKey + SessionKeySize = s.SessionKeySize + Lifespan = s.Lifespan +} + +// Set the name of the cookie +func SetCookieName(name string) { + CookieName = name +} + +//////////////////////////////////////////////////////////////////////////////// +// Managing sessions +//////////////////////////////////////////////////////////////////////////////// + +// SessionManager is a management object which helps with the administration of +// sessions within the application. Use NewSessionManager to create an instance +type SessionManager struct { + Backend SessionBackend + Request *http.Request + Writer http.ResponseWriter + session *Session +} + +func (sm *SessionManager) Session() (*Session, error) { + if sm.session == nil { + cookie, err := sm.Request.Cookie(CookieName) + if err != nil { + switch err { + case http.ErrNoCookie: + // No cookie was set + return nil, err + default: + // Something went wrong and the cookie may or may not be set + logrus.Errorf("Couldn't get cookie: %s", err.Error()) + return nil, ErrSessionMalformed + } + } + + token, err := ParseJWT(cookie.Value) + if err != nil { + logrus.Errorf("Couldn't parse JWT token string from cookie: %s", err.Error()) + return nil, ErrSessionMalformed + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + logrus.Error("Could not parse the claims from the JWT token") + return nil, ErrSessionMalformed + } + + sessionKeyClaim, ok := claims["session_key"] + if !ok { + logrus.Warn("JWT did not have session_key claim") + return nil, ErrSessionMalformed + } + + sessionKey, ok := sessionKeyClaim.(string) + if !ok { + logrus.Warn("JWT session_key claim was not a string") + return nil, ErrSessionMalformed + } + + session, err := sm.Backend.FindKey(sessionKey) + if err != nil { + switch err { + case ErrNoActiveSession: + // If the code path got this far, it's likely that the user was logged + // in some time in the past, but their session has been expired since + // their last usage of the application + return nil, err + default: + logrus.Errorf("Couldn't call Get on backend object: %s", err.Error()) + return nil, err + } + } + sm.session = session + } + return sm.session, nil + +} + +// MakeSessionForUserID creates a session in the database for a given user id. +// You must call Save() after calling this. +func (sm *SessionManager) MakeSessionForUserID(id uint) error { + session, err := sm.Backend.Create(id) + if err != nil { + return err + } + sm.session = session + return nil +} + +// Save writes the current session to a token and delivers the token as a cookie +// to the user. Save must be called after every write action on this struct +// (MakeSessionForUser, Destroy, etc.) +func (sm *SessionManager) Save() error { + token, err := GenerateJWT(sm.session.Key) + if err != nil { + return err + } + + // TODO: set proper flags on cookie for maximum security + http.SetCookie(sm.Writer, &http.Cookie{ + Name: CookieName, + Value: token, + }) + + return nil +} + +// Destroy deletes the active session from the database and erases the session +// instance from this object's access. You must call Save() after calling this. +func (sm *SessionManager) Destroy() error { + if sm.Backend != nil { + err := sm.Backend.Destroy(sm.session) + if err != nil { + return err + } + } + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// JSON Web Tokens +//////////////////////////////////////////////////////////////////////////////// + +// Given a session key create a JWT to be delivered to the client +func GenerateJWT(sessionKey string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "session_key": sessionKey, + }) + + return token.SignedString([]byte(jwtKey)) +} + +// ParseJWT attempts to parse a JWT token in serialized string form into a +// JWT token in a deserialized jwt.Token struct. +func ParseJWT(token string) (*jwt.Token, error) { + return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { + method, ok := t.Method.(*jwt.SigningMethodHMAC) + if !ok || method != jwt.SigningMethodHS256 { + return nil, errors.New("Unexpected signing method") + } + return []byte(jwtKey), nil + }) +} diff --git a/sessions/sessions_test.go b/sessions/sessions_test.go new file mode 100644 index 0000000000..796ecb6f1b --- /dev/null +++ b/sessions/sessions_test.go @@ -0,0 +1,66 @@ +package sessions + +import ( + "net/http" + "strings" + "testing" + + jwt "github.com/dgrijalva/jwt-go" +) + +func TestGenerateJWT(t *testing.T) { + jwtKey = "very secure" + tokenString, err := GenerateJWT("4") + token, err := ParseJWT(tokenString) + if err != nil { + t.Fatal(err.Error()) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + t.Fatal("Token is invalid") + } + + sessionKey := claims["session_key"].(string) + if sessionKey != "4" { + t.Fatalf("Claims are incorrect. session key is %s", sessionKey) + } +} + +func TestSessionManager(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + w := newMocResponseWriter() + sb := newMockSessionBackend() + + sm := &SessionManager{ + Backend: sb, + Request: r, + Writer: w, + } + + err := sm.MakeSessionForUserID(1) + if err != nil { + t.Fatalf(err.Error()) + } + + err = sm.Save() + if err != nil { + t.Fatalf(err.Error()) + } + + header := w.Header().Get("Set-Cookie") + tokenString := strings.Split(header, "=")[1] + token, err := ParseJWT(tokenString) + if err != nil { + t.Fatal(err.Error()) + } + session_key := token.Claims.(jwt.MapClaims)["session_key"].(string) + session, err := sb.FindKey(session_key) + if err != nil { + t.Fatal(err.Error()) + } + + if session.UserID != 1 { + t.Fatal("User ID doesn't match. Got: %s", session.UserID) + } + +} diff --git a/sessions_test.go b/sessions_test.go deleted file mode 100644 index bc3589b9d6..0000000000 --- a/sessions_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package main - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" -) - -type MockResponseWriter struct { -} - -func (w *MockResponseWriter) Header() http.Header { - return map[string][]string{} -} - -func (w *MockResponseWriter) Write([]byte) (int, error) { - return 0, nil -} - -func (w *MockResponseWriter) WriteHeader(int) { -} - -func TestSessionManagerVC(t *testing.T) { - db := openTestDB(t) - - admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false) - if err != nil { - t.Fatal(err.Error()) - } - - backend := &GormSessionBackend{db} - session, err := backend.Create(admin.ID) - if err != nil { - t.Fatal(err.Error()) - } - - if session.UserID != admin.ID { - t.Fatal("IDs do not match") - } - - token, err := GenerateJWT(session.Key) - - cookie := &http.Cookie{ - Name: CookieName, - Value: token, - } - - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err.Error()) - } - req.AddCookie(cookie) - - writer := &MockResponseWriter{} - - sm := &SessionManager{ - request: req, - writer: writer, - backend: backend, - db: db, - } - vc := sm.VC() - - if !vc.IsAdmin() { - t.Fatal("User should be admin") - } - - vcID, _ := vc.UserID() - if vcID != admin.ID { - t.Fatal("IDs don't match") - } -} - -func TestSessionCreation(t *testing.T) { - db := openTestDB(t) - r := createEmptyTestServer(db) - admin, _ := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false) - - r.GET("/login", func(c *gin.Context) { - sm := NewSessionManager(c) - sm.MakeSessionForUser(admin) - err := sm.Save() - if err != nil { - t.Fatal(err.Error()) - } - c.JSON(200, nil) - }) - - r.GET("/resource", func(c *gin.Context) { - sm := NewSessionManager(c) - vc := sm.VC() - if !vc.IsAdmin() { - t.Fatal("Request is not admin") - } - c.JSON(200, nil) - }) - - r.GET("/nope", func(c *gin.Context) { - sm := NewSessionManager(c) - vc := sm.VC() - if !vc.IsAdmin() { - t.Fatal("Request is not admin") - } - c.JSON(200, nil) - }) - - res1 := httptest.NewRecorder() - req1, _ := http.NewRequest("GET", "/login", nil) - r.ServeHTTP(res1, req1) - - res2 := httptest.NewRecorder() - req2, _ := http.NewRequest("GET", "/resource", nil) - req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie")) - r.ServeHTTP(res2, req2) -} diff --git a/story_test.go b/story_test.go index 4d8a552244..f086a9b1b9 100644 --- a/story_test.go +++ b/story_test.go @@ -6,6 +6,7 @@ import ( jwt "github.com/dgrijalva/jwt-go" "github.com/jinzhu/gorm" + "github.com/kolide/kolide-ose/sessions" ) func TestUserAndAccountManagement(t *testing.T) { @@ -59,7 +60,7 @@ func TestUserAndAccountManagement(t *testing.T) { } // Pull the token out of the JWT token and get the session info via that - token, err := ParseJWT(strings.Split(adminSession, "=")[1]) + token, err := sessions.ParseJWT(strings.Split(adminSession, "=")[1]) if err != nil { t.Fatal(err.Error()) } @@ -75,7 +76,7 @@ func TestUserAndAccountManagement(t *testing.T) { req.DeleteSession(adminSessionInfo.Sessions[0].SessionID, adminSession) // Verify the session was deleted - sessionVerify := &Session{ + sessionVerify := &sessions.Session{ Key: sessionKey, } err = req.db.Where(sessionVerify).First(sessionVerify).Error diff --git a/users.go b/users.go index 091510be94..24b82a91d4 100644 --- a/users.go +++ b/users.go @@ -2,10 +2,12 @@ package main import ( "fmt" + "time" "github.com/Sirupsen/logrus" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" + "github.com/kolide/kolide-ose/sessions" "golang.org/x/crypto/bcrypt" ) @@ -452,3 +454,232 @@ func SetUserEnabledState(c *gin.Context) { NeedsPasswordReset: user.NeedsPasswordReset, }) } + +/////////////////////////////////////////////////////////////////////////////// +// Session management HTTP endpoints +//////////////////////////////////////////////////////////////////////////////// + +// Setting the session backend via a middleware +func SessionBackendMiddleware(c *gin.Context) { + db := GetDB(c) + c.Set("SessionBackend", &sessions.GormSessionBackend{db}) + c.Next() +} + +// Get the database connection from the context, or panic +func GetSessionBackend(c *gin.Context) sessions.SessionBackend { + return c.MustGet("SessionBackend").(sessions.SessionBackend) +} + +//////////////////////////////////////////////////////////////////////////////// +// Session management HTTP endpoints +//////////////////////////////////////////////////////////////////////////////// + +type DeleteSessionRequestBody struct { + SessionID uint `json:"session_id" binding:"required"` +} + +func DeleteSession(c *gin.Context) { + var body DeleteSessionRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + vc := VC(c) + if !vc.CanPerformActions() { + UnauthorizedError(c) + return + } + + sb := GetSessionBackend(c) + + session, err := sb.FindID(body.SessionID) + if err != nil { + + } + + db := GetDB(c) + user := &User{ + BaseModel: BaseModel{ + ID: session.UserID, + }, + } + err = db.Where(user).First(user).Error + if err != nil { + DatabaseError(c) + return + } + + if !vc.CanPerformWriteActionOnUser(user) { + UnauthorizedError(c) + return + } + + err = sb.Destroy(session) + if err != nil { + DatabaseError(c) + return + } + + c.JSON(200, nil) +} + +type DeleteSessionsForUserRequestBody struct { + ID uint `json:"id"` + Username string `json:"username"` +} + +func DeleteSessionsForUser(c *gin.Context) { + var body DeleteSessionsForUserRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf(err.Error()) + } + + vc := VC(c) + if !vc.CanPerformActions() { + UnauthorizedError(c) + return + } + + db := GetDB(c) + var user User + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error + if err != nil { + DatabaseError(c) + return + } + + if !vc.CanPerformWriteActionOnUser(&user) { + UnauthorizedError(c) + return + } + + sb := GetSessionBackend(c) + err = sb.DestroyAllForUser(user.ID) + if err != nil { + DatabaseError(c) + return + } + + c.JSON(200, nil) + +} + +type GetInfoAboutSessionRequestBody struct { + SessionKey string `json:"session_key" binding:"required"` +} + +type SessionInfoResponseBody struct { + SessionID uint `json:"session_id"` + UserID uint `json:"user_id"` + CreatedAt time.Time `json:"created_at"` + AccessedAt time.Time `json:"created_at"` +} + +func GetInfoAboutSession(c *gin.Context) { + var body GetInfoAboutSessionRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + vc := VC(c) + if !vc.CanPerformActions() { + UnauthorizedError(c) + return + } + + sb := GetSessionBackend(c) + session, err := sb.FindKey(body.SessionKey) + if err != nil { + DatabaseError(c) + return + } + + db := GetDB(c) + var user User + user.ID = session.UserID + err = db.Where(&user).First(&user).Error + if err != nil { + DatabaseError(c) + return + } + + if !vc.IsAdmin() && !vc.IsUserID(user.ID) { + UnauthorizedError(c) + return + } + + c.JSON(200, &SessionInfoResponseBody{ + SessionID: session.ID, + UserID: session.UserID, + CreatedAt: session.CreatedAt, + AccessedAt: session.AccessedAt, + }) +} + +type GetInfoAboutSessionsForUserRequestBody struct { + ID uint `json:"id"` + Username string `json:"username"` +} + +type GetInfoAboutSessionsForUserResponseBody struct { + Sessions []*SessionInfoResponseBody `json:"sessions"` +} + +func GetInfoAboutSessionsForUser(c *gin.Context) { + var body GetInfoAboutSessionsForUserRequestBody + err := c.BindJSON(&body) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + vc := VC(c) + if !vc.CanPerformActions() { + UnauthorizedError(c) + return + } + + db := GetDB(c) + var user User + user.ID = body.ID + user.Username = body.Username + err = db.Where(&user).First(&user).Error + if err != nil { + DatabaseError(c) + return + } + + if !vc.IsAdmin() && !vc.IsUserID(user.ID) { + UnauthorizedError(c) + return + } + + sb := GetSessionBackend(c) + sessions, err := sb.FindAllForUser(user.ID) + if err != nil { + DatabaseError(c) + return + } + + var response []*SessionInfoResponseBody + for _, session := range sessions { + response = append(response, &SessionInfoResponseBody{ + SessionID: session.ID, + UserID: session.UserID, + CreatedAt: session.CreatedAt, + AccessedAt: session.AccessedAt, + }) + } + + c.JSON(200, &GetInfoAboutSessionsForUserResponseBody{ + Sessions: response, + }) +}