mirror of
https://github.com/argoproj/argo-cd
synced 2026-04-21 17:07:16 +00:00
feat: oidc background token refresh (#23727)
Signed-off-by: Mike Cutsail <mcutsail15@apple.com>
This commit is contained in:
parent
60f2ff5f77
commit
5c6aa59ed3
14 changed files with 831 additions and 164 deletions
|
|
@ -67,6 +67,7 @@ data:
|
|||
issuer: https://keycloak.example.com/realms/master
|
||||
clientID: argocd
|
||||
clientSecret: $oidc.keycloak.clientSecret
|
||||
refreshTokenThreshold: 2m
|
||||
requestedScopes: ["openid", "profile", "email", "groups"]
|
||||
```
|
||||
|
||||
|
|
@ -77,6 +78,7 @@ Make sure that:
|
|||
- __clientID__ is set to the Client ID you configured in Keycloak
|
||||
- __clientSecret__ points to the right key you created in the _argocd-secret_ Secret
|
||||
- __requestedScopes__ contains the _groups_ claim if you didn't add it to the Default scopes
|
||||
- __refreshTokenThreshold__ is less than the client token lifetime. If this setting is not less than the token lifetime, a new token will be obtained for every request. Keycloak sets the client token lifetime to 5 minutes by default.
|
||||
|
||||
## Keycloak and ArgoCD with PKCE
|
||||
|
||||
|
|
@ -135,6 +137,7 @@ data:
|
|||
issuer: https://keycloak.example.com/realms/master
|
||||
clientID: argocd
|
||||
enablePKCEAuthentication: true
|
||||
refreshTokenThreshold: 2m
|
||||
requestedScopes: ["openid", "profile", "email", "groups"]
|
||||
```
|
||||
|
||||
|
|
@ -145,6 +148,7 @@ Make sure that:
|
|||
- __clientID__ is set to the Client ID you configured in Keycloak
|
||||
- __enablePKCEAuthentication__ must be set to true to enable correct ArgoCD behaviour with PKCE
|
||||
- __requestedScopes__ contains the _groups_ claim if you didn't add it to the Default scopes
|
||||
- __refreshTokenThreshold__ is less than the client token lifetime. If this setting is not less than the token lifetime, a new token will be obtained for every request. Keycloak sets the client token lifetime to 5 minutes by default.
|
||||
|
||||
## Configuring the groups claim
|
||||
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ func (t *terminalSession) performValidationsAndReconnect(p []byte) (int, error)
|
|||
}
|
||||
|
||||
// check if token still valid
|
||||
_, newToken, err := t.sessionManager.VerifyToken(*t.token)
|
||||
_, newToken, err := t.sessionManager.VerifyToken(t.ctx, *t.token)
|
||||
// err in case if token is revoked, newToken in case if refresh happened
|
||||
if err != nil || newToken != "" {
|
||||
// need to send reconnect code in case if token was refreshed
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func NewHandler(settingsMrg *settings.SettingsManager, sessionMgr *session.Sessi
|
|||
type Handler struct {
|
||||
settingsMgr *settings.SettingsManager
|
||||
rootPath string
|
||||
verifyToken func(tokenString string) (jwt.Claims, string, error)
|
||||
verifyToken func(ctx context.Context, tokenString string) (jwt.Claims, string, error)
|
||||
revokeToken func(ctx context.Context, id string, expiringAt time.Duration) error
|
||||
baseHRef string
|
||||
}
|
||||
|
|
@ -94,7 +94,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Add("Set-Cookie", argocdCookie.String())
|
||||
}
|
||||
|
||||
claims, _, err := h.verifyToken(tokenString)
|
||||
claims, _, err := h.verifyToken(r.Context(), tokenString)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, logoutRedirectURL, http.StatusSeeOther)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package logout
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -245,28 +246,28 @@ func TestHandlerConstructLogoutURL(t *testing.T) {
|
|||
sessionManager := session.NewSessionManager(settingsManagerWithOIDCConfig, test.NewFakeProjLister(), "", nil, session.NewUserStateStorage(nil))
|
||||
|
||||
oidcHandler := NewHandler(settingsManagerWithOIDCConfig, sessionManager, rootPath, baseHRef)
|
||||
oidcHandler.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
|
||||
oidcHandler.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
if !validJWTPattern.MatchString(tokenString) {
|
||||
return nil, "", errors.New("invalid jwt")
|
||||
}
|
||||
return &jwt.RegisteredClaims{Issuer: "okta"}, "", nil
|
||||
}
|
||||
nonoidcHandler := NewHandler(settingsManagerWithoutOIDCConfig, sessionManager, "", baseHRef)
|
||||
nonoidcHandler.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
|
||||
nonoidcHandler.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
if !validJWTPattern.MatchString(tokenString) {
|
||||
return nil, "", errors.New("invalid jwt")
|
||||
}
|
||||
return &jwt.RegisteredClaims{Issuer: session.SessionManagerClaimsIssuer}, "", nil
|
||||
}
|
||||
oidcHandlerWithoutLogoutURL := NewHandler(settingsManagerWithOIDCConfigButNoLogoutURL, sessionManager, "", baseHRef)
|
||||
oidcHandlerWithoutLogoutURL.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
|
||||
oidcHandlerWithoutLogoutURL.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
if !validJWTPattern.MatchString(tokenString) {
|
||||
return nil, "", errors.New("invalid jwt")
|
||||
}
|
||||
return &jwt.RegisteredClaims{Issuer: "okta"}, "", nil
|
||||
}
|
||||
nonoidcHandlerWithMultipleURLs := NewHandler(settingsManagerWithoutOIDCAndMultipleURLs, sessionManager, "", baseHRef)
|
||||
nonoidcHandlerWithMultipleURLs.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
|
||||
nonoidcHandlerWithMultipleURLs.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
if !validJWTPattern.MatchString(tokenString) {
|
||||
return nil, "", errors.New("invalid jwt")
|
||||
}
|
||||
|
|
@ -274,7 +275,7 @@ func TestHandlerConstructLogoutURL(t *testing.T) {
|
|||
}
|
||||
|
||||
oidcHandlerWithoutBaseURL := NewHandler(settingsManagerWithOIDCConfigButNoURL, sessionManager, "argocd", baseHRef)
|
||||
oidcHandlerWithoutBaseURL.verifyToken = func(tokenString string) (jwt.Claims, string, error) {
|
||||
oidcHandlerWithoutBaseURL.verifyToken = func(_ context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
if !validJWTPattern.MatchString(tokenString) {
|
||||
return nil, "", errors.New("invalid jwt")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -323,6 +323,8 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio
|
|||
appsetLister := appFactory.Argoproj().V1alpha1().ApplicationSets().Lister()
|
||||
|
||||
userStateStorage := util_session.NewUserStateStorage(opts.RedisClient)
|
||||
ssoClientApp, err := oidc.NewClientApp(settings, opts.DexServerAddr, opts.DexTLSConfig, opts.BaseHRef, cacheutil.NewRedisCache(opts.RedisClient, settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
|
||||
errorsutil.CheckError(err)
|
||||
sessionMgr := util_session.NewSessionManager(settingsMgr, projLister, opts.DexServerAddr, opts.DexTLSConfig, userStateStorage)
|
||||
enf := rbac.NewEnforcer(opts.KubeClientset, opts.Namespace, common.ArgoCDRBACConfigMapName, nil)
|
||||
enf.EnableEnforce(!opts.DisableAuth)
|
||||
|
|
@ -370,6 +372,7 @@ func NewServer(ctx context.Context, opts ArgoCDServerOpts, appsetOpts Applicatio
|
|||
a := &ArgoCDServer{
|
||||
ArgoCDServerOpts: opts,
|
||||
ApplicationSetOpts: appsetOpts,
|
||||
ssoClientApp: ssoClientApp,
|
||||
log: logger,
|
||||
settings: settings,
|
||||
sessionMgr: sessionMgr,
|
||||
|
|
@ -1125,19 +1128,7 @@ func (server *ArgoCDServer) translateGrpcCookieHeader(ctx context.Context, w htt
|
|||
}
|
||||
|
||||
func (server *ArgoCDServer) setTokenCookie(token string, w http.ResponseWriter) error {
|
||||
cookiePath := "path=/" + strings.TrimRight(strings.TrimLeft(server.BaseHRef, "/"), "/")
|
||||
flags := []string{cookiePath, "SameSite=lax", "httpOnly"}
|
||||
if !server.Insecure {
|
||||
flags = append(flags, "Secure")
|
||||
}
|
||||
cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, token, flags...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating cookie metadata: %w", err)
|
||||
}
|
||||
for _, cookie := range cookies {
|
||||
w.Header().Add("Set-Cookie", cookie)
|
||||
}
|
||||
return nil
|
||||
return httputil.SetTokenCookie(token, server.BaseHRef, !server.Insecure, w)
|
||||
}
|
||||
|
||||
func withRootPath(handler http.Handler, a *ArgoCDServer) http.Handler {
|
||||
|
|
@ -1221,9 +1212,6 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
|
|||
|
||||
terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf}
|
||||
|
||||
// SSO ClientApp
|
||||
server.ssoClientApp, _ = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
|
||||
|
||||
terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts).
|
||||
WithFeatureFlagMiddleware(server.settingsMgr.GetSettings)
|
||||
th := util_session.WithAuthMiddleware(server.DisableAuth, server.settings.IsSSOConfigured(), server.ssoClientApp, server.sessionMgr, terminal)
|
||||
|
|
@ -1368,9 +1356,7 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
|||
return
|
||||
}
|
||||
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
|
||||
var err error
|
||||
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
|
||||
errorsutil.CheckError(err)
|
||||
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
|
||||
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
|
||||
}
|
||||
|
|
@ -1566,6 +1552,7 @@ func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context,
|
|||
return ctx, nil
|
||||
}
|
||||
|
||||
// getClaims extracts, validates and refreshes a JWT token from an incoming request context.
|
||||
func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
|
|
@ -1575,17 +1562,29 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string,
|
|||
if tokenString == "" {
|
||||
return nil, "", ErrNoSession
|
||||
}
|
||||
claims, newToken, err := server.sessionMgr.VerifyToken(tokenString)
|
||||
// A valid argocd-issued token is automatically refreshed here prior to expiration.
|
||||
// OIDC tokens will be verified but will not be refreshed here.
|
||||
claims, newToken, err := server.sessionMgr.VerifyToken(ctx, tokenString)
|
||||
if err != nil {
|
||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||
}
|
||||
|
||||
finalClaims := claims
|
||||
if server.settings.IsSSOConfigured() {
|
||||
finalClaims, err = server.ssoClientApp.SetGroupsFromUserInfo(claims, util_session.SessionManagerClaimsIssuer)
|
||||
updatedClaims, err := server.ssoClientApp.SetGroupsFromUserInfo(ctx, claims, util_session.SessionManagerClaimsIssuer)
|
||||
if err != nil {
|
||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||
}
|
||||
finalClaims = updatedClaims
|
||||
// OIDC tokens are automatically refreshed here prior to expiration
|
||||
refreshedToken, err := server.ssoClientApp.CheckAndRefreshToken(ctx, updatedClaims, server.settings.OIDCRefreshTokenThreshold)
|
||||
if err != nil {
|
||||
log.Errorf("error checking and refreshing token: %v", err)
|
||||
}
|
||||
if refreshedToken != "" && refreshedToken != tokenString {
|
||||
newToken = refreshedToken
|
||||
log.Infof("refreshed token for subject: %v", jwtutil.StringField(updatedClaims, "sub"))
|
||||
}
|
||||
}
|
||||
|
||||
return finalClaims, newToken, nil
|
||||
|
|
|
|||
|
|
@ -241,3 +241,23 @@ func drainBody(body io.ReadCloser) {
|
|||
log.Warnf("error reading response body: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func SetTokenCookie(token string, baseHRef string, isSecure bool, w http.ResponseWriter) error {
|
||||
var path string
|
||||
if baseHRef != "" {
|
||||
path = strings.TrimRight(strings.TrimLeft(baseHRef, "/"), "/")
|
||||
}
|
||||
cookiePath := "path=/" + path
|
||||
flags := []string{cookiePath, "SameSite=lax", "httpOnly"}
|
||||
if isSecure {
|
||||
flags = append(flags, "Secure")
|
||||
}
|
||||
cookies, err := MakeCookieMetadata(common.AuthCookieName, token, flags...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating cookie metadata: %w", err)
|
||||
}
|
||||
for _, cookie := range cookies {
|
||||
w.Header().Add("Set-Cookie", cookie)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/argoproj/argo-cd/v3/common"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -49,6 +52,101 @@ func TestSplitCookie(t *testing.T) {
|
|||
assert.Equal(t, cookieValue, token)
|
||||
}
|
||||
|
||||
// mockResponseWriter is a mock implementation of http.ResponseWriter.
|
||||
// It captures added headers for verification in tests.
|
||||
type mockResponseWriter struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) Header() http.Header {
|
||||
if m.header == nil {
|
||||
m.header = make(http.Header)
|
||||
}
|
||||
return m.header
|
||||
}
|
||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (m *mockResponseWriter) WriteHeader(_ int) {}
|
||||
|
||||
func TestSetTokenCookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
baseHRef string
|
||||
isSecure bool
|
||||
expectedCookies []string // Expected Set-Cookie header values
|
||||
}{
|
||||
{
|
||||
name: "Insecure cookie",
|
||||
token: "insecure-token",
|
||||
baseHRef: "",
|
||||
isSecure: false,
|
||||
expectedCookies: []string{
|
||||
fmt.Sprintf("%s=%s; path=/; SameSite=lax; httpOnly", common.AuthCookieName, "insecure-token"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Secure cookie",
|
||||
token: "secure-token",
|
||||
baseHRef: "",
|
||||
isSecure: true,
|
||||
expectedCookies: []string{
|
||||
fmt.Sprintf("%s=%s; path=/; SameSite=lax; httpOnly; Secure", common.AuthCookieName, "secure-token"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Insecure cookie with baseHRef",
|
||||
token: "token-with-path",
|
||||
baseHRef: "/app",
|
||||
isSecure: false,
|
||||
expectedCookies: []string{
|
||||
fmt.Sprintf("%s=%s; path=/app; SameSite=lax; httpOnly", common.AuthCookieName, "token-with-path"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Secure cookie with baseHRef",
|
||||
token: "secure-token-with-path",
|
||||
baseHRef: "app/",
|
||||
isSecure: true,
|
||||
expectedCookies: []string{
|
||||
fmt.Sprintf("%s=%s; path=/app; SameSite=lax; httpOnly; Secure", common.AuthCookieName, "secure-token-with-path"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unsecured cookie, baseHRef with multiple segments and mixed slashes",
|
||||
token: "complex-path-token",
|
||||
baseHRef: "///api/v1/auth///",
|
||||
isSecure: false,
|
||||
expectedCookies: []string{
|
||||
fmt.Sprintf("%s=%s; path=/api/v1/auth; SameSite=lax; httpOnly", common.AuthCookieName, "complex-path-token"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &mockResponseWriter{}
|
||||
|
||||
err := SetTokenCookie(tt.token, tt.baseHRef, tt.isSecure, w)
|
||||
if err != nil {
|
||||
t.Fatalf("%s: Unexpected error: %v", tt.name, err)
|
||||
}
|
||||
|
||||
setCookieHeaders := w.Header()["Set-Cookie"]
|
||||
|
||||
if len(setCookieHeaders) != len(tt.expectedCookies) {
|
||||
t.Errorf("Mistmatch in Set-Cookie header length: %s\nExpected: %d\nGot: %d",
|
||||
tt.name, len(tt.expectedCookies), len(setCookieHeaders))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tt.expectedCookies) > 0 && setCookieHeaders[0] != tt.expectedCookies[0] {
|
||||
t.Errorf("Mismatch in Set-Cookie header: %s\nExpected: %s\nGot: %s",
|
||||
tt.name, tt.expectedCookies[0], setCookieHeaders[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoundTripper just copy request headers to the resposne.
|
||||
type TestRoundTripper struct{}
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ const (
|
|||
ResponseTypeCode = "code"
|
||||
UserInfoResponseCachePrefix = "userinfo_response"
|
||||
AccessTokenCachePrefix = "access_token"
|
||||
OidcTokenCachePrefix = "oidc_token"
|
||||
)
|
||||
|
||||
// OIDCConfiguration holds a subset of interested fields from the OIDC configuration spec
|
||||
|
|
@ -87,6 +88,8 @@ type ClientApp struct {
|
|||
clientCache cache.CacheClient
|
||||
// properties for azure workload identity.
|
||||
azure azureApp
|
||||
// preemptive token refresh threshold
|
||||
refreshTokenThreshold time.Duration
|
||||
}
|
||||
|
||||
type azureApp struct {
|
||||
|
|
@ -98,6 +101,63 @@ type azureApp struct {
|
|||
mtx *sync.RWMutex
|
||||
}
|
||||
|
||||
// OidcTokenCache is a serialization wrapper around oauth2 provider configuration needed to generate a TokenSource
|
||||
type OidcTokenCache struct {
|
||||
// Redirect URL is needed for oauth2 config initialization
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
// oauth2 Token
|
||||
Token *oauth2.Token `json:"token"`
|
||||
// TokenExtraIdToken captures value of id_token
|
||||
TokenExtraIdToken string `json:"token_extra_id_token"`
|
||||
}
|
||||
|
||||
// NewOidcTokenCache initializes the struct from a redirect URL and an existing token
|
||||
func NewOidcTokenCache(redirectURL string, token *oauth2.Token) *OidcTokenCache {
|
||||
var idToken string
|
||||
if token.Extra("id_token") == nil {
|
||||
idToken = ""
|
||||
} else {
|
||||
idToken = token.Extra("id_token").(string)
|
||||
}
|
||||
return &OidcTokenCache{
|
||||
RedirectURL: redirectURL,
|
||||
Token: token,
|
||||
TokenExtraIdToken: idToken,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOidcTokenCacheFromJSON deserializes the json representation of OidcTokenCache. The Token extra map is updated from
|
||||
// the serialization wrapper to propagate the id_token. This will ensure that the TokenSource always retrieves a usable token.
|
||||
func GetOidcTokenCacheFromJSON(jsonBytes []byte) (*OidcTokenCache, error) {
|
||||
var newToken OidcTokenCache
|
||||
err := json.Unmarshal(jsonBytes, &newToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if newToken.Token == nil {
|
||||
return nil, errors.New("empty token")
|
||||
}
|
||||
newToken.Token = newToken.Token.WithExtra(map[string]any{
|
||||
"id_token": newToken.TokenExtraIdToken,
|
||||
})
|
||||
return &newToken, nil
|
||||
}
|
||||
|
||||
// GetTokenSourceFromCache creates an oauth2 TokenSource from a cached oidc token. The TokenSource will be configured
|
||||
// with an early expiration based on the refreshTokenThreshold.
|
||||
func (a *ClientApp) GetTokenSourceFromCache(ctx context.Context, oidcTokenCache *OidcTokenCache) (oauth2.TokenSource, error) {
|
||||
if oidcTokenCache == nil {
|
||||
return nil, errors.New("oidcTokenCache is required")
|
||||
}
|
||||
config, err := a.getOauth2ConfigForRedirectURI(oidcTokenCache.RedirectURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
baseTokenSource := config.TokenSource(ctx, oidcTokenCache.Token)
|
||||
tokenRefresher := oauth2.ReuseTokenSourceWithExpiry(oidcTokenCache.Token, baseTokenSource, a.refreshTokenThreshold)
|
||||
return tokenRefresher, nil
|
||||
}
|
||||
|
||||
func GetScopesOrDefault(scopes []string) []string {
|
||||
if len(scopes) == 0 {
|
||||
return []string{"openid", "profile", "email", "groups"}
|
||||
|
|
@ -127,6 +187,7 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL
|
|||
encryptionKey: encryptionKey,
|
||||
clientCache: cacheClient,
|
||||
azure: azureApp{mtx: &sync.RWMutex{}},
|
||||
refreshTokenThreshold: settings.OIDCRefreshTokenThreshold,
|
||||
}
|
||||
log.Infof("Creating client app (%s)", a.clientID)
|
||||
u, err := url.Parse(settings.URL)
|
||||
|
|
@ -165,23 +226,27 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL
|
|||
return &a, nil
|
||||
}
|
||||
|
||||
func (a *ClientApp) oauth2Config(request *http.Request, scopes []string) (*oauth2.Config, error) {
|
||||
func (a *ClientApp) getRedirectURIForRequest(req *http.Request) string {
|
||||
redirectURI, err := a.settings.RedirectURLForRequest(req)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to find ArgoCD URL from request, falling back to configured redirect URI: %v", err)
|
||||
redirectURI = a.redirectURI
|
||||
}
|
||||
return redirectURI
|
||||
}
|
||||
|
||||
func (a *ClientApp) getOauth2ConfigForRedirectURI(redirectURI string) (*oauth2.Config, error) {
|
||||
endpoint, err := a.provider.Endpoint()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirectURL, err := a.settings.RedirectURLForRequest(request)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to find ArgoCD URL from request, falling back to configured redirect URI: %v", err)
|
||||
redirectURL = a.redirectURI
|
||||
}
|
||||
|
||||
return &oauth2.Config{
|
||||
ClientID: a.clientID,
|
||||
ClientSecret: a.clientSecret,
|
||||
Endpoint: *endpoint,
|
||||
Scopes: scopes,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: a.getScopes(),
|
||||
RedirectURL: redirectURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -315,17 +380,13 @@ func (a *ClientApp) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
scopes := make([]string, 0)
|
||||
pkceVerifier := ""
|
||||
var opts []oauth2.AuthCodeOption
|
||||
if config := a.settings.OIDCConfig(); config != nil {
|
||||
scopes = GetScopesOrDefault(config.RequestedScopes)
|
||||
opts = AppendClaimsAuthenticationRequestParameter(opts, config.RequestedIDTokenClaims)
|
||||
} else if a.settings.IsDexConfigured() {
|
||||
scopes = append(GetScopesOrDefault(nil), common.DexFederatedScope)
|
||||
}
|
||||
|
||||
oauth2Config, err := a.oauth2Config(r, scopes)
|
||||
oauth2Config, err := a.getOauth2ConfigForRedirectURI(a.getRedirectURIForRequest(r))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -406,7 +467,7 @@ func (a *azureApp) getFederatedServiceAccountToken(context.Context) (string, err
|
|||
|
||||
// HandleCallback is the callback handler for an OAuth2 login flow
|
||||
func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
oauth2Config, err := a.oauth2Config(r, nil)
|
||||
oauth2Config, err := a.getOauth2ConfigForRedirectURI(a.getRedirectURIForRequest(r))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -456,27 +517,21 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Parse out id token
|
||||
idTokenRAW, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "no id_token in token response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := a.provider.Verify(idTokenRAW, a.settings)
|
||||
idToken, err := a.provider.Verify(ctx, idTokenRAW, a.settings)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to verify token: %s", err)
|
||||
log.Warnf("Failed to verify oidc token: %s", err)
|
||||
http.Error(w, common.TokenVerificationError, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
path := "/"
|
||||
if a.baseHRef != "" {
|
||||
path = strings.TrimRight(strings.TrimLeft(a.baseHRef, "/"), "/")
|
||||
}
|
||||
cookiePath := "path=/" + path
|
||||
flags := []string{cookiePath, "SameSite=lax", "httpOnly"}
|
||||
if a.secureCookie {
|
||||
flags = append(flags, "Secure")
|
||||
}
|
||||
|
||||
// Set cache
|
||||
var claims jwt.MapClaims
|
||||
err = idToken.Claims(&claims)
|
||||
if err != nil {
|
||||
|
|
@ -484,38 +539,38 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// save the accessToken in memory for later use
|
||||
encToken, err := crypto.Encrypt([]byte(token.AccessToken), a.encryptionKey)
|
||||
sub := jwtutil.StringField(claims, "sub")
|
||||
err = a.SetValueInEncryptedCache(FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
|
||||
if err != nil {
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
http.Error(w, "failed encrypting token", http.StatusInternalServerError)
|
||||
log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON)
|
||||
log.Errorf("cannot cache encrypted accessToken: %v (claims=%s)", err, claimsJSON)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sub := jwtutil.StringField(claims, "sub")
|
||||
err = a.clientCache.Set(&cache.Item{
|
||||
Key: FormatAccessTokenCacheKey(sub),
|
||||
Object: encToken,
|
||||
CacheActionOpts: cache.CacheActionOpts{
|
||||
Expiration: getTokenExpiration(claims),
|
||||
},
|
||||
})
|
||||
|
||||
// Cache encrypted raw token for background refresh
|
||||
oidcTokenCache := NewOidcTokenCache(a.getRedirectURIForRequest(r), token)
|
||||
oidcTokenCacheJSON, err := json.Marshal(oidcTokenCache)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sid := jwtutil.StringField(claims, "sid")
|
||||
err = a.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
|
||||
if err != nil {
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError)
|
||||
log.Errorf("cannot cache encrypted oidc token: %v (claims=%s)", err, claimsJSON)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if idTokenRAW != "" {
|
||||
cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, idTokenRAW, flags...)
|
||||
err = httputil.SetTokenCookie(idTokenRAW, a.baseHRef, a.secureCookie, w)
|
||||
if err != nil {
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for _, cookie := range cookies {
|
||||
w.Header().Add("Set-Cookie", cookie)
|
||||
}
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
|
|
@ -528,6 +583,109 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// GetValueFromEncryptedCache is a convenience method for retreiving a value from cache and decrypting it. If the cache
|
||||
// does not contain a value for the given key, a nil value is returned. Return handling should check for error and then
|
||||
// check for nil.
|
||||
func (a *ClientApp) GetValueFromEncryptedCache(key string) (value []byte, err error) {
|
||||
var encryptedValue []byte
|
||||
err = a.clientCache.Get(key, &encryptedValue)
|
||||
if err != nil {
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
// Return nil to signify a cache miss
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get encrypted value from cache: %w", err)
|
||||
}
|
||||
value, err = crypto.Decrypt(encryptedValue, a.encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt value from cache: %w", err)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
|
||||
// SetValueFromEncyrptedCache is a convenience method for encrypting a value and storing it in the cache at a given key.
|
||||
// Cache expiration is set based on input.
|
||||
func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiration time.Duration) error {
|
||||
encryptedValue, err := crypto.Encrypt(value, a.encryptionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = a.clientCache.Set(&cache.Item{
|
||||
Key: key,
|
||||
Object: encryptedValue,
|
||||
CacheActionOpts: cache.CacheActionOpts{
|
||||
Expiration: expiration,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.MapClaims, refreshTokenThreshold time.Duration) (string, error) {
|
||||
sub := jwtutil.StringField(groupClaims, "sub")
|
||||
sid := jwtutil.StringField(groupClaims, "sid")
|
||||
if GetTokenExpiration(groupClaims) < refreshTokenThreshold {
|
||||
token, err := a.GetUpdatedOidcTokenFromCache(ctx, sub, sid)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get token from cache: %v", err)
|
||||
return "", err
|
||||
}
|
||||
if token != nil {
|
||||
idTokenRAW, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New("empty id_token")
|
||||
}
|
||||
return idTokenRAW, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// GetUpdatedOidcTokenFromCache fetches a token from cache and refreshes it if under the threshold for expiration.
|
||||
// The cached token will also be updated if it is refreshed. Returns latest token or an error if the process fails.
|
||||
func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject string, sessionId string) (*oauth2.Token, error) {
|
||||
ctx = gooidc.ClientContext(ctx, a.client)
|
||||
|
||||
// Get oauth2 config
|
||||
cacheKey := formatOidcTokenCacheKey(subject, sessionId)
|
||||
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(cacheKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcTokenCacheJSON == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
oidcTokenCache, err := GetOidcTokenCacheFromJSON(oidcTokenCacheJSON)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal cached oidc token: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
tokenSource, err := a.GetTokenSourceFromCache(ctx, oidcTokenCache)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get token source from cached oidc token: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token from source: %w", err)
|
||||
}
|
||||
if token.AccessToken != oidcTokenCache.Token.AccessToken {
|
||||
oidcTokenCache = NewOidcTokenCache(oidcTokenCache.RedirectURL, token)
|
||||
oidcTokenCacheJSON, err = json.Marshal(oidcTokenCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
|
||||
}
|
||||
err = a.SetValueInEncryptedCache(cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
var implicitFlowTmpl = template.Must(template.New("implicit.html").Parse(`<script>
|
||||
var hash = window.location.hash.substr(1);
|
||||
var result = hash.split('&').reduce(function (result, item) {
|
||||
|
|
@ -645,7 +803,7 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc
|
|||
// If querying the UserInfo endpoint fails, we return an error to indicate the session is invalid
|
||||
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
|
||||
// otherwise this would cause a panic
|
||||
func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) {
|
||||
func (a *ClientApp) SetGroupsFromUserInfo(ctx context.Context, claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) {
|
||||
var groupClaims jwt.MapClaims
|
||||
var ok bool
|
||||
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
|
||||
|
|
@ -657,7 +815,7 @@ func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaim
|
|||
}
|
||||
iss := jwtutil.StringField(groupClaims, "iss")
|
||||
if iss != sessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" {
|
||||
userInfo, unauthorized, err := a.GetUserInfo(groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath())
|
||||
userInfo, unauthorized, err := a.GetUserInfo(ctx, groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath())
|
||||
if unauthorized {
|
||||
return groupClaims, fmt.Errorf("error while quering userinfo endpoint: %w", err)
|
||||
}
|
||||
|
|
@ -674,7 +832,7 @@ func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaim
|
|||
}
|
||||
|
||||
// GetUserInfo queries the IDP userinfo endpoint for claims
|
||||
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
|
||||
func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
|
||||
sub := jwtutil.StringField(actualClaims, "sub")
|
||||
var claims jwt.MapClaims
|
||||
var encClaims []byte
|
||||
|
|
@ -696,19 +854,13 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||
}
|
||||
|
||||
// check if the accessToken for the user is still present
|
||||
var encAccessToken []byte
|
||||
err := a.clientCache.Get(FormatAccessTokenCacheKey(sub), &encAccessToken)
|
||||
// without an accessToken we can't query the user info endpoint
|
||||
// thus the user needs to reauthenticate for argocd to get a new accessToken
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err)
|
||||
} else if err != nil {
|
||||
accessTokenBytes, err := a.GetValueFromEncryptedCache(FormatAccessTokenCacheKey(sub))
|
||||
if err != nil {
|
||||
return claims, true, fmt.Errorf("could not read accessToken from cache for %s: %w", sub, err)
|
||||
}
|
||||
|
||||
accessToken, err := crypto.Decrypt(encAccessToken, a.encryptionKey)
|
||||
if err != nil {
|
||||
return claims, true, fmt.Errorf("could not decrypt accessToken for %s: %w", sub, err)
|
||||
if accessTokenBytes == nil {
|
||||
return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err)
|
||||
}
|
||||
|
||||
url := issuerURL + userInfoPath
|
||||
|
|
@ -718,7 +870,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||
return claims, false, err
|
||||
}
|
||||
|
||||
bearer := fmt.Sprintf("Bearer %s", accessToken)
|
||||
bearer := "Bearer " + string(accessTokenBytes)
|
||||
request.Header.Set("Authorization", bearer)
|
||||
|
||||
response, err := a.client.Do(request)
|
||||
|
|
@ -740,7 +892,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||
switch header {
|
||||
case "application/jwt":
|
||||
// if body is JWT, first validate it before extracting claims
|
||||
idToken, err := a.provider.Verify(string(rawBody), a.settings)
|
||||
idToken, err := a.provider.Verify(ctx, string(rawBody), a.settings)
|
||||
if err != nil {
|
||||
return claims, false, fmt.Errorf("user info response in jwt format not valid: %w", err)
|
||||
}
|
||||
|
|
@ -760,7 +912,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||
// but first let's determine the expiry of the cache
|
||||
var cacheExpiry time.Duration
|
||||
settingExpiry := a.settings.UserInfoCacheExpiration()
|
||||
tokenExpiry := getTokenExpiration(claims)
|
||||
tokenExpiry := GetTokenExpiration(claims)
|
||||
|
||||
// only use configured expiry if the token lives longer and the expiry is configured
|
||||
// if the token has no expiry, use the expiry of the actual token
|
||||
|
|
@ -769,7 +921,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||
case settingExpiry < tokenExpiry && settingExpiry != 0:
|
||||
cacheExpiry = settingExpiry
|
||||
case tokenExpiry < 0:
|
||||
cacheExpiry = getTokenExpiration(actualClaims)
|
||||
cacheExpiry = GetTokenExpiration(actualClaims)
|
||||
default:
|
||||
cacheExpiry = tokenExpiry
|
||||
}
|
||||
|
|
@ -797,8 +949,8 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||
return claims, false, nil
|
||||
}
|
||||
|
||||
// getTokenExpiration returns a time.Duration until the token expires
|
||||
func getTokenExpiration(claims jwt.MapClaims) time.Duration {
|
||||
// GetTokenExpiration returns a time.Duration until the token expires
|
||||
func GetTokenExpiration(claims jwt.MapClaims) time.Duration {
|
||||
// get duration until token expires
|
||||
exp := jwtutil.Float64Field(claims, "exp")
|
||||
tm := time.Unix(int64(exp), 0)
|
||||
|
|
@ -806,12 +958,28 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
|
|||
return tokenExpiry
|
||||
}
|
||||
|
||||
// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
|
||||
// getScopes returns scopes based on provider configuration
|
||||
func (a *ClientApp) getScopes() []string {
|
||||
scopes := make([]string, 0)
|
||||
if config := a.settings.OIDCConfig(); config != nil {
|
||||
scopes = GetScopesOrDefault(config.RequestedScopes)
|
||||
} else if a.settings.IsDexConfigured() {
|
||||
scopes = append(GetScopesOrDefault(nil), common.DexFederatedScope)
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
// FormatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
|
||||
func FormatUserInfoResponseCacheKey(sub string) string {
|
||||
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
|
||||
}
|
||||
|
||||
// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
|
||||
// FormatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
|
||||
func FormatAccessTokenCacheKey(sub string) string {
|
||||
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
|
||||
}
|
||||
|
||||
// formatRefreshTokenCacheKey returns the key which is used to store the oidc Token for a session in cache
|
||||
func formatOidcTokenCacheKey(sub string, sid string) string {
|
||||
return fmt.Sprintf("%s_%s_%s", OidcTokenCachePrefix, sub, sid)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -15,6 +17,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
gooidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -28,6 +32,7 @@ import (
|
|||
"github.com/argoproj/argo-cd/v3/util/cache"
|
||||
"github.com/argoproj/argo-cd/v3/util/crypto"
|
||||
"github.com/argoproj/argo-cd/v3/util/dex"
|
||||
jwtutil "github.com/argoproj/argo-cd/v3/util/jwt"
|
||||
"github.com/argoproj/argo-cd/v3/util/settings"
|
||||
"github.com/argoproj/argo-cd/v3/util/test"
|
||||
)
|
||||
|
|
@ -97,9 +102,14 @@ func TestIDTokenClaims(t *testing.T) {
|
|||
assert.JSONEq(t, "{\"id_token\":{\"groups\":{\"essential\":true}}}", values.Get("claims"))
|
||||
}
|
||||
|
||||
type fakeProvider struct{}
|
||||
type fakeProvider struct {
|
||||
EndpointError bool
|
||||
}
|
||||
|
||||
func (p *fakeProvider) Endpoint() (*oauth2.Endpoint, error) {
|
||||
if p.EndpointError {
|
||||
return nil, errors.New("fake provider endpoint error")
|
||||
}
|
||||
return &oauth2.Endpoint{}, nil
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +117,7 @@ func (p *fakeProvider) ParseConfig() (*OIDCConfiguration, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *fakeProvider) Verify(_ string, _ *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
|
||||
func (p *fakeProvider) Verify(_ context.Context, _ string, _ *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
@ -530,7 +540,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
app.HandleCallback(w, req)
|
||||
|
||||
if !strings.Contains(w.Body.String(), "certificate signed by unknown authority") && !strings.Contains(w.Body.String(), "certificate is not trusted") {
|
||||
t.Fatal("did not receive expected certificate verification failure error")
|
||||
t.Fatalf("did not receive expected certificate verification failure error: %v", w.Code)
|
||||
}
|
||||
|
||||
cdSettings.OIDCTLSInsecureSkipVerify = true
|
||||
|
|
@ -700,7 +710,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
app.HandleCallback(w, req)
|
||||
|
||||
if !strings.Contains(w.Body.String(), "certificate signed by unknown authority") && !strings.Contains(w.Body.String(), "certificate is not trusted") {
|
||||
t.Fatal("did not receive expected certificate verification failure error")
|
||||
t.Fatalf("did not receive expected certificate verification failure error: %v", w.Code)
|
||||
}
|
||||
|
||||
cdSettings.OIDCTLSInsecureSkipVerify = true
|
||||
|
|
@ -1147,7 +1157,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
got, unauthenticated, err := a.GetUserInfo(tt.idpClaims, ts.URL, tt.userInfoPath)
|
||||
got, unauthenticated, err := a.GetUserInfo(t.Context(), tt.idpClaims, ts.URL, tt.userInfoPath)
|
||||
assert.Equal(t, tt.expectedOutput, got)
|
||||
assert.Equal(t, tt.expectUnauthenticated, unauthenticated)
|
||||
if tt.expectError {
|
||||
|
|
@ -1253,7 +1263,7 @@ userInfoPath: /`,
|
|||
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||
}
|
||||
|
||||
receivedClaims, err := a.SetGroupsFromUserInfo(tt.inputClaims, "argocd")
|
||||
receivedClaims, err := a.SetGroupsFromUserInfo(t.Context(), tt.inputClaims, "argocd")
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
|
@ -1263,3 +1273,300 @@ userInfoPath: /`,
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOidcTokenCacheFromJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oidcTokenCache *OidcTokenCache
|
||||
expectErrorContains string
|
||||
expectIdToken string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
oidcTokenCache: &OidcTokenCache{},
|
||||
expectErrorContains: "empty token",
|
||||
},
|
||||
{
|
||||
name: "empty id token",
|
||||
oidcTokenCache: &OidcTokenCache{
|
||||
Token: &oauth2.Token{},
|
||||
},
|
||||
expectIdToken: "",
|
||||
},
|
||||
{
|
||||
name: "simple",
|
||||
oidcTokenCache: NewOidcTokenCache("", (&oauth2.Token{}).WithExtra(map[string]any{"id_token": "simple"})),
|
||||
expectIdToken: "simple",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenJSON, err := json.Marshal(tt.oidcTokenCache)
|
||||
require.NoError(t, err)
|
||||
token, err := GetOidcTokenCacheFromJSON(tokenJSON)
|
||||
if tt.expectErrorContains != "" {
|
||||
assert.ErrorContains(t, err, tt.expectErrorContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.expectIdToken != "" {
|
||||
assert.Equal(t, tt.expectIdToken, token.Token.Extra("id_token").(string))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientApp_GetTokenSourceFromCache(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oidcTokenCache *OidcTokenCache
|
||||
expectErrorContains string
|
||||
provider Provider
|
||||
}{
|
||||
{
|
||||
name: "provider error",
|
||||
oidcTokenCache: &OidcTokenCache{},
|
||||
expectErrorContains: "fake provider endpoint error",
|
||||
provider: &fakeProvider{
|
||||
EndpointError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty oidcTokenCache",
|
||||
expectErrorContains: "oidcTokenCache is required",
|
||||
provider: &fakeProvider{},
|
||||
},
|
||||
{
|
||||
name: "simple",
|
||||
oidcTokenCache: NewOidcTokenCache("", (&oauth2.Token{}).WithExtra(map[string]any{"id_token": "simple"})),
|
||||
provider: &fakeProvider{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := ClientApp{provider: tt.provider, settings: &settings.ArgoCDSettings{}}
|
||||
tokenSource, err := app.GetTokenSourceFromCache(t.Context(), tt.oidcTokenCache)
|
||||
if tt.expectErrorContains != "" {
|
||||
assert.ErrorContains(t, err, tt.expectErrorContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, tokenSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientApp_GetUpdatedOidcTokenFromCache(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subject string
|
||||
session string
|
||||
insertIntoCache bool
|
||||
oidcTokenCache *OidcTokenCache
|
||||
expectErrorContains string
|
||||
expectTokenNotNil bool
|
||||
}{
|
||||
{
|
||||
name: "empty token cache",
|
||||
subject: "alice",
|
||||
session: "111",
|
||||
insertIntoCache: true,
|
||||
expectErrorContains: "failed to unmarshal cached oidc token: empty token",
|
||||
},
|
||||
{
|
||||
name: "no refresh token",
|
||||
subject: "alice",
|
||||
session: "111",
|
||||
insertIntoCache: true,
|
||||
oidcTokenCache: &OidcTokenCache{Token: &oauth2.Token{}},
|
||||
expectErrorContains: "failed to refresh token from source: oauth2: token expired and refresh token is not set",
|
||||
},
|
||||
{
|
||||
name: "cache miss",
|
||||
subject: "",
|
||||
session: "",
|
||||
insertIntoCache: false,
|
||||
},
|
||||
{
|
||||
name: "updated token from cache",
|
||||
subject: "alice",
|
||||
session: "111",
|
||||
insertIntoCache: true,
|
||||
oidcTokenCache: &OidcTokenCache{Token: &oauth2.Token{
|
||||
RefreshToken: "not empty",
|
||||
}},
|
||||
expectTokenNotNil: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcTestServer := test.GetOIDCTestServer(t, nil)
|
||||
t.Cleanup(oidcTestServer.Close)
|
||||
|
||||
cdSettings := &settings.ArgoCDSettings{
|
||||
URL: "https://argocd.example.com",
|
||||
OIDCConfigRAW: fmt.Sprintf(`
|
||||
name: Test
|
||||
issuer: %s
|
||||
clientID: test-client-id
|
||||
clientSecret: test-client-secret
|
||||
requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
||||
OIDCTLSInsecureSkipVerify: true,
|
||||
}
|
||||
app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour))
|
||||
require.NoError(t, err)
|
||||
if tt.insertIntoCache {
|
||||
oidcTokenCacheJSON, err := json.Marshal(tt.oidcTokenCache)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
|
||||
}
|
||||
token, err := app.GetUpdatedOidcTokenFromCache(t.Context(), tt.subject, tt.session)
|
||||
if tt.expectErrorContains != "" {
|
||||
assert.ErrorContains(t, err, tt.expectErrorContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.expectTokenNotNil {
|
||||
assert.NotNil(t, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientApp_CheckAndGetRefreshToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expectErrorContains string
|
||||
expectNewToken bool
|
||||
groupClaims jwt.MapClaims
|
||||
refreshTokenThreshold string
|
||||
}{
|
||||
{
|
||||
name: "no new token",
|
||||
groupClaims: jwt.MapClaims{
|
||||
"aud": common.ArgoCDClientAppID,
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"sub": "randomUser",
|
||||
"sid": "1111",
|
||||
"iss": "issuer",
|
||||
"groups": "group1",
|
||||
},
|
||||
expectNewToken: false,
|
||||
refreshTokenThreshold: "1m",
|
||||
},
|
||||
{
|
||||
name: "new token",
|
||||
groupClaims: jwt.MapClaims{
|
||||
"aud": common.ArgoCDClientAppID,
|
||||
"exp": float64(time.Now().Add(55 * time.Second).Unix()),
|
||||
"sub": "randomUser",
|
||||
"sid": "1111",
|
||||
"iss": "issuer",
|
||||
"groups": "group1",
|
||||
},
|
||||
expectNewToken: true,
|
||||
refreshTokenThreshold: "1m",
|
||||
},
|
||||
{
|
||||
name: "parse error",
|
||||
groupClaims: jwt.MapClaims{
|
||||
"aud": common.ArgoCDClientAppID,
|
||||
"exp": float64(time.Now().Add(time.Minute).Unix()),
|
||||
"sub": "randomUser",
|
||||
"sid": "1111",
|
||||
"iss": "issuer",
|
||||
"groups": "group1",
|
||||
},
|
||||
expectNewToken: false,
|
||||
refreshTokenThreshold: "1xx",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcTestServer := test.GetOIDCTestServer(t, nil)
|
||||
t.Cleanup(oidcTestServer.Close)
|
||||
|
||||
cdSettings := &settings.ArgoCDSettings{
|
||||
URL: "https://argocd.example.com",
|
||||
OIDCConfigRAW: fmt.Sprintf(`
|
||||
name: Test
|
||||
issuer: %s
|
||||
clientID: test-client-id
|
||||
clientSecret: test-client-secret
|
||||
refreshTokenThreshold: %s
|
||||
requestedScopes: ["oidc"]`, oidcTestServer.URL, tt.refreshTokenThreshold),
|
||||
OIDCTLSInsecureSkipVerify: true,
|
||||
}
|
||||
// The base href (the last argument for NewClientApp) is what HandleLogin will fall back to when no explicit
|
||||
// redirect URL is given.
|
||||
app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour))
|
||||
require.NoError(t, err)
|
||||
oidcTokenCacheJSON, err := json.Marshal(&OidcTokenCache{Token: &oauth2.Token{
|
||||
RefreshToken: "not empty",
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
sub := jwtutil.StringField(tt.groupClaims, "sub")
|
||||
require.NotEmpty(t, sub)
|
||||
sid := jwtutil.StringField(tt.groupClaims, "sid")
|
||||
require.NotEmpty(t, sid)
|
||||
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
|
||||
token, err := app.CheckAndRefreshToken(t.Context(), tt.groupClaims, cdSettings.RefreshTokenThreshold())
|
||||
if tt.expectErrorContains != "" {
|
||||
require.ErrorContains(t, err, tt.expectErrorContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.expectNewToken {
|
||||
require.NotEmpty(t, token)
|
||||
} else {
|
||||
require.Empty(t, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientApp_getRedirectURIForRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
expectLogContains string
|
||||
expectedRequestURI string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
req: &http.Request{
|
||||
URL: &url.URL{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil URL",
|
||||
expectLogContains: "falling back to configured redirect URI",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := ClientApp{provider: &fakeProvider{}, settings: &settings.ArgoCDSettings{}}
|
||||
hook := test.LogHook{}
|
||||
log.AddHook(&hook)
|
||||
t.Cleanup(func() {
|
||||
log.StandardLogger().ReplaceHooks(log.LevelHooks{})
|
||||
})
|
||||
redirectURI := app.getRedirectURIForRequest(tt.req)
|
||||
if tt.expectLogContains != "" {
|
||||
assert.NotEmpty(t, hook.GetRegexMatchesInEntries(tt.expectLogContains), "expected log")
|
||||
} else {
|
||||
assert.Empty(t, hook.Entries, "expected log")
|
||||
}
|
||||
if tt.req == nil {
|
||||
return
|
||||
}
|
||||
expectedRedirectURI, err := app.settings.RedirectURLForRequest(tt.req)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, expectedRedirectURI, redirectURI, "expected URI")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ type Provider interface {
|
|||
|
||||
ParseConfig() (*OIDCConfiguration, error)
|
||||
|
||||
Verify(tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error)
|
||||
Verify(ctx context.Context, tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error)
|
||||
}
|
||||
|
||||
type providerImpl struct {
|
||||
|
|
@ -85,7 +85,7 @@ func (t tokenVerificationError) Error() string {
|
|||
return "token verification failed for all audiences: " + strings.Join(errorStrings, ", ")
|
||||
}
|
||||
|
||||
func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
|
||||
func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSettings *settings.ArgoCDSettings) (*gooidc.IDToken, error) {
|
||||
// According to the JWT spec, the aud claim is optional. The spec also says (emphasis mine):
|
||||
//
|
||||
// If the principal processing the claim does not identify itself with a value in the "aud" claim _when this
|
||||
|
|
@ -110,7 +110,7 @@ func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDS
|
|||
|
||||
var idToken *gooidc.IDToken
|
||||
if !unverifiedHasAudClaim {
|
||||
idToken, err = p.verify("", tokenString, argoSettings.SkipAudienceCheckWhenTokenHasNoAudience())
|
||||
idToken, err = p.verify(ctx, "", tokenString, argoSettings.SkipAudienceCheckWhenTokenHasNoAudience())
|
||||
} else {
|
||||
allowedAudiences := argoSettings.OAuth2AllowedAudiences()
|
||||
if len(allowedAudiences) == 0 {
|
||||
|
|
@ -119,7 +119,7 @@ func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDS
|
|||
tokenVerificationErrors := make(map[string]error)
|
||||
// Token must be verified for at least one allowed audience
|
||||
for _, aud := range allowedAudiences {
|
||||
idToken, err = p.verify(aud, tokenString, false)
|
||||
idToken, err = p.verify(ctx, aud, tokenString, false)
|
||||
tokenExpiredError := &gooidc.TokenExpiredError{}
|
||||
if errors.As(err, &tokenExpiredError) {
|
||||
// If the token is expired, we won't bother checking other audiences. It's important to return a
|
||||
|
|
@ -143,14 +143,13 @@ func (p *providerImpl) Verify(tokenString string, argoSettings *settings.ArgoCDS
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify token: %w", err)
|
||||
return nil, fmt.Errorf("failed to verify provider token: %w", err)
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func (p *providerImpl) verify(clientID, tokenString string, skipClientIDCheck bool) (*gooidc.IDToken, error) {
|
||||
ctx := context.Background()
|
||||
func (p *providerImpl) verify(ctx context.Context, clientID, tokenString string, skipClientIDCheck bool) (*gooidc.IDToken, error) {
|
||||
prov, err := p.provider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -489,7 +489,7 @@ func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool, isSSOConfigured boo
|
|||
// TokenVerifier defines the contract to invoke token
|
||||
// verification logic
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) (jwt.Claims, string, error)
|
||||
VerifyToken(ctx context.Context, token string) (jwt.Claims, string, error)
|
||||
}
|
||||
|
||||
// WithAuthMiddleware is an HTTP middleware used to ensure incoming
|
||||
|
|
@ -504,12 +504,13 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
|
|||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies := r.Cookies()
|
||||
ctx := r.Context()
|
||||
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
|
||||
if err != nil {
|
||||
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
claims, _, err := authn.VerifyToken(tokenString)
|
||||
claims, _, err := authn.VerifyToken(ctx, tokenString)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
|
|
@ -517,14 +518,12 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
|
|||
|
||||
finalClaims := claims
|
||||
if isSSOConfigured {
|
||||
finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(claims, SessionManagerClaimsIssuer)
|
||||
finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(ctx, claims, SessionManagerClaimsIssuer)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
// Add claims to the context to inspect for RBAC
|
||||
//nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, "claims", finalClaims)
|
||||
|
|
@ -536,7 +535,7 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
|
|||
|
||||
// VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
|
||||
// We choose how to verify based on the issuer.
|
||||
func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string, error) {
|
||||
func (mgr *SessionManager) VerifyToken(ctx context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
||||
claims := jwt.MapClaims{}
|
||||
_, _, err := parser.ParseUnverified(tokenString, &claims)
|
||||
|
|
@ -564,12 +563,12 @@ func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, string,
|
|||
return nil, "", errors.New("settings are not available while verifying the token")
|
||||
}
|
||||
|
||||
idToken, err := prov.Verify(tokenString, argoSettings)
|
||||
idToken, err := prov.Verify(ctx, tokenString, argoSettings)
|
||||
// The token verification has failed. If the token has expired, we will
|
||||
// return a dummy claims only containing a value for the issuer, so the
|
||||
// UI can handle expired tokens appropriately.
|
||||
if err != nil {
|
||||
log.Warnf("Failed to verify token: %s", err)
|
||||
log.Warnf("Failed to verify session token: %s", err)
|
||||
tokenExpiredError := &oidc.TokenExpiredError{}
|
||||
if errors.As(err, &tokenExpiredError) {
|
||||
claims = jwt.MapClaims{
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ type tokenVerifierMock struct {
|
|||
err error
|
||||
}
|
||||
|
||||
func (tm *tokenVerifierMock) VerifyToken(_ string) (jwt.Claims, string, error) {
|
||||
func (tm *tokenVerifierMock) VerifyToken(_ context.Context, _ string) (jwt.Claims, string, error) {
|
||||
if tm.claims == nil {
|
||||
return nil, "", tm.err
|
||||
}
|
||||
|
|
@ -255,7 +255,7 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
|||
}
|
||||
}
|
||||
jsonClaims, err := json.Marshal(gotClaims)
|
||||
require.NoError(t, err, "erorr marshalling claims set by AuthMiddleware")
|
||||
require.NoError(t, err, "error marshalling claims set by AuthMiddleware")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonClaims)
|
||||
require.NoError(t, err, "error writing response: %s", err)
|
||||
|
|
@ -720,7 +720,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
assert.NotContains(t, err.Error(), "oidc: id token signed with unsupported algorithm")
|
||||
})
|
||||
|
||||
|
|
@ -752,7 +752,7 @@ rootCA: |
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
// If the root CA is being respected, we won't get this error. The error message is environment-dependent, so
|
||||
// we check for either of the error messages associated with a failed cert check.
|
||||
assert.NotContains(t, err.Error(), "certificate is not trusted")
|
||||
|
|
@ -789,7 +789,7 @@ rootCA: |
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -824,7 +824,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -859,7 +859,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -895,7 +895,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
assert.NotContains(t, err.Error(), "certificate is not trusted")
|
||||
assert.NotContains(t, err.Error(), "certificate signed by unknown authority")
|
||||
})
|
||||
|
|
@ -924,7 +924,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
// This is the error thrown when the test server's certificate _is_ being verified.
|
||||
assert.NotContains(t, err.Error(), "certificate is not trusted")
|
||||
assert.NotContains(t, err.Error(), "certificate signed by unknown authority")
|
||||
|
|
@ -961,7 +961,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -997,7 +997,7 @@ skipAudienceCheckWhenTokenHasNoAudience: true`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -1033,7 +1033,7 @@ skipAudienceCheckWhenTokenHasNoAudience: false`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -1069,7 +1069,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -1106,7 +1106,7 @@ allowedAudiences:
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -1143,7 +1143,7 @@ allowedAudiences:
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -1179,7 +1179,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -1216,7 +1216,7 @@ allowedAudiences: []`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
@ -1254,7 +1254,7 @@ allowedAudiences: ["aud-a", "aud-b"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -1289,7 +1289,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
|||
tokenString, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = mgr.VerifyToken(tokenString)
|
||||
_, _, err = mgr.VerifyToken(t.Context(), tokenString)
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, common.ErrTokenVerification)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -136,6 +136,9 @@ type ArgoCDSettings struct {
|
|||
// token verification to pass despite the OIDC provider having an invalid certificate. Only set to `true` if you
|
||||
// understand the risks.
|
||||
OIDCTLSInsecureSkipVerify bool `json:"oidcTLSInsecureSkipVerify"`
|
||||
// OIDCRefreshTokenThreshold sets the threshold for preemptive server-side token refresh. If set to 0, tokens
|
||||
// will not be refreshed and will expire before client is redirected to login.
|
||||
OIDCRefreshTokenThreshold time.Duration `json:"oidcRefreshTokenThreshold,omitempty"`
|
||||
// AppsInAnyNamespaceEnabled indicates whether applications are allowed to be created in any namespace
|
||||
AppsInAnyNamespaceEnabled bool `json:"appsInAnyNamespaceEnabled"`
|
||||
// ExtensionConfig configurations related to ArgoCD proxy extensions. The keys are the extension name.
|
||||
|
|
@ -193,6 +196,7 @@ func (o *oidcConfig) toExported() *OIDCConfig {
|
|||
UserInfoPath: o.UserInfoPath,
|
||||
EnableUserInfoGroups: o.EnableUserInfoGroups,
|
||||
UserInfoCacheExpiration: o.UserInfoCacheExpiration,
|
||||
RefreshTokenThreshold: o.RefreshTokenThreshold,
|
||||
RequestedScopes: o.RequestedScopes,
|
||||
RequestedIDTokenClaims: o.RequestedIDTokenClaims,
|
||||
LogoutURL: o.LogoutURL,
|
||||
|
|
@ -218,6 +222,7 @@ type OIDCConfig struct {
|
|||
EnablePKCEAuthentication bool `json:"enablePKCEAuthentication,omitempty"`
|
||||
DomainHint string `json:"domainHint,omitempty"`
|
||||
Azure *AzureOIDCConfig `json:"azure,omitempty"`
|
||||
RefreshTokenThreshold string `json:"refreshTokenThreshold,omitempty"`
|
||||
}
|
||||
|
||||
type AzureOIDCConfig struct {
|
||||
|
|
@ -1432,6 +1437,7 @@ func getDownloadBinaryUrlsFromConfigMap(argoCDCM *corev1.ConfigMap) map[string]s
|
|||
func updateSettingsFromConfigMap(settings *ArgoCDSettings, argoCDCM *corev1.ConfigMap) {
|
||||
settings.DexConfig = argoCDCM.Data[settingDexConfigKey]
|
||||
settings.OIDCConfigRAW = argoCDCM.Data[settingsOIDCConfigKey]
|
||||
settings.OIDCRefreshTokenThreshold = settings.RefreshTokenThreshold()
|
||||
settings.KustomizeBuildOptions = argoCDCM.Data[kustomizeBuildOptionsKey]
|
||||
settings.StatusBadgeEnabled = argoCDCM.Data[statusBadgeEnabledKey] == "true"
|
||||
settings.StatusBadgeRootUrl = argoCDCM.Data[statusBadgeRootURLKey]
|
||||
|
|
@ -1882,6 +1888,18 @@ func (a *ArgoCDSettings) UserInfoCacheExpiration() time.Duration {
|
|||
return 0
|
||||
}
|
||||
|
||||
// RefreshTokenThreshold returns the duration before token expiration that a token should be refreshed by the server
|
||||
func (a *ArgoCDSettings) RefreshTokenThreshold() time.Duration {
|
||||
if oidcConfig := a.OIDCConfig(); oidcConfig != nil && oidcConfig.RefreshTokenThreshold != "" {
|
||||
refreshTokenThreshold, err := time.ParseDuration(oidcConfig.RefreshTokenThreshold)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse 'oidc.config.refreshTokenThreshold' key: %v", err)
|
||||
}
|
||||
return refreshTokenThreshold
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *ArgoCDSettings) OAuth2ClientID() string {
|
||||
if oidcConfig := a.OIDCConfig(); oidcConfig != nil {
|
||||
return oidcConfig.ClientID
|
||||
|
|
@ -2001,6 +2019,9 @@ func (a *ArgoCDSettings) ArgoURLForRequest(r *http.Request) (string, error) {
|
|||
}
|
||||
|
||||
func (a *ArgoCDSettings) RedirectURLForRequest(r *http.Request) (string, error) {
|
||||
if r == nil {
|
||||
return "", errors.New("request is nil")
|
||||
}
|
||||
base, err := a.ArgoURLForRequest(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
|||
|
|
@ -1020,43 +1020,94 @@ func TestSettingsManager_GetSettings(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetOIDCConfig(t *testing.T) {
|
||||
kubeClient := fake.NewClientset(
|
||||
&corev1.ConfigMap{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: common.ArgoCDConfigMapName,
|
||||
Namespace: "default",
|
||||
Labels: map[string]string{
|
||||
"app.kubernetes.io/part-of": "argocd",
|
||||
},
|
||||
},
|
||||
Data: map[string]string{
|
||||
testCases := []struct {
|
||||
name string
|
||||
configMapData map[string]string
|
||||
testFunc func(t *testing.T, settingsManager *SettingsManager)
|
||||
}{
|
||||
{
|
||||
name: "requestedIDTokenClaims",
|
||||
configMapData: map[string]string{
|
||||
"oidc.config": "\n requestedIDTokenClaims: {\"groups\": {\"essential\": true}}\n",
|
||||
},
|
||||
testFunc: func(t *testing.T, settingsManager *SettingsManager) {
|
||||
t.Helper()
|
||||
settings, err := settingsManager.GetSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := settings.OIDCConfig()
|
||||
assert.NotNil(t, oidcConfig)
|
||||
|
||||
claim := oidcConfig.RequestedIDTokenClaims["groups"]
|
||||
assert.NotNil(t, claim)
|
||||
assert.True(t, claim.Essential)
|
||||
},
|
||||
},
|
||||
&corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: common.ArgoCDSecretName,
|
||||
Namespace: "default",
|
||||
Labels: map[string]string{
|
||||
"app.kubernetes.io/part-of": "argocd",
|
||||
{
|
||||
name: "refreshTokenThreshold success",
|
||||
configMapData: map[string]string{
|
||||
"oidc.config": "\n refreshTokenThreshold: 5m\n",
|
||||
},
|
||||
testFunc: func(t *testing.T, settingsManager *SettingsManager) {
|
||||
t.Helper()
|
||||
settings, err := settingsManager.GetSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := settings.OIDCConfig()
|
||||
assert.NotNil(t, oidcConfig)
|
||||
|
||||
assert.Equal(t, 5*time.Minute, settings.RefreshTokenThreshold())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "refreshTokenThreshold parse failure",
|
||||
configMapData: map[string]string{
|
||||
"oidc.config": "\n refreshTokenThreshold: 5xx\n",
|
||||
},
|
||||
testFunc: func(t *testing.T, settingsManager *SettingsManager) {
|
||||
t.Helper()
|
||||
settings, err := settingsManager.GetSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := settings.OIDCConfig()
|
||||
assert.NotNil(t, oidcConfig)
|
||||
|
||||
assert.Equal(t, time.Duration(0), settings.RefreshTokenThreshold())
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range testCases {
|
||||
tc := testCases[i]
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
kubeClient := fake.NewClientset(
|
||||
&corev1.ConfigMap{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: common.ArgoCDConfigMapName,
|
||||
Namespace: "default",
|
||||
Labels: map[string]string{
|
||||
"app.kubernetes.io/part-of": "argocd",
|
||||
},
|
||||
},
|
||||
Data: tc.configMapData,
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"admin.password": nil,
|
||||
"server.secretkey": nil,
|
||||
},
|
||||
},
|
||||
)
|
||||
settingsManager := NewSettingsManager(t.Context(), kubeClient, "default")
|
||||
settings, err := settingsManager.GetSettings()
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := settings.OIDCConfig()
|
||||
assert.NotNil(t, oidcConfig)
|
||||
|
||||
claim := oidcConfig.RequestedIDTokenClaims["groups"]
|
||||
assert.NotNil(t, claim)
|
||||
assert.True(t, claim.Essential)
|
||||
&corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: common.ArgoCDSecretName,
|
||||
Namespace: "default",
|
||||
Labels: map[string]string{
|
||||
"app.kubernetes.io/part-of": "argocd",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"admin.password": nil,
|
||||
"server.secretkey": nil,
|
||||
},
|
||||
},
|
||||
)
|
||||
settingsManager := NewSettingsManager(t.Context(), kubeClient, "default")
|
||||
tc.testFunc(t, settingsManager)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectURL(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue