diff --git a/ee/server/service/mdm.go b/ee/server/service/mdm.go index 7992d2f52b..f64cd3db9f 100644 --- a/ee/server/service/mdm.go +++ b/ee/server/service/mdm.go @@ -712,7 +712,7 @@ func (svc *Service) DeleteMDMAppleSetupAssistant(ctx context.Context, teamID *ui const appleMDMAccountDrivenEnrollmentUrl = "/api/mdm/apple/account_driven_enroll" -func (svc *Service) InitiateMDMSSO(ctx context.Context, initiator, customOriginalURL string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) { +func (svc *Service) InitiateMDMSSO(ctx context.Context, initiator, customOriginalURL string, hostUUID string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) { // skipauth: User context does not yet exist. Unauthenticated users may // initiate SSO. svc.authz.SkipAuthorization(ctx) @@ -771,6 +771,10 @@ func (svc *Service) InitiateMDMSSO(ctx context.Context, initiator, customOrigina sessionID, idpURL, err = sso.CreateAuthorizationRequest(ctx, samlProvider, svc.ssoSessionStore, originalURL, uint(sessionDurationSeconds), //nolint:gosec // dismiss G115 + sso.SSORequestData{ + HostUUID: hostUUID, + Initiator: initiator, + }, ) if err != nil { return "", 0, "", ctxerr.Wrap(ctx, err, "InitiateMDMSSO creating authorization") @@ -786,13 +790,13 @@ func (svc *Service) MDMSSOCallback(ctx context.Context, sessionID string, samlRe logging.WithLevel(logging.WithNoUser(ctx), level.Info) - profileToken, enrollmentRef, eulaToken, originalURL, err := svc.mdmSSOHandleCallbackAuth(ctx, sessionID, samlResponse) + profileToken, enrollmentRef, eulaToken, originalURL, ssoRequestData, err := svc.mdmSSOHandleCallbackAuth(ctx, sessionID, samlResponse) if err != nil { logging.WithErr(ctx, err) return apple_mdm.FleetUISSOCallbackPath + "?error=true", "" } - if !strings.HasPrefix(originalURL, "/enroll?") { + if !strings.HasPrefix(originalURL, "/enroll?") && ssoRequestData.Initiator != "setup_experience" { // for flows other than the /enroll BYOD, we have to ensure that Apple MDM // is enabled (this was previously done in a middleware on the route, but // we do it here now so the middleware is disabled for the BYOD flow, which @@ -812,6 +816,8 @@ func (svc *Service) MDMSSOCallback(ctx context.Context, sessionID string, samlRe q.Add("eula_token", eulaToken) } + q.Add("initiator", ssoRequestData.Initiator) + switch { case originalURL == appleMDMAccountDrivenEnrollmentUrl: // For account driven enrollment we have to use this special protocol URL scheme to pass the @@ -845,17 +851,17 @@ func (svc *Service) mdmSSOHandleCallbackAuth( sessionID string, samlResponse []byte, ) (profileToken string, enrollmentReference string, - eulaToken string, originalURL string, err error, + eulaToken string, originalURL string, ssoRequestData sso.SSORequestData, err error, ) { appConfig, err := svc.ds.AppConfig(ctx) if err != nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "get config for sso") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "get config for sso") } serverURL := appConfig.MDMUrl() acsURL, err := url.Parse(serverURL + svc.config.Server.URLPrefix + "/api/v1/fleet/mdm/sso/callback") if err != nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "failed to parse ACS URL") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "failed to parse ACS URL") } mdmSSOSettings := appConfig.MDM.EndUserAuthentication.SSOProviderSettings @@ -866,7 +872,7 @@ func (svc *Service) mdmSSOHandleCallbackAuth( // this means some teams may not use SSO even if it is configured. if mdmSSOSettings.IsEmpty() { err := &fleet.BadRequestError{Message: "organization not configured to use sso"} - return "", "", "", "", ctxerr.Wrap(ctx, err, "get config for mdm sso callback") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "get config for mdm sso callback") } expectedAudiences := []string{ @@ -874,11 +880,11 @@ func (svc *Service) mdmSSOHandleCallbackAuth( appConfig.MDMUrl(), appConfig.MDMUrl() + svc.config.Server.URLPrefix + "/api/v1/fleet/mdm/sso/callback", } - samlProvider, requestID, originalURL, err := sso.SAMLProviderFromSession( + samlProvider, requestID, originalURL, ssoRequestData, err := sso.SAMLProviderFromSession( ctx, sessionID, svc.ssoSessionStore, acsURL, mdmSSOSettings.EntityID, expectedAudiences, ) if err != nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "failed to create provider from metadata") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "failed to create provider from metadata") } // Parse and verify SAMLResponse (verifies fields, expected IDs and signature). @@ -886,7 +892,7 @@ func (svc *Service) mdmSSOHandleCallbackAuth( if err != nil { // We actually don't return 401 to clients and instead return an HTML page with /login?status=error, // but to be consistent we will return fleet.AuthFailedError which is used for unauthorized access. - return "", "", "", "", ctxerr.Wrap(ctx, fleet.NewAuthFailedError(err.Error())) + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, fleet.NewAuthFailedError(err.Error())) } // Store information for automatic account population/creation @@ -902,12 +908,13 @@ func (svc *Service) mdmSSOHandleCallbackAuth( } err = svc.ds.InsertMDMIdPAccount(ctx, &fleet.MDMIdPAccount{ + UUID: ssoRequestData.HostUUID, Username: username, Fullname: auth.UserDisplayName(), Email: auth.UserID(), }) if err != nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "saving account data from IdP") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "saving account data from IdP") } idpAcc, err := svc.ds.GetMDMIdPAccountByEmail( @@ -917,12 +924,21 @@ func (svc *Service) mdmSSOHandleCallbackAuth( auth.UserID(), ) if err != nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "retrieving new account data from IdP") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "retrieving new account data from IdP") + } + + // If the initiator is "setup_experience", we can insert the host idp account record + // right away, as the host uuid is provided in the SSO request data. + if ssoRequestData.Initiator == "setup_experience" && ssoRequestData.HostUUID != "" { + err = svc.ds.AssociateHostMDMIdPAccountDB(ctx, ssoRequestData.HostUUID, idpAcc.UUID) + if err != nil { + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "saving host-account link from IdP") + } } eula, err := svc.ds.MDMGetEULAMetadata(ctx) if err != nil && !fleet.IsNotFound(err) { - return "", "", "", "", ctxerr.Wrap(ctx, err, "getting EULA metadata") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "getting EULA metadata") } if eula != nil { @@ -931,22 +947,22 @@ func (svc *Service) mdmSSOHandleCallbackAuth( // If this is account driven enrollment there is no need to fetch the profile if originalURL == appleMDMAccountDrivenEnrollmentUrl { - return "", idpAcc.UUID, eulaToken, originalURL, nil + return "", idpAcc.UUID, eulaToken, originalURL, ssoRequestData, nil } // get the automatic profile to access the authentication token. depProf, err := svc.getAutomaticEnrollmentProfile(ctx) if err != nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "listing profiles") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "listing profiles") } if depProf == nil { - return "", "", "", "", ctxerr.Wrap(ctx, err, "missing profile") + return "", "", "", "", sso.SSORequestData{}, ctxerr.Wrap(ctx, err, "missing profile") } // using the idp token as a reference just because that's the // only thing we're referencing later on during enrollment. - return depProf.Token, idpAcc.UUID, eulaToken, originalURL, nil + return depProf.Token, idpAcc.UUID, eulaToken, originalURL, ssoRequestData, nil } func (svc *Service) mdmAppleSyncDEPProfiles(ctx context.Context) error { diff --git a/frontend/pages/MDMAppleSSOCallbackPage/MDMAppleSSOCallbackPage.tsx b/frontend/pages/MDMAppleSSOCallbackPage/MDMAppleSSOCallbackPage.tsx index b9a985c66e..7ae9f88bec 100644 --- a/frontend/pages/MDMAppleSSOCallbackPage/MDMAppleSSOCallbackPage.tsx +++ b/frontend/pages/MDMAppleSSOCallbackPage/MDMAppleSSOCallbackPage.tsx @@ -7,6 +7,8 @@ import Spinner from "components/Spinner/Spinner"; import SSOError from "components/MDM/SSOError"; import Button from "components/buttons/Button"; +import AuthenticationFormWrapper from "components/AuthenticationFormWrapper"; + const baseClass = "mdm-apple-sso-callback-page"; const RedirectTo = ({ url }: { url: string }) => { @@ -18,6 +20,7 @@ interface IEnrollmentGateProps { profileToken?: string; eulaToken?: string; enrollmentReference?: string; + initiator?: string; error?: boolean; } @@ -25,6 +28,7 @@ const EnrollmentGate = ({ profileToken, eulaToken, enrollmentReference, + initiator, error, }: IEnrollmentGateProps) => { const [showEULA, setShowEULA] = useState(Boolean(eulaToken)); @@ -35,6 +39,16 @@ const EnrollmentGate = ({ return ; } + if (initiator === "setup_experience") { + return ( + +
+

You’re done! You may now close this window.

+
+
+ ); + } + if (showEULA && eulaToken) { return (
@@ -70,6 +84,7 @@ interface IMDMSSOCallbackQuery { eula_token?: string; profile_token?: string; enrollment_reference?: string; + initiator?: string; error?: boolean; } @@ -80,6 +95,7 @@ const MDMAppleSSOCallbackPage = ( eula_token, profile_token, enrollment_reference, + initiator, error, } = props.location.query; return ( @@ -88,6 +104,7 @@ const MDMAppleSSOCallbackPage = ( eulaToken={eula_token} profileToken={profile_token} enrollmentReference={enrollment_reference} + initiator={initiator} error={error} />
diff --git a/frontend/pages/MDMAppleSSOCallbackPage/_styles.scss b/frontend/pages/MDMAppleSSOCallbackPage/_styles.scss index 5140589b92..9f7e1b7d0e 100644 --- a/frontend/pages/MDMAppleSSOCallbackPage/_styles.scss +++ b/frontend/pages/MDMAppleSSOCallbackPage/_styles.scss @@ -19,4 +19,8 @@ &__agree-btn { width: 80%; } + + &.form { + height: auto; + } } diff --git a/frontend/pages/MDMAppleSSOPage/MDMAppleSSOPage.tsx b/frontend/pages/MDMAppleSSOPage/MDMAppleSSOPage.tsx index df3ae27e82..88b0cc3066 100644 --- a/frontend/pages/MDMAppleSSOPage/MDMAppleSSOPage.tsx +++ b/frontend/pages/MDMAppleSSOPage/MDMAppleSSOPage.tsx @@ -1,4 +1,4 @@ -import React from "react"; +import React, { useState } from "react"; import { useQuery } from "react-query"; import { AxiosError } from "axios"; import { WithRouterProps } from "react-router"; @@ -7,22 +7,29 @@ import mdmAPI, { IMDMSSOParams } from "services/entities/mdm"; import SSOError from "components/MDM/SSOError"; import Spinner from "components/Spinner/Spinner"; +import Button from "components/buttons/Button"; +import CustomLink from "components/CustomLink"; import { IMdmSSOReponse } from "interfaces/mdm"; +import AuthenticationFormWrapper from "components/AuthenticationFormWrapper"; const baseClass = "mdm-apple-sso-page"; const DEPSSOLoginPage = ({ location: { pathname, query }, }: WithRouterProps) => { + const [clickedLogin, setClickedLogin] = useState(false); localStorage.setItem("deviceinfo", query.deviceinfo || ""); - query.initiator = "mdm_sso"; - if (pathname === "/mdm/apple/account_driven_enroll/sso") { - query.initiator = "account_driven_enroll"; + if (!query.initiator) { + query.initiator = + pathname === "/mdm/apple/account_driven_enroll/sso" + ? "account_driven_enroll" + : "mdm_sso"; } const { error } = useQuery( ["dep_sso"], () => mdmAPI.initiateMDMAppleSSO(query), { + enabled: clickedLogin || query.initiator !== "setup_experience", retry: false, refetchOnWindowFocus: false, onSuccess: ({ url }) => { @@ -31,6 +38,35 @@ const DEPSSOLoginPage = ({ } ); + if (query.initiator === "setup_experience") { + return ( + +
+

+ Your organization requires you to authenticate before setting up + your device. Please sign in to continue. +

+ +

+ +

+
+
+ ); + } + return
{error ? : }
; }; diff --git a/frontend/pages/MDMAppleSSOPage/_styles.scss b/frontend/pages/MDMAppleSSOPage/_styles.scss index 1d3c32f2f7..b3a83e4795 100644 --- a/frontend/pages/MDMAppleSSOPage/_styles.scss +++ b/frontend/pages/MDMAppleSSOPage/_styles.scss @@ -6,4 +6,23 @@ display: flex; align-items: center; justify-content: center; + + &.form { + height: auto; + + .mdm-apple-sso-page__sso-btn { + width: 240px; + } + + } + + + &__transparency-link { + text-align: center; + + .custom-link { + font-size: $xxx-small; + } + + } } diff --git a/frontend/services/entities/mdm.ts b/frontend/services/entities/mdm.ts index ea1492f263..27ef2fa5fc 100644 --- a/frontend/services/entities/mdm.ts +++ b/frontend/services/entities/mdm.ts @@ -76,6 +76,9 @@ export interface IAppleSetupEnrollmentProfileResponse { export interface IMDMSSOParams { deviceinfo: string; initiator: string; + // optional host_uuid to link SSO to a specific host; used in Orbit-initiated + // enrollments with end-user authentication. + host_uuid?: string; } export interface IMDMAppleEnrollmentProfileParams { diff --git a/server/datastore/mysql/apple_mdm.go b/server/datastore/mysql/apple_mdm.go index 7eb9873cf1..4cffb0230a 100644 --- a/server/datastore/mysql/apple_mdm.go +++ b/server/datastore/mysql/apple_mdm.go @@ -3756,12 +3756,12 @@ func (ds *Datastore) InsertMDMIdPAccount(ctx context.Context, account *fleet.MDM INSERT INTO mdm_idp_accounts (uuid, username, fullname, email) VALUES - (UUID(), ?, ?, ?) - ON DUPLICATE KEY UPDATE + (COALESCE(NULLIF(TRIM(?), ''), UUID()), ?, ?, ?) + ON DUPLICATE KEY UPDATE username = VALUES(username), fullname = VALUES(fullname)` - _, err := ds.writer(ctx).ExecContext(ctx, stmt, account.Username, account.Fullname, account.Email) + _, err := ds.writer(ctx).ExecContext(ctx, stmt, account.UUID, account.Username, account.Fullname, account.Email) return ctxerr.Wrap(ctx, err, "creating new MDM IdP account") } @@ -6986,7 +6986,11 @@ func (ds *Datastore) ReconcileMDMAppleEnrollRef(ctx context.Context, enrollRef s return result, err } -func associateHostMDMIdPAccountDB(ctx context.Context, tx sqlx.ExtContext, hostUUID, acctUUID string) error { +func (ds *Datastore) AssociateHostMDMIdPAccountDB(ctx context.Context, hostUUID string, acctUUID string) error { + return associateHostMDMIdPAccountDB(ctx, ds.writer(ctx), hostUUID, acctUUID) +} + +func associateHostMDMIdPAccountDB(ctx context.Context, tx sqlx.ExtContext, hostUUID string, acctUUID string) error { const stmt = ` INSERT INTO host_mdm_idp_accounts (host_uuid, account_uuid) VALUES (?, ?) diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index f72c561243..7bba804c3f 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -1395,6 +1395,9 @@ type Datastore interface { // InsertMDMIdPAccount inserts a new MDM IdP account InsertMDMIdPAccount(ctx context.Context, account *MDMIdPAccount) error + // AssociateHostMDMIdPAccountDB associates a host with an MDM IdP account + AssociateHostMDMIdPAccountDB(ctx context.Context, hostUUID string, acctUUID string) error + // GetMDMIdPAccountByUUID returns MDM IdP account that matches the given token. GetMDMIdPAccountByUUID(ctx context.Context, uuid string) (*MDMIdPAccount, error) diff --git a/server/fleet/errors.go b/server/fleet/errors.go index 0ef4dca1d2..ee201aa15f 100644 --- a/server/fleet/errors.go +++ b/server/fleet/errors.go @@ -633,6 +633,7 @@ func (fe FleetdError) ToMap() map[string]any { // with a failed request's response. type OrbitError struct { Message string + code int } // Error implements the error interface for the OrbitError. @@ -640,6 +641,21 @@ func (e OrbitError) Error() string { return e.Message } +// StatusCode implements the ErrWithStatusCode interface for the OrbitError. +func (e OrbitError) StatusCode() int { + if e.code == 0 { + return http.StatusInternalServerError + } + return e.code +} + +func NewOrbitIDPAuthRequiredError() *OrbitError { + return &OrbitError{ + Message: "END_USER_AUTH_REQUIRED", + code: http.StatusUnauthorized, + } +} + // Message that may surfaced by the server or the fleetctl client. const ( // Hosts, general diff --git a/server/fleet/service.go b/server/fleet/service.go index 0542c12586..d2b547174c 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -181,7 +181,9 @@ type Service interface { // different from InitiateSSO because it receives a different // configuration and only supports a subset of the features (eg: we // don't want to allow IdP initiated authentications) - InitiateMDMSSO(ctx context.Context, initiator, customOriginalURL string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) + // When initiated from Orbit, the hostUUID is used to link the SSO + // session to a specific host. + InitiateMDMSSO(ctx context.Context, initiator, customOriginalURL string, hostUUID string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) // InitSSOCallback handles the IdP SAMLResponse and ensures the credentials are valid. // The sessionID is used to identify the SSO session and samlResponse is the raw SAMLResponse. diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index e84c32dbc6..4ce97b06e5 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -951,6 +951,8 @@ type GetMDMAppleProfilesSummaryFunc func(ctx context.Context, teamID *uint) (*fl type InsertMDMIdPAccountFunc func(ctx context.Context, account *fleet.MDMIdPAccount) error +type AssociateHostMDMIdPAccountDBFunc func(ctx context.Context, hostUUID string, acctUUID string) error + type GetMDMIdPAccountByUUIDFunc func(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error) type GetMDMIdPAccountByEmailFunc func(ctx context.Context, email string) (*fleet.MDMIdPAccount, error) @@ -2982,6 +2984,9 @@ type DataStore struct { InsertMDMIdPAccountFunc InsertMDMIdPAccountFunc InsertMDMIdPAccountFuncInvoked bool + AssociateHostMDMIdPAccountDBFunc AssociateHostMDMIdPAccountDBFunc + AssociateHostMDMIdPAccountDBFuncInvoked bool + GetMDMIdPAccountByUUIDFunc GetMDMIdPAccountByUUIDFunc GetMDMIdPAccountByUUIDFuncInvoked bool @@ -7190,6 +7195,13 @@ func (s *DataStore) InsertMDMIdPAccount(ctx context.Context, account *fleet.MDMI return s.InsertMDMIdPAccountFunc(ctx, account) } +func (s *DataStore) AssociateHostMDMIdPAccountDB(ctx context.Context, hostUUID string, acctUUID string) error { + s.mu.Lock() + s.AssociateHostMDMIdPAccountDBFuncInvoked = true + s.mu.Unlock() + return s.AssociateHostMDMIdPAccountDBFunc(ctx, hostUUID, acctUUID) +} + func (s *DataStore) GetMDMIdPAccountByUUID(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error) { s.mu.Lock() s.GetMDMIdPAccountByUUIDFuncInvoked = true diff --git a/server/mock/service/service_mock.go b/server/mock/service/service_mock.go index e4cb18e21a..604cc871df 100644 --- a/server/mock/service/service_mock.go +++ b/server/mock/service/service_mock.go @@ -85,7 +85,7 @@ type GetUserSettingsFunc func(ctx context.Context, id uint) (settings *fleet.Use type InitiateSSOFunc func(ctx context.Context, redirectURL string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) -type InitiateMDMSSOFunc func(ctx context.Context, initiator string, customOriginalURL string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) +type InitiateMDMSSOFunc func(ctx context.Context, initiator string, customOriginalURL string, hostUUID string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) type InitSSOCallbackFunc func(ctx context.Context, sessionID string, samlResponse []byte) (auth fleet.Auth, redirectURL string, err error) @@ -2314,11 +2314,11 @@ func (s *Service) InitiateSSO(ctx context.Context, redirectURL string) (sessionI return s.InitiateSSOFunc(ctx, redirectURL) } -func (s *Service) InitiateMDMSSO(ctx context.Context, initiator string, customOriginalURL string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) { +func (s *Service) InitiateMDMSSO(ctx context.Context, initiator string, customOriginalURL string, hostUUID string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) { s.mu.Lock() s.InitiateMDMSSOFuncInvoked = true s.mu.Unlock() - return s.InitiateMDMSSOFunc(ctx, initiator, customOriginalURL) + return s.InitiateMDMSSOFunc(ctx, initiator, customOriginalURL, hostUUID) } func (s *Service) InitSSOCallback(ctx context.Context, sessionID string, samlResponse []byte) (auth fleet.Auth, redirectURL string, err error) { diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index 2ddc8c4dbb..88b0cd605e 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -3341,8 +3341,9 @@ func (svc *Service) UpdateMDMAppleSetup(ctx context.Context, payload fleet.MDMAp //////////////////////////////////////////////////////////////////////////////// type initiateMDMSSORequest struct { - Initiator string `json:"initiator,omitempty"` // optional, passed by the UI during account-driven enrollment + Initiator string `json:"initiator,omitempty"` // optional, passed by the UI during account-driven enrollment, or by Orbit for non-Apple IdP auth. UserIdentifier string `json:"user_identifier,omitempty"` // optional, passed by Apple for account-driven enrollment + HostUUID string `json:"host_uuid,omitempty"` // optional, passed by Orbit for non-Apple IdP auth } type initiateMDMSSOResponse struct { @@ -3361,7 +3362,7 @@ func (r initiateMDMSSOResponse) SetCookies(_ context.Context, w http.ResponseWri func initiateMDMSSOEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) { req := request.(*initiateMDMSSORequest) - sessionID, sessionDurationSeconds, idpProviderURL, err := svc.InitiateMDMSSO(ctx, req.Initiator, "") + sessionID, sessionDurationSeconds, idpProviderURL, err := svc.InitiateMDMSSO(ctx, req.Initiator, "", req.HostUUID) if err != nil { return initiateMDMSSOResponse{Err: err}, nil } @@ -3374,7 +3375,7 @@ func initiateMDMSSOEndpoint(ctx context.Context, request interface{}, svc fleet. }, nil } -func (svc *Service) InitiateMDMSSO(ctx context.Context, initiator, customOriginalURL string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) { +func (svc *Service) InitiateMDMSSO(ctx context.Context, initiator, customOriginalURL string, hostUUID string) (sessionID string, sessionDurationSeconds int, idpURL string, err error) { // skipauth: No authorization check needed due to implementation // returning only license error. svc.authz.SkipAuthorization(ctx) diff --git a/server/service/frontend.go b/server/service/frontend.go index 973af030ce..5bdeb3a1bf 100644 --- a/server/service/frontend.go +++ b/server/service/frontend.go @@ -224,7 +224,7 @@ func renderEnrollPage(w io.Writer, appCfg *fleet.AppConfig, urlPrefix, enrollSec } func initiateOTAEnrollSSO(svc fleet.Service, w http.ResponseWriter, r *http.Request, enrollSecret string) error { - ssnID, ssnDurationSecs, idpURL, err := svc.InitiateMDMSSO(r.Context(), "ota_enroll", "/enroll?enroll_secret="+url.QueryEscape(enrollSecret)) + ssnID, ssnDurationSecs, idpURL, err := svc.InitiateMDMSSO(r.Context(), "ota_enroll", "/enroll?enroll_secret="+url.QueryEscape(enrollSecret), "") if err != nil { return err } diff --git a/server/service/orbit.go b/server/service/orbit.go index e1b1ec4c46..bee958e178 100644 --- a/server/service/orbit.go +++ b/server/service/orbit.go @@ -177,6 +177,29 @@ func (svc *Service) EnrollOrbit(ctx context.Context, hostInfo fleet.OrbitHostInf if err != nil { return "", fleet.OrbitError{Message: "app config load failed: " + err.Error()} } + isEndUserAuthRequired := appConfig.MDM.MacOSSetup.EnableEndUserAuthentication + // If the secret is for a team, get the team config as well. + if secret.TeamID != nil { + team, err := svc.ds.Team(ctx, *secret.TeamID) + if err != nil { + return "", fleet.OrbitError{Message: "failed to get team config: " + err.Error()} + } + isEndUserAuthRequired = team.Config.MDM.MacOSSetup.EnableEndUserAuthentication + } + + if isEndUserAuthRequired { + if hostInfo.HardwareUUID == "" { + return "", fleet.OrbitError{Message: "failed to get IdP account: hardware uuid is empty"} + } + // Try to find an IdP account for this host. + idpAccount, err := svc.ds.GetMDMIdPAccountByHostUUID(ctx, hostInfo.HardwareUUID) + if err != nil { + return "", fleet.OrbitError{Message: "failed to get IdP account: " + err.Error()} + } + if idpAccount == nil { + return "", fleet.NewOrbitIDPAuthRequiredError() + } + } var stickyEnrollment *string if svc.keyValueStore != nil { diff --git a/server/service/sessions.go b/server/service/sessions.go index 8e3f6ef5f2..c693206a65 100644 --- a/server/service/sessions.go +++ b/server/service/sessions.go @@ -493,6 +493,7 @@ func (svc *Service) InitiateSSO(ctx context.Context, redirectURL string) (sessio sessionID, idpURL, err = sso.CreateAuthorizationRequest( ctx, samlProvider, svc.ssoSessionStore, redirectURL, uint(sessionDurationSeconds), //nolint:gosec // dismiss G115 + sso.SSORequestData{}, ) if err != nil { return "", 0, "", ctxerr.Wrap(ctx, err, "InitiateSSO creating authorization") diff --git a/server/sso/authorization_request.go b/server/sso/authorization_request.go index 5e1821698c..9a8eeae017 100644 --- a/server/sso/authorization_request.go +++ b/server/sso/authorization_request.go @@ -35,6 +35,7 @@ func CreateAuthorizationRequest( sessionStore SessionStore, originalURL string, sessionTTLSeconds uint, + requestData SSORequestData, ) (sessionID string, idpURL string, err error) { idpURL, err = getDestinationURL(samlProvider.IDPMetadata) if err != nil { @@ -76,6 +77,7 @@ func CreateAuthorizationRequest( originalURL, metadataWriter.String(), sessionLifetimeSeconds, + requestData, ) if err != nil { return "", "", fmt.Errorf("caching SSO session while creating auth request: %w", err) diff --git a/server/sso/authorization_request_test.go b/server/sso/authorization_request_test.go index f90ca088ad..01004ad117 100644 --- a/server/sso/authorization_request_test.go +++ b/server/sso/authorization_request_test.go @@ -47,6 +47,10 @@ func TestCreateAuthorizationRequest(t *testing.T) { store, "/redir", 0, + SSORequestData{ + HostUUID: "host-uuid-123", + Initiator: "test_initiator", + }, ) require.NoError(t, err) assert.Equal(t, 300*time.Second, store.sessionLifetime) // check default is used @@ -67,6 +71,8 @@ func TestCreateAuthorizationRequest(t *testing.T) { ssn := store.session require.NotNil(t, ssn) assert.Equal(t, "/redir", ssn.OriginalURL) + assert.Equal(t, "host-uuid-123", ssn.RequestData.HostUUID) + assert.Equal(t, "test_initiator", ssn.RequestData.Initiator) assert.Equal(t, 5*time.Minute, store.sessionLifetime) var meta saml.EntityDescriptor @@ -79,6 +85,7 @@ func TestCreateAuthorizationRequest(t *testing.T) { store, "/redir", sessionTTL, + SSORequestData{}, ) require.NoError(t, err) assert.Equal(t, 1*time.Hour, store.sessionLifetime) @@ -105,11 +112,12 @@ type mockStore struct { sessionLifetime time.Duration } -func (s *mockStore) create(sessionID, requestID, originalURL, metadata string, lifetimeSecs uint) error { +func (s *mockStore) create(sessionID, requestID, originalURL, metadata string, lifetimeSecs uint, requestData SSORequestData) error { s.session = &Session{ RequestID: requestID, OriginalURL: originalURL, Metadata: metadata, + RequestData: requestData, } s.sessionLifetime = time.Duration(lifetimeSecs) * time.Second // nolint:gosec // dismiss G115 return nil diff --git a/server/sso/saml_provider.go b/server/sso/saml_provider.go index cfdb0d64ce..73ed97c2e1 100644 --- a/server/sso/saml_provider.go +++ b/server/sso/saml_provider.go @@ -64,14 +64,14 @@ func SAMLProviderFromSession( acsURL *url.URL, entityID string, expectedAudiences []string, -) (samlProvider *saml.ServiceProvider, requestID, originalURL string, err error) { +) (samlProvider *saml.ServiceProvider, requestID, originalURL string, ssoRequestData SSORequestData, err error) { session, err := sessionStore.Fullfill(sessionID) if err != nil { - return nil, "", "", ctxerr.Wrap(ctx, err, "validate request in session") + return nil, "", "", SSORequestData{}, ctxerr.Wrap(ctx, err, "validate request in session") } entityDescriptor, err := ParseMetadata([]byte(session.Metadata)) if err != nil { - return nil, "", "", ctxerr.Wrap(ctx, err, "failed to parse metadata") + return nil, "", "", SSORequestData{}, ctxerr.Wrap(ctx, err, "failed to parse metadata") } return &saml.ServiceProvider{ @@ -81,7 +81,7 @@ func SAMLProviderFromSession( ValidateAudienceRestriction: func(assertion *saml.Assertion) error { return validateAudiences(assertion, expectedAudiences) }, - }, session.RequestID, session.OriginalURL, nil + }, session.RequestID, session.OriginalURL, session.RequestData, nil } // SAMLProviderFromSessionOrConfiguredMetadata creates a SAML provider that can validate SAML responses. diff --git a/server/sso/session_store.go b/server/sso/session_store.go index 1c9b974938..b180c40a04 100644 --- a/server/sso/session_store.go +++ b/server/sso/session_store.go @@ -11,6 +11,11 @@ import ( redigo "github.com/gomodule/redigo/redis" ) +type SSORequestData struct { + HostUUID string `json:"host_uuid,omitempty"` + Initiator string `json:"initiator,omitempty"` +} + // Session stores state for the lifetime of a single sign on session. type Session struct { // RequestID is the SAMLRequest ID that must match "InResponseTo" in the SAMLResponse. @@ -19,6 +24,8 @@ type Session struct { Metadata string `json:"metadata"` // OriginalURL is the resource being accessed when login request was triggered OriginalURL string `json:"original_url"` + // Additional request data that may be needed to complete the SSO process. + RequestData SSORequestData `json:"request_data,omitempty"` } // SessionStore persists state of a sso session across process boundries and @@ -27,7 +34,7 @@ type Session struct { // is constrained in the backing store (Redis) so if the sso process is not completed in // a reasonable amount of time, it automatically expires and is removed. type SessionStore interface { - create(sessionID, requestID, originalURL, metadata string, lifetimeSecs uint) error + create(sessionID, requestID, originalURL, metadata string, lifetimeSecs uint, requestData SSORequestData) error get(sessionID string) (*Session, error) expire(sessionID string) error // Fullfill loads a session with the given session ID, deletes it and returns it. @@ -43,7 +50,7 @@ type store struct { pool fleet.RedisPool } -func (s *store) create(sessionID, requestID, originalURL, metadata string, lifetimeSecs uint) error { +func (s *store) create(sessionID, requestID, originalURL, metadata string, lifetimeSecs uint, requestData SSORequestData) error { if len(sessionID) < 8 { return errors.New("request id must be 8 or more characters in length") } @@ -54,6 +61,7 @@ func (s *store) create(sessionID, requestID, originalURL, metadata string, lifet RequestID: requestID, Metadata: metadata, OriginalURL: originalURL, + RequestData: requestData, } var writer bytes.Buffer err := json.NewEncoder(&writer).Encode(session) diff --git a/server/sso/session_store_test.go b/server/sso/session_store_test.go index 63347f30de..2e6a6d7919 100644 --- a/server/sso/session_store_test.go +++ b/server/sso/session_store_test.go @@ -15,7 +15,7 @@ func TestSessionStore(t *testing.T) { store := NewSessionStore(pool) // Create session that lives for 1 second. - err := store.create("sessionID123", "requestID123", "https://originalurl.com", "some metadata", 1) + err := store.create("sessionID123", "requestID123", "https://originalurl.com", "some metadata", 1, SSORequestData{HostUUID: "host-uuid-123"}) require.NoError(t, err) sess, err := store.get("sessionID123") @@ -24,6 +24,7 @@ func TestSessionStore(t *testing.T) { assert.Equal(t, "requestID123", sess.RequestID) assert.Equal(t, "https://originalurl.com", sess.OriginalURL) assert.Equal(t, "some metadata", sess.Metadata) + assert.Equal(t, "host-uuid-123", sess.RequestData.HostUUID) // Wait a little bit more than one second, session should no longer be present. time.Sleep(1100 * time.Millisecond) @@ -33,7 +34,7 @@ func TestSessionStore(t *testing.T) { assert.Nil(t, sess) // Create another session for 1 second - err = store.create("sessionID456", "requestID456", "https://originalurl.com", "some metadata", 1) + err = store.create("sessionID456", "requestID456", "https://originalurl.com", "some metadata", 1, SSORequestData{}) require.NoError(t, err) // Forcefully expire it