feat: oidc background token refresh (#23727)

Signed-off-by: Mike Cutsail <mcutsail15@apple.com>
This commit is contained in:
Mike Cutsail 2025-11-13 11:37:53 -05:00 committed by GitHub
parent 60f2ff5f77
commit 5c6aa59ed3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 831 additions and 164 deletions

View file

@ -67,6 +67,7 @@ data:
issuer: https://keycloak.example.com/realms/master
clientID: argocd
clientSecret: $oidc.keycloak.clientSecret
refreshTokenThreshold: 2m
requestedScopes: ["openid", "profile", "email", "groups"]
```
@ -77,6 +78,7 @@ Make sure that:
- __clientID__ is set to the Client ID you configured in Keycloak
- __clientSecret__ points to the right key you created in the _argocd-secret_ Secret
- __requestedScopes__ contains the _groups_ claim if you didn't add it to the Default scopes
- __refreshTokenThreshold__ is less than the client token lifetime. If this setting is not less than the token lifetime, a new token will be obtained for every request. Keycloak sets the client token lifetime to 5 minutes by default.
## Keycloak and ArgoCD with PKCE
@ -135,6 +137,7 @@ data:
issuer: https://keycloak.example.com/realms/master
clientID: argocd
enablePKCEAuthentication: true
refreshTokenThreshold: 2m
requestedScopes: ["openid", "profile", "email", "groups"]
```
@ -145,6 +148,7 @@ Make sure that:
- __clientID__ is set to the Client ID you configured in Keycloak
- __enablePKCEAuthentication__ must be set to true to enable correct ArgoCD behaviour with PKCE
- __requestedScopes__ contains the _groups_ claim if you didn't add it to the Default scopes
- __refreshTokenThreshold__ is less than the client token lifetime. If this setting is not less than the token lifetime, a new token will be obtained for every request. Keycloak sets the client token lifetime to 5 minutes by default.
## Configuring the groups claim

View file

@ -162,7 +162,7 @@ func (t *terminalSession) performValidationsAndReconnect(p []byte) (int, error)
}
// check if token still valid
_, newToken, err := t.sessionManager.VerifyToken(*t.token)
_, newToken, err := t.sessionManager.VerifyToken(t.ctx, *t.token)
// err in case if token is revoked, newToken in case if refresh happened
if err != nil || newToken != "" {
// need to send reconnect code in case if token was refreshed

View file

@ -31,7 +31,7 @@ func NewHandler(settingsMrg *settings.SettingsManager, sessionMgr *session.Sessi
type Handler struct {
settingsMgr *settings.SettingsManager
rootPath string
verifyToken func(tokenString string) (jwt.Claims, string, error)
verifyToken func(ctx context.Context, tokenString string) (jwt.Claims, string, error)
revokeToken func(ctx context.Context, id string, expiringAt time.Duration) error
baseHRef string
}
@ -94,7 +94,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Set-Cookie", argocdCookie.String())
}
claims, _, err := h.verifyToken(tokenString)
claims, _, err := h.verifyToken(r.Context(), tokenString)
if err != nil {
http.Redirect(w, r, logoutRedirectURL, http.StatusSeeOther)
return

View file

@ -1,6 +1,7 @@
package logout
import (
"context"
"errors"
"net/http"
"net/http/httptest"
@ -245,28 +246,28 @@ func TestHandlerConstructLogoutURL(t *testing.T) {
sessionManager := session.NewSessionManager(settingsManagerWithOIDCConfig, test.NewFakeProjLister(), "", nil, session.NewUserStateStorage(nil))
oidcHandler := NewHandler(settingsManagerWithOIDCConfig, sessionManager, rootPath, baseHRef)
oidcHandler.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
oidcHandler.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
if !validJWTPattern.MatchString(tokenString) {
return nil, "", errors.New("invalid jwt")
}
return &jwt.RegisteredClaims{Issuer: "okta"}, "", nil
}
nonoidcHandler := NewHandler(settingsManagerWithoutOIDCConfig, sessionManager, "", baseHRef)
nonoidcHandler.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
nonoidcHandler.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
if !validJWTPattern.MatchString(tokenString) {
return nil, "", errors.New("invalid jwt")
}
return &jwt.RegisteredClaims{Issuer: session.SessionManagerClaimsIssuer}, "", nil
}
oidcHandlerWithoutLogoutURL := NewHandler(settingsManagerWithOIDCConfigButNoLogoutURL, sessionManager, "", baseHRef)
oidcHandlerWithoutLogoutURL.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
oidcHandlerWithoutLogoutURL.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
if !validJWTPattern.MatchString(tokenString) {
return nil, "", errors.New("invalid jwt")
}
return &jwt.RegisteredClaims{Issuer: "okta"}, "", nil
}
nonoidcHandlerWithMultipleURLs := NewHandler(settingsManagerWithoutOIDCAndMultipleURLs, sessionManager, "", baseHRef)
nonoidcHandlerWithMultipleURLs.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
nonoidcHandlerWithMultipleURLs.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
if !validJWTPattern.MatchString(tokenString) {
return nil, "", errors.New("invalid jwt")
}
@ -274,7 +275,7 @@ func TestHandlerConstructLogoutURL(t *testing.T) {
}
oidcHandlerWithoutBaseURL := NewHandler(settingsManagerWithOIDCConfigButNoURL, sessionManager, "argocd", baseHRef)
oidcHandlerWithoutBaseURL.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
oidcHandlerWithoutBaseURL.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
if !validJWTPattern.MatchString(tokenString) {
return nil, "", errors.New("invalid jwt")
}

View file

@ -323,6 +323,8 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio
appsetLister := appFactory.Argoproj().V1alpha1().ApplicationSets().Lister()
userStateStorage := util_session.NewUserStateStorage(opts.RedisClient)
ssoClientApp, err := oidc.NewClientApp(settings, opts.DexServerAddr, opts.DexTLSConfig, opts.BaseHRef, cacheutil.NewRedisCache(opts.RedisClient, settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
errorsutil.CheckError(err)
sessionMgr := util_session.NewSessionManager(settingsMgr, projLister, opts.DexServerAddr, opts.DexTLSConfig, userStateStorage)
enf := rbac.NewEnforcer(opts.KubeClientset, opts.Namespace, common.ArgoCDRBACConfigMapName, nil)
enf.EnableEnforce(!opts.DisableAuth)
@ -370,6 +372,7 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio
a := &ArgoCDServer{
ArgoCDServerOpts: opts,
ApplicationSetOpts: appsetOpts,
ssoClientApp: ssoClientApp,
log: logger,
settings: settings,
sessionMgr: sessionMgr,
@ -1125,19 +1128,7 @@ func (server *ArgoCDServer) translateGrpcCookieHeader(ctx context.Context, w htt
}
func (server *ArgoCDServer) setTokenCookie(token string, w http.ResponseWriter) error {
cookiePath := "path=/" + strings.TrimRight(strings.TrimLeft(server.BaseHRef, "/"), "/")
flags := []string{cookiePath, "SameSite=lax", "httpOnly"}
if !server.Insecure {
flags = append(flags, "Secure")
}
cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, token, flags...)
if err != nil {
return fmt.Errorf("error creating cookie metadata: %w", err)
}
for _, cookie := range cookies {
w.Header().Add("Set-Cookie", cookie)
}
return nil
return httputil.SetTokenCookie(token, server.BaseHRef, !server.Insecure, w)
}
func withRootPath(handler http.Handler, a *ArgoCDServer) http.Handler {
@ -1221,9 +1212,6 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf}
// SSO ClientApp
server.ssoClientApp, _ = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts).
WithFeatureFlagMiddleware(server.settingsMgr.GetSettings)
th := util_session.WithAuthMiddleware(server.DisableAuth, server.settings.IsSSOConfigured(), server.ssoClientApp, server.sessionMgr, terminal)
@ -1368,9 +1356,7 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
return
}
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
var err error
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
errorsutil.CheckError(err)
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
}
@ -1566,6 +1552,7 @@ func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context,
return ctx, nil
}
// getClaims extracts, validates and refreshes a JWT token from an incoming request context.
func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
@ -1575,17 +1562,29 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string,
if tokenString == "" {
return nil, "", ErrNoSession
}
claims, newToken, err := server.sessionMgr.VerifyToken(tokenString)
// A valid argocd-issued token is automatically refreshed here prior to expiration.
// OIDC tokens will be verified but will not be refreshed here.
claims, newToken, err := server.sessionMgr.VerifyToken(ctx, tokenString)
if err != nil {
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
}
finalClaims := claims
if server.settings.IsSSOConfigured() {
finalClaims, err = server.ssoClientApp.SetGroupsFromUserInfo(claims, util_session.SessionManagerClaimsIssuer)
updatedClaims, err := server.ssoClientApp.SetGroupsFromUserInfo(ctx, claims, util_session.SessionManagerClaimsIssuer)
if err != nil {
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
}
finalClaims = updatedClaims
// OIDC tokens are automatically refreshed here prior to expiration
refreshedToken, err := server.ssoClientApp.CheckAndRefreshToken(ctx, updatedClaims, server.settings.OIDCRefreshTokenThreshold)
if err != nil {
log.Errorf("error checking and refreshing token: %v", err)
}
if refreshedToken != "" && refreshedToken != tokenString {
newToken = refreshedToken
log.Infof("refreshed token for subject: %v", jwtutil.StringField(updatedClaims, "sub"))
}
}
return finalClaims, newToken, nil

View file

@ -241,3 +241,23 @@ func drainBody(body io.ReadCloser) {
log.Warnf("error reading response body: %s", err.Error())
}
}
func SetTokenCookie(token string, baseHRef string, isSecure bool, w http.ResponseWriter) error {
var path string
if baseHRef != "" {
path = strings.TrimRight(strings.TrimLeft(baseHRef, "/"), "/")
}
cookiePath := "path=/" + path
flags := []string{cookiePath, "SameSite=lax", "httpOnly"}
if isSecure {
flags = append(flags, "Secure")
}
cookies, err := MakeCookieMetadata(common.AuthCookieName, token, flags...)
if err != nil {
return fmt.Errorf("error creating cookie metadata: %w", err)
}
for _, cookie := range cookies {
w.Header().Add("Set-Cookie", cookie)
}
return nil
}

View file

@ -1,10 +1,13 @@
package http
import (
"fmt"
"net/http"
"strings"
"testing"
"github.com/argoproj/argo-cd/v3/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -49,6 +52,101 @@ func TestSplitCookie(t *testing.T) {
assert.Equal(t, cookieValue, token)
}
// mockResponseWriter is a mock implementation of http.ResponseWriter.
// It captures added headers for verification in tests.
type mockResponseWriter struct {
header http.Header
}
func (m *mockResponseWriter) Header() http.Header {
if m.header == nil {
m.header = make(http.Header)
}
return m.header
}
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) WriteHeader(_ int) {}
func TestSetTokenCookie(t *testing.T) {
tests := []struct {
name string
token string
baseHRef string
isSecure bool
expectedCookies []string // Expected Set-Cookie header values
}{
{
name: "Insecure cookie",
token: "insecure-token",
baseHRef: "",
isSecure: false,
expectedCookies: []string{
fmt.Sprintf("%s=%s; path=/; SameSite=lax; httpOnly", common.AuthCookieName, "insecure-token"),
},
},
{
name: "Secure cookie",
token: "secure-token",
baseHRef: "",
isSecure: true,
expectedCookies: []string{
fmt.Sprintf("%s=%s; path=/; SameSite=lax; httpOnly; Secure", common.AuthCookieName, "secure-token"),
},
},
{
name: "Insecure cookie with baseHRef",
token: "token-with-path",
baseHRef: "/app",
isSecure: false,
expectedCookies: []string{
fmt.Sprintf("%s=%s; path=/app; SameSite=lax; httpOnly", common.AuthCookieName, "token-with-path"),
},
},
{
name: "Secure cookie with baseHRef",
token: "secure-token-with-path",
baseHRef: "app/",
isSecure: true,
expectedCookies: []string{
fmt.Sprintf("%s=%s; path=/app; SameSite=lax; httpOnly; Secure", common.AuthCookieName, "secure-token-with-path"),
},
},
{
name: "Unsecured cookie, baseHRef with multiple segments and mixed slashes",
token: "complex-path-token",
baseHRef: "///api/v1/auth///",
isSecure: false,
expectedCookies: []string{
fmt.Sprintf("%s=%s; path=/api/v1/auth; SameSite=lax; httpOnly", common.AuthCookieName, "complex-path-token"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &mockResponseWriter{}
err := SetTokenCookie(tt.token, tt.baseHRef, tt.isSecure, w)
if err != nil {
t.Fatalf("%s: Unexpected error: %v", tt.name, err)
}
setCookieHeaders := w.Header()["Set-Cookie"]
if len(setCookieHeaders) != len(tt.expectedCookies) {
t.Errorf("Mistmatch in Set-Cookie header length: %s\nExpected: %d\nGot: %d",
tt.name, len(tt.expectedCookies), len(setCookieHeaders))
return
}
if len(tt.expectedCookies) > 0 && setCookieHeaders[0] != tt.expectedCookies[0] {
t.Errorf("Mismatch in Set-Cookie header: %s\nExpected: %s\nGot: %s",
tt.name, tt.expectedCookies[0], setCookieHeaders[0])
}
})
}
}
// TestRoundTripper just copy request headers to the resposne.
type TestRoundTripper struct{}

View file

@ -43,6 +43,7 @@ const (
ResponseTypeCode = "code"
UserInfoResponseCachePrefix = "userinfo_response"
AccessTokenCachePrefix = "access_token"
OidcTokenCachePrefix = "oidc_token"
)
// OIDCConfiguration holds a subset of interested fields from the OIDC configuration spec
@ -87,6 +88,8 @@ type ClientApp struct {
clientCache cache.CacheClient
// properties for azure workload identity.
azure azureApp
// preemptive token refresh threshold
refreshTokenThreshold time.Duration
}
type azureApp struct {
@ -98,6 +101,63 @@ type azureApp struct {
mtx *sync.RWMutex
}
// OidcTokenCache is a serialization wrapper around oauth2 provider configuration needed to generate a TokenSource
type OidcTokenCache struct {
// Redirect URL is needed for oauth2 config initialization
RedirectURL string `json:"redirect_url"`
// oauth2 Token
Token *oauth2.Token `json:"token"`
// TokenExtraIdToken captures value of id_token
TokenExtraIdToken string `json:"token_extra_id_token"`
}
// NewOidcTokenCache initializes the struct from a redirect URL and an existing token
func NewOidcTokenCache(redirectURL string, token *oauth2.Token) *OidcTokenCache {
var idToken string
if token.Extra("id_token") == nil {
idToken = ""
} else {
idToken = token.Extra("id_token").(string)
}
return &OidcTokenCache{
RedirectURL: redirectURL,
Token: token,
TokenExtraIdToken: idToken,
}
}
// GetOidcTokenCacheFromJSON deserializes the json representation of OidcTokenCache. The Token extra map is updated from
// the serialization wrapper to propagate the id_token. This will ensure that the TokenSource always retrieves a usable token.
func GetOidcTokenCacheFromJSON(jsonBytes []byte) (*OidcTokenCache, error) {
var newToken OidcTokenCache
err := json.Unmarshal(jsonBytes, &newToken)
if err != nil {
return nil, err
}
if newToken.Token == nil {
return nil, errors.New("empty token")
}
newToken.Token = newToken.Token.WithExtra(map[string]any{
"id_token": newToken.TokenExtraIdToken,
})
return &newToken, nil
}
// GetTokenSourceFromCache creates an oauth2 TokenSource from a cached oidc token. The TokenSource will be configured
// with an early expiration based on the refreshTokenThreshold.
func (a *ClientApp) GetTokenSourceFromCache(ctx context.Context, oidcTokenCache *OidcTokenCache) (oauth2.TokenSource, error) {
if oidcTokenCache == nil {
return nil, errors.New("oidcTokenCache is required")
}
config, err := a.getOauth2ConfigForRedirectURI(oidcTokenCache.RedirectURL)
if err != nil {
return nil, err
}
baseTokenSource := config.TokenSource(ctx, oidcTokenCache.Token)
tokenRefresher := oauth2.ReuseTokenSourceWithExpiry(oidcTokenCache.Token, baseTokenSource, a.refreshTokenThreshold)
return tokenRefresher, nil
}
func GetScopesOrDefault(scopes []string) []string {
if len(scopes) == 0 {
return []string{"openid", "profile", "email", "groups"}
@ -127,6 +187,7 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL
encryptionKey: encryptionKey,
clientCache: cacheClient,
azure: azureApp{mtx: &sync.RWMutex{}},
refreshTokenThreshold: settings.OIDCRefreshTokenThreshold,
}
log.Infof("Creating client app (%s)", a.clientID)
u, err := url.Parse(settings.URL)
@ -165,23 +226,27 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL
return &a, nil
}
func (a *ClientApp) oauth2Config(request *http.Request, scopes []string) (*oauth2.Config, error) {
func (a *ClientApp) getRedirectURIForRequest(req *http.Request) string {
redirectURI, err := a.settings.RedirectURLForRequest(req)
if err != nil {
log.Warnf("Unable to find ArgoCD URL from request, falling back to configured redirect URI: %v", err)
redirectURI = a.redirectURI
}
return redirectURI
}
func (a *ClientApp) getOauth2ConfigForRedirectURI(redirectURI string) (*oauth2.Config, error) {
endpoint, err := a.provider.Endpoint()
if err != nil {
return nil, err
}
redirectURL, err := a.settings.RedirectURLForRequest(request)
if err != nil {
log.Warnf("Unable to find ArgoCD URL from request, falling back to configured redirect URI: %v", err)
redirectURL = a.redirectURI
}
return &oauth2.Config{
ClientID: a.clientID,
ClientSecret: a.clientSecret,
Endpoint: *endpoint,
Scopes: scopes,
RedirectURL: redirectURL,
Scopes: a.getScopes(),
RedirectURL: redirectURI,
}, nil
}
@ -315,17 +380,13 @@ func (a *ClientApp) HandleLogin(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
scopes := make([]string, 0)
pkceVerifier := ""
var opts []oauth2.AuthCodeOption
if config := a.settings.OIDCConfig(); config != nil {
scopes = GetScopesOrDefault(config.RequestedScopes)
opts = AppendClaimsAuthenticationRequestParameter(opts, config.RequestedIDTokenClaims)
} else if a.settings.IsDexConfigured() {
scopes = append(GetScopesOrDefault(nil), common.DexFederatedScope)
}
oauth2Config, err := a.oauth2Config(r, scopes)
oauth2Config, err := a.getOauth2ConfigForRedirectURI(a.getRedirectURIForRequest(r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -406,7 +467,7 @@ func (a *azureApp) getFederatedServiceAccountToken(context.Context) (string, err
// HandleCallback is the callback handler for an OAuth2 login flow
func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
oauth2Config, err := a.oauth2Config(r, nil)
oauth2Config, err := a.getOauth2ConfigForRedirectURI(a.getRedirectURIForRequest(r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -456,27 +517,21 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
return
}
// Parse out id token
idTokenRAW, ok := token.Extra("id_token").(string)
if !ok {
http.Error(w, "no id_token in token response", http.StatusInternalServerError)
return
}
idToken, err := a.provider.Verify(idTokenRAW, a.settings)
idToken, err := a.provider.Verify(ctx, idTokenRAW, a.settings)
if err != nil {
log.Warnf("Failed to verify token: %s", err)
log.Warnf("Failed to verify oidc token: %s", err)
http.Error(w, common.TokenVerificationError, http.StatusInternalServerError)
return
}
path := "/"
if a.baseHRef != "" {
path = strings.TrimRight(strings.TrimLeft(a.baseHRef, "/"), "/")
}
cookiePath := "path=/" + path
flags := []string{cookiePath, "SameSite=lax", "httpOnly"}
if a.secureCookie {
flags = append(flags, "Secure")
}
// Set cache
var claims jwt.MapClaims
err = idToken.Claims(&claims)
if err != nil {
@ -484,38 +539,38 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
return
}
// save the accessToken in memory for later use
encToken, err := crypto.Encrypt([]byte(token.AccessToken), a.encryptionKey)
sub := jwtutil.StringField(claims, "sub")
err = a.SetValueInEncryptedCache(FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
if err != nil {
claimsJSON, _ := json.Marshal(claims)
http.Error(w, "failed encrypting token", http.StatusInternalServerError)
log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON)
log.Errorf("cannot cache encrypted accessToken: %v (claims=%s)", err, claimsJSON)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
sub := jwtutil.StringField(claims, "sub")
err = a.clientCache.Set(&cache.Item{
Key: FormatAccessTokenCacheKey(sub),
Object: encToken,
CacheActionOpts: cache.CacheActionOpts{
Expiration: getTokenExpiration(claims),
},
})
// Cache encrypted raw token for background refresh
oidcTokenCache := NewOidcTokenCache(a.getRedirectURIForRequest(r), token)
oidcTokenCacheJSON, err := json.Marshal(oidcTokenCache)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
sid := jwtutil.StringField(claims, "sid")
err = a.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
if err != nil {
claimsJSON, _ := json.Marshal(claims)
http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError)
log.Errorf("cannot cache encrypted oidc token: %v (claims=%s)", err, claimsJSON)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if idTokenRAW != "" {
cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, idTokenRAW, flags...)
err = httputil.SetTokenCookie(idTokenRAW, a.baseHRef, a.secureCookie, w)
if err != nil {
claimsJSON, _ := json.Marshal(claims)
http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError)
return
}
for _, cookie := range cookies {
w.Header().Add("Set-Cookie", cookie)
}
}
claimsJSON, _ := json.Marshal(claims)
@ -528,6 +583,109 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
}
}
// GetValueFromEncryptedCache is a convenience method for retreiving a value from cache and decrypting it. If the cache
// does not contain a value for the given key, a nil value is returned. Return handling should check for error and then
// check for nil.
func (a *ClientApp) GetValueFromEncryptedCache(key string) (value []byte, err error) {
var encryptedValue []byte
err = a.clientCache.Get(key, &encryptedValue)
if err != nil {
if errors.Is(err, cache.ErrCacheMiss) {
// Return nil to signify a cache miss
return nil, nil
}
return nil, fmt.Errorf("failed to get encrypted value from cache: %w", err)
}
value, err = crypto.Decrypt(encryptedValue, a.encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt value from cache: %w", err)
}
return value, err
}
// SetValueFromEncyrptedCache is a convenience method for encrypting a value and storing it in the cache at a given key.
// Cache expiration is set based on input.
func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiration time.Duration) error {
encryptedValue, err := crypto.Encrypt(value, a.encryptionKey)
if err != nil {
return err
}
err = a.clientCache.Set(&cache.Item{
Key: key,
Object: encryptedValue,
CacheActionOpts: cache.CacheActionOpts{
Expiration: expiration,
},
})
if err != nil {
return err
}
return nil
}
func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.MapClaims, refreshTokenThreshold time.Duration) (string, error) {
sub := jwtutil.StringField(groupClaims, "sub")
sid := jwtutil.StringField(groupClaims, "sid")
if GetTokenExpiration(groupClaims) < refreshTokenThreshold {
token, err := a.GetUpdatedOidcTokenFromCache(ctx, sub, sid)
if err != nil {
log.Errorf("Failed to get token from cache: %v", err)
return "", err
}
if token != nil {
idTokenRAW, ok := token.Extra("id_token").(string)
if !ok {
return "", errors.New("empty id_token")
}
return idTokenRAW, nil
}
}
return "", nil
}
// GetUpdatedOidcTokenFromCache fetches a token from cache and refreshes it if under the threshold for expiration.
// The cached token will also be updated if it is refreshed. Returns latest token or an error if the process fails.
func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject string, sessionId string) (*oauth2.Token, error) {
ctx = gooidc.ClientContext(ctx, a.client)
// Get oauth2 config
cacheKey := formatOidcTokenCacheKey(subject, sessionId)
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(cacheKey)
if err != nil {
return nil, err
}
if oidcTokenCacheJSON == nil {
return nil, nil
}
oidcTokenCache, err := GetOidcTokenCacheFromJSON(oidcTokenCacheJSON)
if err != nil {
err = fmt.Errorf("failed to unmarshal cached oidc token: %w", err)
return nil, err
}
tokenSource, err := a.GetTokenSourceFromCache(ctx, oidcTokenCache)
if err != nil {
err = fmt.Errorf("failed to get token source from cached oidc token: %w", err)
return nil, err
}
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to refresh token from source: %w", err)
}
if token.AccessToken != oidcTokenCache.Token.AccessToken {
oidcTokenCache = NewOidcTokenCache(oidcTokenCache.RedirectURL, token)
oidcTokenCacheJSON, err = json.Marshal(oidcTokenCache)
if err != nil {
return nil, fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
}
err = a.SetValueInEncryptedCache(cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
if err != nil {
return nil, err
}
}
return token, nil
}
var implicitFlowTmpl = template.Must(template.New("implicit.html").Parse(`<script>
var hash = window.location.hash.substr(1);
var result = hash.split('&').reduce(function (result, item) {
@ -645,7 +803,7 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc
// If querying the UserInfo endpoint fails, we return an error to indicate the session is invalid
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
// otherwise this would cause a panic
func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) {
func (a *ClientApp) SetGroupsFromUserInfo(ctx context.Context, claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) {
var groupClaims jwt.MapClaims
var ok bool
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
@ -657,7 +815,7 @@ func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaim
}
iss := jwtutil.StringField(groupClaims, "iss")
if iss != sessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" {
userInfo, unauthorized, err := a.GetUserInfo(groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath())
userInfo, unauthorized, err := a.GetUserInfo(ctx, groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath())
if unauthorized {
return groupClaims, fmt.Errorf("error while quering userinfo endpoint: %w", err)
}
@ -674,7 +832,7 @@ func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaim
}
// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
sub := jwtutil.StringField(actualClaims, "sub")
var claims jwt.MapClaims
var encClaims []byte
@ -696,19 +854,13 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
}
// check if the accessToken for the user is still present
var encAccessToken []byte
err := a.clientCache.Get(FormatAccessTokenCacheKey(sub), &encAccessToken)
// without an accessToken we can't query the user info endpoint
// thus the user needs to reauthenticate for argocd to get a new accessToken
if errors.Is(err, cache.ErrCacheMiss) {
return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err)
} else if err != nil {
accessTokenBytes, err := a.GetValueFromEncryptedCache(FormatAccessTokenCacheKey(sub))
if err != nil {
return claims, true, fmt.Errorf("could not read accessToken from cache for %s: %w", sub, err)
}
accessToken, err := crypto.Decrypt(encAccessToken, a.encryptionKey)
if err != nil {
return claims, true, fmt.Errorf("could not decrypt accessToken for %s: %w", sub, err)
if accessTokenBytes == nil {
return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err)
}
url := issuerURL + userInfoPath
@ -718,7 +870,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
return claims, false, err
}
bearer := fmt.Sprintf("Bearer %s", accessToken)
bearer := "Bearer " + string(accessTokenBytes)
request.Header.Set("Authorization", bearer)
response, err := a.client.Do(request)
@ -740,7 +892,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
switch header {
case "application/jwt":
// if body is JWT, first validate it before extracting claims
idToken, err := a.provider.Verify(string(rawBody), a.settings)
idToken, err := a.provider.Verify(ctx, string(rawBody), a.settings)
if err != nil {
return claims, false, fmt.Errorf("user info response in jwt format not valid: %w", err)
}
@ -760,7 +912,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
// but first let's determine the expiry of the cache
var cacheExpiry time.Duration
settingExpiry := a.settings.UserInfoCacheExpiration()
tokenExpiry := getTokenExpiration(claims)
tokenExpiry := GetTokenExpiration(claims)
// only use configured expiry if the token lives longer and the expiry is configured
// if the token has no expiry, use the expiry of the actual token
@ -769,7 +921,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
case settingExpiry < tokenExpiry && settingExpiry != 0:
cacheExpiry = settingExpiry
case tokenExpiry < 0:
cacheExpiry = getTokenExpiration(actualClaims)
cacheExpiry = GetTokenExpiration(actualClaims)
default:
cacheExpiry = tokenExpiry
}
@ -797,8 +949,8 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
return claims, false, nil
}
// getTokenExpiration returns a time.Duration until the token expires
func getTokenExpiration(claims jwt.MapClaims) time.Duration {
// GetTokenExpiration returns a time.Duration until the token expires
func GetTokenExpiration(claims jwt.MapClaims) time.Duration {
// get duration until token expires
exp := jwtutil.Float64Field(claims, "exp")
tm := time.Unix(int64(exp), 0)
@ -806,12 +958,28 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
return tokenExpiry
}
// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
// getScopes returns scopes based on provider configuration
func (a *ClientApp) getScopes() []string {
scopes := make([]string, 0)
if config := a.settings.OIDCConfig(); config != nil {
scopes = GetScopesOrDefault(config.RequestedScopes)
} else if a.settings.IsDexConfigured() {
scopes = append(GetScopesOrDefault(nil), common.DexFederatedScope)
}
return scopes
}
// FormatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
func FormatUserInfoResponseCacheKey(sub string) string {
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
}
// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
// FormatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
func FormatAccessTokenCacheKey(sub string) string {
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
}
// formatRefreshTokenCacheKey returns the key which is used to store the oidc Token for a session in cache
func formatOidcTokenCacheKey(sub string, sid string) string {
return fmt.Sprintf("%s_%s_%s", OidcTokenCachePrefix, sub, sid)
}

View file

@ -1,9 +1,11 @@
package oidc
import (
"context"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -15,6 +17,8 @@ import (
"testing"
"time"
log "github.com/sirupsen/logrus"
gooidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
@ -28,6 +32,7 @@ import (
"github.com/argoproj/argo-cd/v3/util/cache"
"github.com/argoproj/argo-cd/v3/util/crypto"
"github.com/argoproj/argo-cd/v3/util/dex"
jwtutil "github.com/argoproj/argo-cd/v3/util/jwt"
"github.com/argoproj/argo-cd/v3/util/settings"
"github.com/argoproj/argo-cd/v3/util/test"
)
@ -97,9 +102,14 @@ func TestIDTokenClaims(t *testing.T) {
assert.JSONEq(t, "{\"id_token\":{\"groups\":{\"essential\":true}}}", values.Get("claims"))
}
type fakeProvider struct{}
type fakeProvider struct {
EndpointError bool
}
func (p *fakeProvider) Endpoint() (*oauth2.Endpoint, error) {
if p.EndpointError {
return nil, errors.New("fake provider endpoint error")
}
return &oauth2.Endpoint{}, nil
}
@ -107,7 +117,7 @@ func (p *fakeProvider) ParseConfig() (*OIDCConfiguration, error) {
return nil, nil
}
func (p *fakeProvider) Verify(_ string, _ *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
func (p *fakeProvider) Verify(_ context.Context, _ string, _ *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
return nil, nil
}
@ -530,7 +540,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
app.HandleCallback(w, req)
if !strings.Contains(w.Body.String(), "certificate signed by unknown authority") && !strings.Contains(w.Body.String(), "certificate is not trusted") {
t.Fatal("did not receive expected certificate verification failure error")
t.Fatalf("did not receive expected certificate verification failure error: %v", w.Code)
}
cdSettings.OIDCTLSInsecureSkipVerify = true
@ -700,7 +710,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
app.HandleCallback(w, req)
if !strings.Contains(w.Body.String(), "certificate signed by unknown authority") && !strings.Contains(w.Body.String(), "certificate is not trusted") {
t.Fatal("did not receive expected certificate verification failure error")
t.Fatalf("did not receive expected certificate verification failure error: %v", w.Code)
}
cdSettings.OIDCTLSInsecureSkipVerify = true
@ -1147,7 +1157,7 @@ func TestGetUserInfo(t *testing.T) {
require.NoError(t, err)
}
got, unauthenticated, err := a.GetUserInfo(tt.idpClaims, ts.URL, tt.userInfoPath)
got, unauthenticated, err := a.GetUserInfo(t.Context(), tt.idpClaims, ts.URL, tt.userInfoPath)
assert.Equal(t, tt.expectedOutput, got)
assert.Equal(t, tt.expectUnauthenticated, unauthenticated)
if tt.expectError {
@ -1253,7 +1263,7 @@ userInfoPath: /`,
require.NoError(t, err, "failed setting item to in-memory cache")
}
receivedClaims, err := a.SetGroupsFromUserInfo(tt.inputClaims, "argocd")
receivedClaims, err := a.SetGroupsFromUserInfo(t.Context(), tt.inputClaims, "argocd")
if tt.expectError {
require.Error(t, err)
} else {
@ -1263,3 +1273,300 @@ userInfoPath: /`,
})
}
}
func TestGetOidcTokenCacheFromJSON(t *testing.T) {
tests := []struct {
name string
oidcTokenCache *OidcTokenCache
expectErrorContains string
expectIdToken string
}{
{
name: "empty",
oidcTokenCache: &OidcTokenCache{},
expectErrorContains: "empty token",
},
{
name: "empty id token",
oidcTokenCache: &OidcTokenCache{
Token: &oauth2.Token{},
},
expectIdToken: "",
},
{
name: "simple",
oidcTokenCache: NewOidcTokenCache("", (&oauth2.Token{}).WithExtra(map[string]any{"id_token": "simple"})),
expectIdToken: "simple",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenJSON, err := json.Marshal(tt.oidcTokenCache)
require.NoError(t, err)
token, err := GetOidcTokenCacheFromJSON(tokenJSON)
if tt.expectErrorContains != "" {
assert.ErrorContains(t, err, tt.expectErrorContains)
return
}
require.NoError(t, err)
if tt.expectIdToken != "" {
assert.Equal(t, tt.expectIdToken, token.Token.Extra("id_token").(string))
}
})
}
}
func TestClientApp_GetTokenSourceFromCache(t *testing.T) {
tests := []struct {
name string
oidcTokenCache *OidcTokenCache
expectErrorContains string
provider Provider
}{
{
name: "provider error",
oidcTokenCache: &OidcTokenCache{},
expectErrorContains: "fake provider endpoint error",
provider: &fakeProvider{
EndpointError: true,
},
},
{
name: "empty oidcTokenCache",
expectErrorContains: "oidcTokenCache is required",
provider: &fakeProvider{},
},
{
name: "simple",
oidcTokenCache: NewOidcTokenCache("", (&oauth2.Token{}).WithExtra(map[string]any{"id_token": "simple"})),
provider: &fakeProvider{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := ClientApp{provider: tt.provider, settings: &settings.ArgoCDSettings{}}
tokenSource, err := app.GetTokenSourceFromCache(t.Context(), tt.oidcTokenCache)
if tt.expectErrorContains != "" {
assert.ErrorContains(t, err, tt.expectErrorContains)
return
}
require.NoError(t, err)
assert.NotNil(t, tokenSource)
})
}
}
func TestClientApp_GetUpdatedOidcTokenFromCache(t *testing.T) {
tests := []struct {
name string
subject string
session string
insertIntoCache bool
oidcTokenCache *OidcTokenCache
expectErrorContains string
expectTokenNotNil bool
}{
{
name: "empty token cache",
subject: "alice",
session: "111",
insertIntoCache: true,
expectErrorContains: "failed to unmarshal cached oidc token: empty token",
},
{
name: "no refresh token",
subject: "alice",
session: "111",
insertIntoCache: true,
oidcTokenCache: &OidcTokenCache{Token: &oauth2.Token{}},
expectErrorContains: "failed to refresh token from source: oauth2: token expired and refresh token is not set",
},
{
name: "cache miss",
subject: "",
session: "",
insertIntoCache: false,
},
{
name: "updated token from cache",
subject: "alice",
session: "111",
insertIntoCache: true,
oidcTokenCache: &OidcTokenCache{Token: &oauth2.Token{
RefreshToken: "not empty",
}},
expectTokenNotNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oidcTestServer := test.GetOIDCTestServer(t, nil)
t.Cleanup(oidcTestServer.Close)
cdSettings := &settings.ArgoCDSettings{
URL: "https://argocd.example.com",
OIDCConfigRAW: fmt.Sprintf(`
name: Test
issuer: %s
clientID: test-client-id
clientSecret: test-client-secret
requestedScopes: ["oidc"]`, oidcTestServer.URL),
OIDCTLSInsecureSkipVerify: true,
}
app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour))
require.NoError(t, err)
if tt.insertIntoCache {
oidcTokenCacheJSON, err := json.Marshal(tt.oidcTokenCache)
require.NoError(t, err)
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
}
token, err := app.GetUpdatedOidcTokenFromCache(t.Context(), tt.subject, tt.session)
if tt.expectErrorContains != "" {
assert.ErrorContains(t, err, tt.expectErrorContains)
return
}
require.NoError(t, err)
if tt.expectTokenNotNil {
assert.NotNil(t, token)
}
})
}
}
func TestClientApp_CheckAndGetRefreshToken(t *testing.T) {
tests := []struct {
name string
expectErrorContains string
expectNewToken bool
groupClaims jwt.MapClaims
refreshTokenThreshold string
}{
{
name: "no new token",
groupClaims: jwt.MapClaims{
"aud": common.ArgoCDClientAppID,
"exp": float64(time.Now().Add(time.Hour).Unix()),
"sub": "randomUser",
"sid": "1111",
"iss": "issuer",
"groups": "group1",
},
expectNewToken: false,
refreshTokenThreshold: "1m",
},
{
name: "new token",
groupClaims: jwt.MapClaims{
"aud": common.ArgoCDClientAppID,
"exp": float64(time.Now().Add(55 * time.Second).Unix()),
"sub": "randomUser",
"sid": "1111",
"iss": "issuer",
"groups": "group1",
},
expectNewToken: true,
refreshTokenThreshold: "1m",
},
{
name: "parse error",
groupClaims: jwt.MapClaims{
"aud": common.ArgoCDClientAppID,
"exp": float64(time.Now().Add(time.Minute).Unix()),
"sub": "randomUser",
"sid": "1111",
"iss": "issuer",
"groups": "group1",
},
expectNewToken: false,
refreshTokenThreshold: "1xx",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oidcTestServer := test.GetOIDCTestServer(t, nil)
t.Cleanup(oidcTestServer.Close)
cdSettings := &settings.ArgoCDSettings{
URL: "https://argocd.example.com",
OIDCConfigRAW: fmt.Sprintf(`
name: Test
issuer: %s
clientID: test-client-id
clientSecret: test-client-secret
refreshTokenThreshold: %s
requestedScopes: ["oidc"]`, oidcTestServer.URL, tt.refreshTokenThreshold),
OIDCTLSInsecureSkipVerify: true,
}
// The base href (the last argument for NewClientApp) is what HandleLogin will fall back to when no explicit
// redirect URL is given.
app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour))
require.NoError(t, err)
oidcTokenCacheJSON, err := json.Marshal(&OidcTokenCache{Token: &oauth2.Token{
RefreshToken: "not empty",
}})
require.NoError(t, err)
sub := jwtutil.StringField(tt.groupClaims, "sub")
require.NotEmpty(t, sub)
sid := jwtutil.StringField(tt.groupClaims, "sid")
require.NotEmpty(t, sid)
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
token, err := app.CheckAndRefreshToken(t.Context(), tt.groupClaims, cdSettings.RefreshTokenThreshold())
if tt.expectErrorContains != "" {
require.ErrorContains(t, err, tt.expectErrorContains)
return
}
require.NoError(t, err)
if tt.expectNewToken {
require.NotEmpty(t, token)
} else {
require.Empty(t, token)
}
})
}
}
func TestClientApp_getRedirectURIForRequest(t *testing.T) {
tests := []struct {
name string
req *http.Request
expectLogContains string
expectedRequestURI string
expectError bool
}{
{
name: "empty",
req: &http.Request{
URL: &url.URL{},
},
},
{
name: "nil URL",
expectLogContains: "falling back to configured redirect URI",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := ClientApp{provider: &fakeProvider{}, settings: &settings.ArgoCDSettings{}}
hook := test.LogHook{}
log.AddHook(&hook)
t.Cleanup(func() {
log.StandardLogger().ReplaceHooks(log.LevelHooks{})
})
redirectURI := app.getRedirectURIForRequest(tt.req)
if tt.expectLogContains != "" {
assert.NotEmpty(t, hook.GetRegexMatchesInEntries(tt.expectLogContains), "expected log")
} else {
assert.Empty(t, hook.Entries, "expected log")
}
if tt.req == nil {
return
}
expectedRedirectURI, err := app.settings.RedirectURLForRequest(tt.req)
if tt.expectError {
assert.Error(t, err)
return
}
assert.Equal(t, expectedRedirectURI, redirectURI, "expected URI")
})
}
}

View file

@ -27,7 +27,7 @@ type Provider interface {
ParseConfig() (*OIDCConfiguration, error)
Verify(tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error)
Verify(ctx context.Context, tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error)
}
type providerImpl struct {
@ -85,7 +85,7 @@ func (t tokenVerificationError) Error() string {
return "token verification failed for all audiences: " + strings.Join(errorStrings, ", ")
}
func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
// According to the JWT spec, the aud claim is optional. The spec also says (emphasis mine):
//
// If the principal processing the claim does not identify itself with a value in the "aud" claim _when this
@ -110,7 +110,7 @@ func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDS
var idToken *gooidc.IDToken
if !unverifiedHasAudClaim {
idToken, err = p.verify("", tokenString, argoSettings.SkipAudienceCheckWhenTokenHasNoAudience())
idToken, err = p.verify(ctx, "", tokenString, argoSettings.SkipAudienceCheckWhenTokenHasNoAudience())
} else {
allowedAudiences := argoSettings.OAuth2AllowedAudiences()
if len(allowedAudiences) == 0 {
@ -119,7 +119,7 @@ func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDS
tokenVerificationErrors := make(map[string]error)
// Token must be verified for at least one allowed audience
for _, aud := range allowedAudiences {
idToken, err = p.verify(aud, tokenString, false)
idToken, err = p.verify(ctx, aud, tokenString, false)
tokenExpiredError := &gooidc.TokenExpiredError{}
if errors.As(err, &tokenExpiredError) {
// If the token is expired, we won't bother checking other audiences. It's important to return a
@ -143,14 +143,13 @@ func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDS
}
if err != nil {
return nil, fmt.Errorf("failed to verify token: %w", err)
return nil, fmt.Errorf("failed to verify provider token: %w", err)
}
return idToken, nil
}
func (p *providerImpl) verify(clientID, tokenString string, skipClientIDCheck bool) (*gooidc.IDToken, error) {
ctx := context.Background()
func (p *providerImpl) verify(ctx context.Context, clientID, tokenString string, skipClientIDCheck bool) (*gooidc.IDToken, error) {
prov, err := p.provider()
if err != nil {
return nil, err

View file

@ -489,7 +489,7 @@ func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool, isSSOConfigured boo
// TokenVerifier defines the contract to invoke token
// verification logic
type TokenVerifier interface {
VerifyToken(token string) (jwt.Claims, string, error)
VerifyToken(ctx context.Context, token string) (jwt.Claims, string, error)
}
// WithAuthMiddleware is an HTTP middleware used to ensure incoming
@ -504,12 +504,13 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookies := r.Cookies()
ctx := r.Context()
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
if err != nil {
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
return
}
claims, _, err := authn.VerifyToken(tokenString)
claims, _, err := authn.VerifyToken(ctx, tokenString)
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
@ -517,14 +518,12 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
finalClaims := claims
if isSSOConfigured {
finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(claims, SessionManagerClaimsIssuer)
finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(ctx, claims, SessionManagerClaimsIssuer)
if err != nil {
http.Error(w, "Invalid session", http.StatusUnauthorized)
return
}
}
ctx := r.Context()
// Add claims to the context to inspect for RBAC
//nolint:staticcheck
ctx = context.WithValue(ctx, "claims", finalClaims)
@ -536,7 +535,7 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
// VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
// We choose how to verify based on the issuer.
func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, error) {
func (mgr *SessionManager) VerifyToken(ctx context.Context, tokenString string) (jwt.Claims, string, error) {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(tokenString, &claims)
@ -564,12 +563,12 @@ func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string,
return nil, "", errors.New("settings are not available while verifying the token")
}
idToken, err := prov.Verify(tokenString, argoSettings)
idToken, err := prov.Verify(ctx, tokenString, argoSettings)
// The token verification has failed. If the token has expired, we will
// return a dummy claims only containing a value for the issuer, so the
// UI can handle expired tokens appropriately.
if err != nil {
log.Warnf("Failed to verify token: %s", err)
log.Warnf("Failed to verify session token: %s", err)
tokenExpiredError := &oidc.TokenExpiredError{}
if errors.As(err, &tokenExpiredError) {
claims = jwt.MapClaims{

View file

@ -228,7 +228,7 @@ type tokenVerifierMock struct {
err error
}
func (tm *tokenVerifierMock) VerifyToken(_ string) (jwt.Claims, string, error) {
func (tm *tokenVerifierMock) VerifyToken(_ context.Context, _ string) (jwt.Claims, string, error) {
if tm.claims == nil {
return nil, "", tm.err
}
@ -255,7 +255,7 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) {
}
}
jsonClaims, err := json.Marshal(gotClaims)
require.NoError(t, err, "erorr marshalling claims set by AuthMiddleware")
require.NoError(t, err, "error marshalling claims set by AuthMiddleware")
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonClaims)
require.NoError(t, err, "error writing response: %s", err)
@ -720,7 +720,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
assert.NotContains(t, err.Error(), "oidc: id token signed with unsupported algorithm")
})
@ -752,7 +752,7 @@ rootCA: |
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
// If the root CA is being respected, we won't get this error. The error message is environment-dependent, so
// we check for either of the error messages associated with a failed cert check.
assert.NotContains(t, err.Error(), "certificate is not trusted")
@ -789,7 +789,7 @@ rootCA: |
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -824,7 +824,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -859,7 +859,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -895,7 +895,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
assert.NotContains(t, err.Error(), "certificate is not trusted")
assert.NotContains(t, err.Error(), "certificate signed by unknown authority")
})
@ -924,7 +924,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
// This is the error thrown when the test server's certificate _is_ being verified.
assert.NotContains(t, err.Error(), "certificate is not trusted")
assert.NotContains(t, err.Error(), "certificate signed by unknown authority")
@ -961,7 +961,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
})
@ -997,7 +997,7 @@ skipAudienceCheckWhenTokenHasNoAudience: true`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.NoError(t, err)
})
@ -1033,7 +1033,7 @@ skipAudienceCheckWhenTokenHasNoAudience: false`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -1069,7 +1069,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.NoError(t, err)
})
@ -1106,7 +1106,7 @@ allowedAudiences:
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.NoError(t, err)
})
@ -1143,7 +1143,7 @@ allowedAudiences:
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -1179,7 +1179,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -1216,7 +1216,7 @@ allowedAudiences: []`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})
@ -1254,7 +1254,7 @@ allowedAudiences: ["aud-a", "aud-b"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.NoError(t, err)
})
@ -1289,7 +1289,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
tokenString, err := token.SignedString(key)
require.NoError(t, err)
_, _, err = mgr.VerifyToken(tokenString)
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
require.Error(t, err)
assert.ErrorIs(t, err, common.ErrTokenVerification)
})

View file

@ -136,6 +136,9 @@ type ArgoCDSettings struct {
// token verification to pass despite the OIDC provider having an invalid certificate. Only set to `true` if you
// understand the risks.
OIDCTLSInsecureSkipVerify bool `json:"oidcTLSInsecureSkipVerify"`
// OIDCRefreshTokenThreshold sets the threshold for preemptive server-side token refresh. If set to 0, tokens
// will not be refreshed and will expire before client is redirected to login.
OIDCRefreshTokenThreshold time.Duration `json:"oidcRefreshTokenThreshold,omitempty"`
// AppsInAnyNamespaceEnabled indicates whether applications are allowed to be created in any namespace
AppsInAnyNamespaceEnabled bool `json:"appsInAnyNamespaceEnabled"`
// ExtensionConfig configurations related to ArgoCD proxy extensions. The keys are the extension name.
@ -193,6 +196,7 @@ func (o *oidcConfig) toExported() *OIDCConfig {
UserInfoPath: o.UserInfoPath,
EnableUserInfoGroups: o.EnableUserInfoGroups,
UserInfoCacheExpiration: o.UserInfoCacheExpiration,
RefreshTokenThreshold: o.RefreshTokenThreshold,
RequestedScopes: o.RequestedScopes,
RequestedIDTokenClaims: o.RequestedIDTokenClaims,
LogoutURL: o.LogoutURL,
@ -218,6 +222,7 @@ type OIDCConfig struct {
EnablePKCEAuthentication bool `json:"enablePKCEAuthentication,omitempty"`
DomainHint string `json:"domainHint,omitempty"`
Azure *AzureOIDCConfig `json:"azure,omitempty"`
RefreshTokenThreshold string `json:"refreshTokenThreshold,omitempty"`
}
type AzureOIDCConfig struct {
@ -1432,6 +1437,7 @@ func getDownloadBinaryUrlsFromConfigMap(argoCDCM *corev1.ConfigMap) map[string]s
func updateSettingsFromConfigMap(settings *ArgoCDSettings, argoCDCM *corev1.ConfigMap) {
settings.DexConfig = argoCDCM.Data[settingDexConfigKey]
settings.OIDCConfigRAW = argoCDCM.Data[settingsOIDCConfigKey]
settings.OIDCRefreshTokenThreshold = settings.RefreshTokenThreshold()
settings.KustomizeBuildOptions = argoCDCM.Data[kustomizeBuildOptionsKey]
settings.StatusBadgeEnabled = argoCDCM.Data[statusBadgeEnabledKey] == "true"
settings.StatusBadgeRootUrl = argoCDCM.Data[statusBadgeRootURLKey]
@ -1882,6 +1888,18 @@ func (a *ArgoCDSettings) UserInfoCacheExpiration() time.Duration {
return 0
}
// RefreshTokenThreshold returns the duration before token expiration that a token should be refreshed by the server
func (a *ArgoCDSettings) RefreshTokenThreshold() time.Duration {
if oidcConfig := a.OIDCConfig(); oidcConfig != nil && oidcConfig.RefreshTokenThreshold != "" {
refreshTokenThreshold, err := time.ParseDuration(oidcConfig.RefreshTokenThreshold)
if err != nil {
log.Warnf("Failed to parse 'oidc.config.refreshTokenThreshold' key: %v", err)
}
return refreshTokenThreshold
}
return 0
}
func (a *ArgoCDSettings) OAuth2ClientID() string {
if oidcConfig := a.OIDCConfig(); oidcConfig != nil {
return oidcConfig.ClientID
@ -2001,6 +2019,9 @@ func (a *ArgoCDSettings) ArgoURLForRequest(r *http.Request) (string, error) {
}
func (a *ArgoCDSettings) RedirectURLForRequest(r *http.Request) (string, error) {
if r == nil {
return "", errors.New("request is nil")
}
base, err := a.ArgoURLForRequest(r)
if err != nil {
return "", err

View file

@ -1020,43 +1020,94 @@ func TestSettingsManager_GetSettings(t *testing.T) {
}
func TestGetOIDCConfig(t *testing.T) {
kubeClient := fake.NewClientset(
&corev1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: common.ArgoCDConfigMapName,
Namespace: "default",
Labels: map[string]string{
"app.kubernetes.io/part-of": "argocd",
},
},
Data: map[string]string{
testCases := []struct {
name string
configMapData map[string]string
testFunc func(t *testing.T, settingsManager *SettingsManager)
}{
{
name: "requestedIDTokenClaims",
configMapData: map[string]string{
"oidc.config": "\n requestedIDTokenClaims: {\"groups\": {\"essential\": true}}\n",
},
testFunc: func(t *testing.T, settingsManager *SettingsManager) {
t.Helper()
settings, err := settingsManager.GetSettings()
require.NoError(t, err)
oidcConfig := settings.OIDCConfig()
assert.NotNil(t, oidcConfig)
claim := oidcConfig.RequestedIDTokenClaims["groups"]
assert.NotNil(t, claim)
assert.True(t, claim.Essential)
},
},
&corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: common.ArgoCDSecretName,
Namespace: "default",
Labels: map[string]string{
"app.kubernetes.io/part-of": "argocd",
{
name: "refreshTokenThreshold success",
configMapData: map[string]string{
"oidc.config": "\n refreshTokenThreshold: 5m\n",
},
testFunc: func(t *testing.T, settingsManager *SettingsManager) {
t.Helper()
settings, err := settingsManager.GetSettings()
require.NoError(t, err)
oidcConfig := settings.OIDCConfig()
assert.NotNil(t, oidcConfig)
assert.Equal(t, 5*time.Minute, settings.RefreshTokenThreshold())
},
},
{
name: "refreshTokenThreshold parse failure",
configMapData: map[string]string{
"oidc.config": "\n refreshTokenThreshold: 5xx\n",
},
testFunc: func(t *testing.T, settingsManager *SettingsManager) {
t.Helper()
settings, err := settingsManager.GetSettings()
require.NoError(t, err)
oidcConfig := settings.OIDCConfig()
assert.NotNil(t, oidcConfig)
assert.Equal(t, time.Duration(0), settings.RefreshTokenThreshold())
},
},
}
for i := range testCases {
tc := testCases[i]
t.Run(tc.name, func(t *testing.T) {
kubeClient := fake.NewClientset(
&corev1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: common.ArgoCDConfigMapName,
Namespace: "default",
Labels: map[string]string{
"app.kubernetes.io/part-of": "argocd",
},
},
Data: tc.configMapData,
},
},
Data: map[string][]byte{
"admin.password": nil,
"server.secretkey": nil,
},
},
)
settingsManager := NewSettingsManager(t.Context(), kubeClient, "default")
settings, err := settingsManager.GetSettings()
require.NoError(t, err)
oidcConfig := settings.OIDCConfig()
assert.NotNil(t, oidcConfig)
claim := oidcConfig.RequestedIDTokenClaims["groups"]
assert.NotNil(t, claim)
assert.True(t, claim.Essential)
&corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: common.ArgoCDSecretName,
Namespace: "default",
Labels: map[string]string{
"app.kubernetes.io/part-of": "argocd",
},
},
Data: map[string][]byte{
"admin.password": nil,
"server.secretkey": nil,
},
},
)
settingsManager := NewSettingsManager(t.Context(), kubeClient, "default")
tc.testFunc(t, settingsManager)
})
}
}
func TestRedirectURL(t *testing.T) {