mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Added signup_token for Android signup callback authentication. (#26681)
For #26218 - Added signup_token authentication for Android enterprise callback and fixed API path to match API doc # Checklist for submitter - [x] Added/updated automated tests - [x] Manual QA for all new/changed functionality
This commit is contained in:
parent
3d9072981b
commit
b21f54d648
9 changed files with 71 additions and 15 deletions
|
|
@ -9,6 +9,7 @@ import (
|
|||
type Datastore interface {
|
||||
CreateEnterprise(ctx context.Context, userID uint) (uint, error)
|
||||
GetEnterpriseByID(ctx context.Context, ID uint) (*EnterpriseDetails, error)
|
||||
GetEnterpriseBySignupToken(ctx context.Context, signupToken string) (*EnterpriseDetails, error)
|
||||
GetEnterprise(ctx context.Context) (*Enterprise, error)
|
||||
UpdateEnterprise(ctx context.Context, enterprise *EnterpriseDetails) error
|
||||
DeleteAllEnterprises(ctx context.Context) error
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ type CreateEnterpriseFunc func(ctx context.Context, userID uint) (uint, error)
|
|||
|
||||
type GetEnterpriseByIDFunc func(ctx context.Context, ID uint) (*android.EnterpriseDetails, error)
|
||||
|
||||
type GetEnterpriseBySignupTokenFunc func(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error)
|
||||
|
||||
type GetEnterpriseFunc func(ctx context.Context) (*android.Enterprise, error)
|
||||
|
||||
type UpdateEnterpriseFunc func(ctx context.Context, enterprise *android.EnterpriseDetails) error
|
||||
|
|
@ -35,6 +37,9 @@ type Datastore struct {
|
|||
GetEnterpriseByIDFunc GetEnterpriseByIDFunc
|
||||
GetEnterpriseByIDFuncInvoked bool
|
||||
|
||||
GetEnterpriseBySignupTokenFunc GetEnterpriseBySignupTokenFunc
|
||||
GetEnterpriseBySignupTokenFuncInvoked bool
|
||||
|
||||
GetEnterpriseFunc GetEnterpriseFunc
|
||||
GetEnterpriseFuncInvoked bool
|
||||
|
||||
|
|
@ -70,6 +75,13 @@ func (ds *Datastore) GetEnterpriseByID(ctx context.Context, ID uint) (*android.E
|
|||
return ds.GetEnterpriseByIDFunc(ctx, ID)
|
||||
}
|
||||
|
||||
func (ds *Datastore) GetEnterpriseBySignupToken(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) {
|
||||
ds.mu.Lock()
|
||||
ds.GetEnterpriseBySignupTokenFuncInvoked = true
|
||||
ds.mu.Unlock()
|
||||
return ds.GetEnterpriseBySignupTokenFunc(ctx, signupToken)
|
||||
}
|
||||
|
||||
func (ds *Datastore) GetEnterprise(ctx context.Context) (*android.Enterprise, error) {
|
||||
ds.mu.Lock()
|
||||
ds.GetEnterpriseFuncInvoked = true
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package mock
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/android"
|
||||
)
|
||||
|
|
@ -19,6 +20,12 @@ func (s *Datastore) InitCommonMocks() {
|
|||
s.GetEnterpriseByIDFunc = func(ctx context.Context, ID uint) (*android.EnterpriseDetails, error) {
|
||||
return &android.EnterpriseDetails{}, nil
|
||||
}
|
||||
s.GetEnterpriseBySignupTokenFunc = func(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) {
|
||||
if signupToken == "signup_token" {
|
||||
return &android.EnterpriseDetails{}, nil
|
||||
}
|
||||
return nil, ¬FoundError{errors.New("not found")}
|
||||
}
|
||||
s.DeleteAllEnterprisesFunc = func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -26,3 +33,11 @@ func (s *Datastore) InitCommonMocks() {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type notFoundError struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (e *notFoundError) IsNotFound() bool {
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,19 @@ func (ds *Datastore) GetEnterpriseByID(ctx context.Context, id uint) (*android.E
|
|||
return &enterprise, nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) GetEnterpriseBySignupToken(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) {
|
||||
stmt := `SELECT id, signup_name, enterprise_id, pubsub_topic_id, signup_token, user_id FROM android_enterprises WHERE signup_token = ?`
|
||||
var enterprise android.EnterpriseDetails
|
||||
err := sqlx.GetContext(ctx, ds.reader(ctx), &enterprise, stmt, signupToken)
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return nil, common_mysql.NotFound("Android enterprise")
|
||||
case err != nil:
|
||||
return nil, ctxerr.Wrap(ctx, err, "getting enterprise by signup token")
|
||||
}
|
||||
return &enterprise, nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) GetEnterprise(ctx context.Context) (*android.Enterprise, error) {
|
||||
stmt := `SELECT id, enterprise_id FROM android_enterprises WHERE enterprise_id != '' LIMIT 1`
|
||||
var enterprise android.Enterprise
|
||||
|
|
|
|||
|
|
@ -73,6 +73,10 @@ func testUpdateEnterprise(t *testing.T, ds *Datastore) {
|
|||
require.NoError(t, err)
|
||||
assert.Equal(t, enterprise, resultEnriched)
|
||||
|
||||
resultEnrichedByToken, err := ds.GetEnterpriseBySignupToken(testCtx(), enterprise.SignupToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, enterprise, resultEnrichedByToken)
|
||||
|
||||
result, err := ds.GetEnterprise(testCtx())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, enterprise.Enterprise, *result)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
type Service interface {
|
||||
EnterpriseSignup(ctx context.Context) (*SignupDetails, error)
|
||||
EnterpriseSignupCallback(ctx context.Context, enterpriseID uint, enterpriseToken string) error
|
||||
EnterpriseSignupCallback(ctx context.Context, signupToken string, enterpriseToken string) error
|
||||
GetEnterprise(ctx context.Context) (*Enterprise, error)
|
||||
DeleteEnterprise(ctx context.Context) error
|
||||
EnterpriseSignupSSE(ctx context.Context) (chan string, error)
|
||||
|
|
|
|||
|
|
@ -110,8 +110,10 @@ func TestEnterprisesAuth(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
err := svc.EnterpriseSignupCallback(context.Background(), 1, "token")
|
||||
err := svc.EnterpriseSignupCallback(context.Background(), "signup_token", "token")
|
||||
checkAuthErr(t, false, err)
|
||||
err = svc.EnterpriseSignupCallback(context.Background(), "bad_token", "token")
|
||||
checkAuthErr(t, true, err)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Ser
|
|||
// These endpoints should do custom one-time authentication by verifying that a valid secret token is provided with the request.
|
||||
ne := newNoAuthEndpointer(fleetSvc, svc, opts, r, apiVersions()...)
|
||||
|
||||
ne.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/connect", enterpriseSignupCallbackEndpoint, enterpriseSignupCallbackRequest{})
|
||||
ne.GET("/api/_version_/fleet/android_enterprise/connect/{token}", enterpriseSignupCallbackEndpoint, enterpriseSignupCallbackRequest{})
|
||||
ne.GET("/api/_version_/fleet/android_enterprise/enrollment_token", enrollmentTokenEndpoint, enrollmentTokenRequest{})
|
||||
ne.POST(pubSubPushPath, pubSubPushEndpoint, pubSubPushRequest{})
|
||||
|
||||
|
|
|
|||
|
|
@ -101,7 +101,13 @@ func (svc *Service) EnterpriseSignup(ctx context.Context) (*android.SignupDetail
|
|||
return nil, ctxerr.Wrap(ctx, err, "creating enterprise")
|
||||
}
|
||||
|
||||
callbackURL := fmt.Sprintf("%s/api/v1/fleet/android_enterprise/%d/connect", appConfig.ServerSettings.ServerURL, id)
|
||||
// signupToken is used to authenticate the signup callback URL -- to ensure that the callback came from our Android enterprise signup flow
|
||||
signupToken, err := server.GenerateRandomURLSafeText(32)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "generating Android enterprise signup token")
|
||||
}
|
||||
|
||||
callbackURL := fmt.Sprintf("%s/api/v1/fleet/android_enterprise/connect/%s", appConfig.ServerSettings.ServerURL, signupToken)
|
||||
signupDetails, err := svc.proxy.SignupURLsCreate(callbackURL)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "creating signup url")
|
||||
|
|
@ -111,7 +117,8 @@ func (svc *Service) EnterpriseSignup(ctx context.Context) (*android.SignupDetail
|
|||
Enterprise: android.Enterprise{
|
||||
ID: id,
|
||||
},
|
||||
SignupName: signupDetails.Name,
|
||||
SignupName: signupDetails.Name,
|
||||
SignupToken: signupToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "updating enterprise")
|
||||
|
|
@ -133,20 +140,23 @@ func (svc *Service) checkIfAndroidAlreadyConfigured(ctx context.Context) (*fleet
|
|||
}
|
||||
|
||||
type enterpriseSignupCallbackRequest struct {
|
||||
ID uint `url:"id"`
|
||||
SignupToken string `url:"token"`
|
||||
EnterpriseToken string `query:"enterpriseToken"`
|
||||
}
|
||||
|
||||
func enterpriseSignupCallbackEndpoint(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer {
|
||||
req := request.(*enterpriseSignupCallbackRequest)
|
||||
err := svc.EnterpriseSignupCallback(ctx, req.ID, req.EnterpriseToken)
|
||||
err := svc.EnterpriseSignupCallback(ctx, req.SignupToken, req.EnterpriseToken)
|
||||
return android.DefaultResponse{Err: err}
|
||||
}
|
||||
|
||||
func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enterpriseToken string) error {
|
||||
// Skip authorization because the callback is called by Google.
|
||||
// TODO(26218): Add some authorization here so random people can't bind random Android enterprises just for fun.
|
||||
// This call will fail if Proxy (Google Project) is not configured.
|
||||
// EnterpriseSignupCallback handles the callback from Google UI during signup flow.
|
||||
// signupToken is for authentication with Fleet
|
||||
// enterpriseToken is for authentication with Google
|
||||
func (svc *Service) EnterpriseSignupCallback(ctx context.Context, signupToken string, enterpriseToken string) error {
|
||||
// Authorization is done by GetEnterpriseBySignupToken below.
|
||||
// We call SkipAuthorization here to avoid explicitly calling it when errors occur.
|
||||
// Also, this method call will fail if Proxy (Google Project) is not configured.
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
||||
appConfig, err := svc.checkIfAndroidAlreadyConfigured(ctx)
|
||||
|
|
@ -154,11 +164,10 @@ func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enter
|
|||
return err
|
||||
}
|
||||
|
||||
enterprise, err := svc.ds.GetEnterpriseByID(ctx, id)
|
||||
enterprise, err := svc.ds.GetEnterpriseBySignupToken(ctx, signupToken)
|
||||
switch {
|
||||
case fleet.IsNotFound(err):
|
||||
return fleet.NewInvalidArgumentError("id",
|
||||
fmt.Sprintf("Enterprise with ID %d not found", id)).WithStatus(http.StatusNotFound)
|
||||
return authz.ForbiddenWithInternal("invalid signup token", nil, nil, nil)
|
||||
case err != nil:
|
||||
return ctxerr.Wrap(ctx, err, "getting enterprise")
|
||||
}
|
||||
|
|
@ -232,7 +241,7 @@ func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enter
|
|||
return ctxerr.Wrapf(ctx, err, "patching %d policy", defaultAndroidPolicyID)
|
||||
}
|
||||
|
||||
err = svc.ds.DeleteOtherEnterprises(ctx, id)
|
||||
err = svc.ds.DeleteOtherEnterprises(ctx, enterprise.ID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "deleting temp enterprises")
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue