Include an error code as query string in /sso/callback response in case of failure (#6286)

This commit is contained in:
Martin Angers 2022-06-21 09:04:50 -04:00 committed by GitHub
parent 984605f630
commit 7bfe93f5d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 4 deletions

View file

@ -0,0 +1 @@
* Added `status` query string parameter in redirect to `/login` on SSO authentication failure, to assist the frontend in displaying a helpful error message.

View file

@ -5105,6 +5105,25 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
t.Log(rows)
}
func (s *integrationTestSuite) TestSSODisabled() {
t := s.T()
var initiateResp initiateSSOResponse
s.DoJSON("POST", "/api/v1/fleet/sso", struct{}{}, http.StatusBadRequest, &initiateResp)
var callbackResp callbackSSOResponse
// callback without SAML response
s.DoJSON("POST", "/api/v1/fleet/sso/callback", nil, http.StatusBadRequest, &callbackResp)
// callback with invalid SAML response
s.DoJSON("POST", "/api/v1/fleet/sso/callback?SAMLResponse=zz", nil, http.StatusBadRequest, &callbackResp)
// callback with valid SAML response (<samlp:AuthnRequest></samlp:AuthnRequest>)
res := s.DoRaw("POST", "/api/v1/fleet/sso/callback?SAMLResponse=PHNhbWxwOkF1dGhuUmVxdWVzdD48L3NhbWxwOkF1dGhuUmVxdWVzdD4%3D", nil, http.StatusOK)
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Contains(t, string(body), "/login?status=org_disabled") // html contains a script that redirects to this path
}
// this test can be deleted once the "v1" version is removed.
func (s *integrationTestSuite) TestAPIVersion_v1_2022_04() {
t := s.T()

View file

@ -21,3 +21,31 @@ func (a alreadyExistsError) Error() string {
func (a alreadyExistsError) IsExists() bool {
return true
}
// ssoErrCode defines a code for the type of SSO error that occurred. This is
// used to indicate to the frontend why the SSO login attempt failed so that
// it can provide a helpful and appropriate error message.
type ssoErrCode string
// List of valid SSO error codes.
const (
ssoOtherError ssoErrCode = "error"
ssoOrgDisabled ssoErrCode = "org_disabled"
ssoAccountDisabled ssoErrCode = "account_disabled"
ssoAccountInvalid ssoErrCode = "account_invalid"
)
// ssoError is an error that occurs during the Single-Sign-On flow. Its code
// indicates the type of error.
type ssoError struct {
err error
code ssoErrCode
}
func (e ssoError) Error() string {
return string(e.code) + ": " + e.err.Error()
}
func (e ssoError) Unwrap() error {
return e.err
}

View file

@ -277,6 +277,11 @@ func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (string
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting app config")
}
if !appConfig.SSOSettings.EnableSSO {
err := &badRequestError{message: "organization not configured to use sso"}
return "", ctxerr.Wrap(ctx, ssoError{err: err, code: ssoOrgDisabled}, "callback sso")
}
metadata, err := svc.getMetadata(appConfig)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "InitiateSSO getting metadata")
@ -321,11 +326,11 @@ type callbackSSORequest struct{}
func (callbackSSORequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) {
err := r.ParseForm()
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "decode sso callback")
return nil, ctxerr.Wrap(ctx, &badRequestError{message: err.Error()}, "decode sso callback")
}
authResponse, err := sso.DecodeAuthResponse(r.FormValue("SAMLResponse"))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "decoding sso callback")
return nil, ctxerr.Wrap(ctx, &badRequestError{message: err.Error()}, "decoding sso callback")
}
return authResponse, nil
}
@ -346,10 +351,16 @@ func makeCallbackSSOEndpoint(urlPrefix string) handlerFunc {
session, err := svc.CallbackSSO(ctx, authResponse)
var resp callbackSSOResponse
if err != nil {
var ssoErr ssoError
status := ssoOtherError
if errors.As(err, &ssoErr) {
status = ssoErr.code
}
// redirect to login page on front end if there was some problem,
// errors should still be logged
session = &fleet.SSOSession{
RedirectURL: urlPrefix + "/login",
RedirectURL: urlPrefix + "/login?status=" + string(status),
Token: "",
}
resp.Err = err
@ -391,6 +402,11 @@ func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SS
return nil, ctxerr.Wrap(ctx, err, "get config for sso")
}
if !appConfig.SSOSettings.EnableSSO {
err := ctxerr.New(ctx, "organization not configured to use sso")
return nil, ctxerr.Wrap(ctx, ssoError{err: err, code: ssoOrgDisabled}, "callback sso")
}
// Load the request metadata if available
// localhost:9080/simplesaml/saml2/idp/SSOService.php?spentityid=https://localhost:8080
@ -444,11 +460,16 @@ func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SS
// Get and log in user
user, err := svc.ds.UserByEmail(ctx, auth.UserID())
if err != nil {
var nfe notFoundErrorInterface
if errors.As(err, &nfe) {
return nil, ctxerr.Wrap(ctx, ssoError{err: err, code: ssoAccountInvalid})
}
return nil, ctxerr.Wrap(ctx, err, "find user in sso callback")
}
// if the user is not sso enabled they are not authorized
if !user.SSOEnabled {
return nil, ctxerr.New(ctx, "user not configured to use sso")
err := ctxerr.New(ctx, "user not configured to use sso")
return nil, ctxerr.Wrap(ctx, ssoError{err: err, code: ssoAccountDisabled})
}
session, err := svc.makeSession(ctx, user.ID)
if err != nil {