diff --git a/changes/issue-3927-remove-api-only-token-expiration b/changes/issue-3927-remove-api-only-token-expiration new file mode 100644 index 0000000000..70b83bcc00 --- /dev/null +++ b/changes/issue-3927-remove-api-only-token-expiration @@ -0,0 +1 @@ +* Remove expiration for API-only user tokens. diff --git a/server/datastore/mysql/sessions.go b/server/datastore/mysql/sessions.go index 652058b89f..576d27a928 100644 --- a/server/datastore/mysql/sessions.go +++ b/server/datastore/mysql/sessions.go @@ -11,8 +11,10 @@ import ( func (ds *Datastore) SessionByKey(ctx context.Context, key string) (*fleet.Session, error) { sqlStatement := ` - SELECT * FROM sessions - WHERE ` + "`key`" + ` = ? LIMIT 1 + SELECT s.*, u.api_only FROM sessions s + LEFT JOIN users u + ON s.user_id = u.id + WHERE ` + "s.`key`" + ` = ? LIMIT 1 ` session := &fleet.Session{} err := sqlx.GetContext(ctx, ds.reader, session, sqlStatement, key) @@ -28,8 +30,10 @@ func (ds *Datastore) SessionByKey(ctx context.Context, key string) (*fleet.Sessi func (ds *Datastore) SessionByID(ctx context.Context, id uint) (*fleet.Session, error) { sqlStatement := ` - SELECT * FROM sessions - WHERE id = ? + SELECT s.*, u.api_only FROM sessions s + LEFT JOIN users u + ON s.user_id = u.id + WHERE s.id = ? LIMIT 1 ` session := &fleet.Session{} @@ -46,8 +50,10 @@ func (ds *Datastore) SessionByID(ctx context.Context, id uint) (*fleet.Session, func (ds *Datastore) ListSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) { sqlStatement := ` - SELECT * FROM sessions - WHERE user_id = ? + SELECT s.*, u.api_only FROM sessions s + INNER JOIN users u + ON s.user_id = u.id + WHERE s.user_id = ? ` sessions := []*fleet.Session{} err := sqlx.SelectContext(ctx, ds.reader, &sessions, sqlStatement, id) diff --git a/server/datastore/mysql/sessions_test.go b/server/datastore/mysql/sessions_test.go index 446052e653..bd4b1bb58a 100644 --- a/server/datastore/mysql/sessions_test.go +++ b/server/datastore/mysql/sessions_test.go @@ -42,10 +42,14 @@ func testSessionsGetters(t *testing.T, ds *Datastore) { gotByID, err := ds.SessionByID(context.Background(), session.ID) require.NoError(t, err) assert.Equal(t, session.Key, gotByID.Key) + require.NotNil(t, gotByID.APIOnly) + assert.False(t, *gotByID.APIOnly) gotByKey, err := ds.SessionByKey(context.Background(), session.Key) require.NoError(t, err) assert.Equal(t, session.ID, gotByKey.ID) + require.NotNil(t, gotByKey.APIOnly) + assert.False(t, *gotByKey.APIOnly) newSession, err := ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID, Key: "somekey2"}) require.NoError(t, err) @@ -66,4 +70,41 @@ func testSessionsGetters(t *testing.T, ds *Datastore) { require.NotEqual(t, prevAccessedAt, sessions[0].AccessedAt) require.NoError(t, ds.DestroyAllSessionsForUser(context.Background(), user.ID)) + + // session for a non-existing user + newSession, err = ds.NewSession(context.Background(), &fleet.Session{UserID: user.ID + 1, Key: "someotherkey"}) + require.NoError(t, err) + + gotByKey, err = ds.SessionByKey(context.Background(), newSession.Key) + require.NoError(t, err) + assert.Equal(t, newSession.ID, gotByKey.ID) + require.Nil(t, gotByKey.APIOnly) + + gotByID, err = ds.SessionByID(context.Background(), newSession.ID) + require.NoError(t, err) + assert.Equal(t, newSession.ID, gotByKey.ID) + require.Nil(t, gotByKey.APIOnly) + + apiUser, err := ds.NewUser(context.Background(), &fleet.User{ + Password: []byte("supersecret"), + GlobalRole: ptr.String(fleet.RoleObserver), + APIOnly: true, + }) + require.NoError(t, err) + + // session for an api user + apiSession, err := ds.NewSession(context.Background(), &fleet.Session{UserID: apiUser.ID, Key: "someapikey"}) + require.NoError(t, err) + + gotByKey, err = ds.SessionByKey(context.Background(), apiSession.Key) + require.NoError(t, err) + assert.Equal(t, apiSession.ID, gotByKey.ID) + require.NotNil(t, gotByKey.APIOnly) + assert.True(t, *gotByKey.APIOnly) + + gotByID, err = ds.SessionByID(context.Background(), apiSession.ID) + require.NoError(t, err) + assert.Equal(t, apiSession.ID, gotByKey.ID) + require.NotNil(t, gotByKey.APIOnly) + assert.True(t, *gotByKey.APIOnly) } diff --git a/server/fleet/sessions.go b/server/fleet/sessions.go index ac149568b3..bcfa454693 100644 --- a/server/fleet/sessions.go +++ b/server/fleet/sessions.go @@ -31,6 +31,7 @@ type Session struct { AccessedAt time.Time `db:"accessed_at"` UserID uint `json:"user_id" db:"user_id"` Key string + APIOnly *bool `json:"-" db:"api_only"` } func (s Session) AuthzType() string { diff --git a/server/service/service_sessions.go b/server/service/service_sessions.go index e8877e21c0..9aade15b3f 100644 --- a/server/service/service_sessions.go +++ b/server/service/service_sessions.go @@ -308,6 +308,10 @@ func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) } sessionDuration := svc.config.Session.Duration + if session.APIOnly != nil && *session.APIOnly { + sessionDuration = 0 // make API-only tokens unlimited + } + // duration 0 = unlimited if sessionDuration != 0 && time.Since(session.AccessedAt) >= sessionDuration { err := svc.ds.DestroySession(ctx, session) diff --git a/server/service/service_sessions_test.go b/server/service/service_sessions_test.go index be31ef550d..c1e24c6afe 100644 --- a/server/service/service_sessions_test.go +++ b/server/service/service_sessions_test.go @@ -1,10 +1,15 @@ package service import ( + "context" "testing" "time" + "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/datastore/mysql" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -52,3 +57,55 @@ func TestAuthenticate(t *testing.T) { }) } } + +func TestGetSessionByKey(t *testing.T) { + ds := new(mock.Store) + svc := newTestService(ds, nil, nil) + cfg := config.TestConfig() + + theSession := &fleet.Session{UserID: 123, Key: "abc"} + + ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) { + return theSession, nil + } + ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error { + return nil + } + ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error { + return nil + } + + cases := []struct { + desc string + accessed time.Duration + apiOnly bool + fail bool + }{ + {"real user, accessed recently", -1 * time.Hour, false, false}, + {"real user, accessed too long ago", -(cfg.Session.Duration + time.Hour), false, true}, + {"api-only, accessed recently", -1 * time.Hour, true, false}, + {"api-only, accessed long ago", -(cfg.Session.Duration + time.Hour), true, false}, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var authErr *fleet.AuthRequiredError + ds.SessionByKeyFuncInvoked, ds.DestroySessionFuncInvoked, ds.MarkSessionAccessedFuncInvoked = false, false, false + + theSession.AccessedAt = time.Now().Add(tc.accessed) + theSession.APIOnly = ptr.Bool(tc.apiOnly) + _, err := svc.GetSessionByKey(context.Background(), theSession.Key) + if tc.fail { + require.Error(t, err) + require.ErrorAs(t, err, &authErr) + require.True(t, ds.SessionByKeyFuncInvoked) + require.True(t, ds.DestroySessionFuncInvoked) + require.False(t, ds.MarkSessionAccessedFuncInvoked) + } else { + require.NoError(t, err) + require.True(t, ds.SessionByKeyFuncInvoked) + require.False(t, ds.DestroySessionFuncInvoked) + require.True(t, ds.MarkSessionAccessedFuncInvoked) + } + }) + } +}