Remove expiration of API-only user tokens (#4314)

This commit is contained in:
Martin Angers 2022-02-22 08:12:03 -05:00 committed by GitHub
parent 93b50c3787
commit 2ab1b9ec85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 6 deletions

View file

@ -0,0 +1 @@
* Remove expiration for API-only user tokens.

View file

@ -11,8 +11,10 @@ import (
func (ds *Datastore) SessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
sqlStatement := `
SELECT * FROM sessions
WHERE ` + "`key`" + ` = ? LIMIT 1
SELECT s.*, u.api_only FROM sessions s
LEFT JOIN users u
ON s.user_id = u.id
WHERE ` + "s.`key`" + ` = ? LIMIT 1
`
session := &fleet.Session{}
err := sqlx.GetContext(ctx, ds.reader, session, sqlStatement, key)
@ -28,8 +30,10 @@ func (ds *Datastore) SessionByKey(ctx context.Context, key string) (*fleet.Sessi
func (ds *Datastore) SessionByID(ctx context.Context, id uint) (*fleet.Session, error) {
sqlStatement := `
SELECT * FROM sessions
WHERE id = ?
SELECT s.*, u.api_only FROM sessions s
LEFT JOIN users u
ON s.user_id = u.id
WHERE s.id = ?
LIMIT 1
`
session := &fleet.Session{}
@ -46,8 +50,10 @@ func (ds *Datastore) SessionByID(ctx context.Context, id uint) (*fleet.Session,
func (ds *Datastore) ListSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) {
sqlStatement := `
SELECT * FROM sessions
WHERE user_id = ?
SELECT s.*, u.api_only FROM sessions s
INNER JOIN users u
ON s.user_id = u.id
WHERE s.user_id = ?
`
sessions := []*fleet.Session{}
err := sqlx.SelectContext(ctx, ds.reader, &sessions, sqlStatement, id)

View file

@ -42,10 +42,14 @@ func testSessionsGetters(t *testing.T, ds *Datastore) {
gotByID, err := ds.SessionByID(context.Background(), session.ID)
require.NoError(t, err)
assert.Equal(t, session.Key, gotByID.Key)
require.NotNil(t, gotByID.APIOnly)
assert.False(t, *gotByID.APIOnly)
gotByKey, err := ds.SessionByKey(context.Background(), session.Key)
require.NoError(t, err)
assert.Equal(t, session.ID, gotByKey.ID)
require.NotNil(t, gotByKey.APIOnly)
assert.False(t, *gotByKey.APIOnly)
newSession, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID, Key: "somekey2"})
require.NoError(t, err)
@ -66,4 +70,41 @@ func testSessionsGetters(t *testing.T, ds *Datastore) {
require.NotEqual(t, prevAccessedAt, sessions[0].AccessedAt)
require.NoError(t, ds.DestroyAllSessionsForUser(context.Background(), user.ID))
// session for a non-existing user
newSession, err = ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID + 1, Key: "someotherkey"})
require.NoError(t, err)
gotByKey, err = ds.SessionByKey(context.Background(), newSession.Key)
require.NoError(t, err)
assert.Equal(t, newSession.ID, gotByKey.ID)
require.Nil(t, gotByKey.APIOnly)
gotByID, err = ds.SessionByID(context.Background(), newSession.ID)
require.NoError(t, err)
assert.Equal(t, newSession.ID, gotByKey.ID)
require.Nil(t, gotByKey.APIOnly)
apiUser, err := ds.NewUser(context.Background(), &fleet.User{
Password: []byte("supersecret"),
GlobalRole: ptr.String(fleet.RoleObserver),
APIOnly: true,
})
require.NoError(t, err)
// session for an api user
apiSession, err := ds.NewSession(context.Background(), &fleet.Session{UserID: apiUser.ID, Key: "someapikey"})
require.NoError(t, err)
gotByKey, err = ds.SessionByKey(context.Background(), apiSession.Key)
require.NoError(t, err)
assert.Equal(t, apiSession.ID, gotByKey.ID)
require.NotNil(t, gotByKey.APIOnly)
assert.True(t, *gotByKey.APIOnly)
gotByID, err = ds.SessionByID(context.Background(), apiSession.ID)
require.NoError(t, err)
assert.Equal(t, apiSession.ID, gotByKey.ID)
require.NotNil(t, gotByKey.APIOnly)
assert.True(t, *gotByKey.APIOnly)
}

View file

@ -31,6 +31,7 @@ type Session struct {
AccessedAt time.Time `db:"accessed_at"`
UserID uint `json:"user_id" db:"user_id"`
Key string
APIOnly *bool `json:"-" db:"api_only"`
}
func (s Session) AuthzType() string {

View file

@ -308,6 +308,10 @@ func (svc *Service) validateSession(ctx context.Context, session *fleet.Session)
}
sessionDuration := svc.config.Session.Duration
if session.APIOnly != nil && *session.APIOnly {
sessionDuration = 0 // make API-only tokens unlimited
}
// duration 0 = unlimited
if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration {
err := svc.ds.DestroySession(ctx, session)

View file

@ -1,10 +1,15 @@
package service
import (
"context"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -52,3 +57,55 @@ func TestAuthenticate(t *testing.T) {
})
}
}
func TestGetSessionByKey(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
cfg := config.TestConfig()
theSession := &fleet.Session{UserID: 123, Key: "abc"}
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
return theSession, nil
}
ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error {
return nil
}
ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error {
return nil
}
cases := []struct {
desc string
accessed time.Duration
apiOnly bool
fail bool
}{
{"real user, accessed recently", -1 * time.Hour, false, false},
{"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true},
{"api-only, accessed recently", -1 * time.Hour, true, false},
{"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var authErr *fleet.AuthRequiredError
ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false
theSession.AccessedAt = time.Now().Add(tc.accessed)
theSession.APIOnly = ptr.Bool(tc.apiOnly)
_, err := svc.GetSessionByKey(context.Background(), theSession.Key)
if tc.fail {
require.Error(t, err)
require.ErrorAs(t, err, &authErr)
require.True(t, ds.SessionByKeyFuncInvoked)
require.True(t, ds.DestroySessionFuncInvoked)
require.False(t, ds.MarkSessionAccessedFuncInvoked)
} else {
require.NoError(t, err)
require.True(t, ds.SessionByKeyFuncInvoked)
require.False(t, ds.DestroySessionFuncInvoked)
require.True(t, ds.MarkSessionAccessedFuncInvoked)
}
})
}
}