diff --git a/server/mdm/android/datastore.go b/server/mdm/android/datastore.go index d693e646dc..51b0bb4309 100644 --- a/server/mdm/android/datastore.go +++ b/server/mdm/android/datastore.go @@ -9,6 +9,7 @@ import ( type Datastore interface { CreateEnterprise(ctx context.Context, userID uint) (uint, error) GetEnterpriseByID(ctx context.Context, ID uint) (*EnterpriseDetails, error) + GetEnterpriseBySignupToken(ctx context.Context, signupToken string) (*EnterpriseDetails, error) GetEnterprise(ctx context.Context) (*Enterprise, error) UpdateEnterprise(ctx context.Context, enterprise *EnterpriseDetails) error DeleteAllEnterprises(ctx context.Context) error diff --git a/server/mdm/android/mock/datastore.go b/server/mdm/android/mock/datastore.go index 15dac12194..1b34d36e24 100644 --- a/server/mdm/android/mock/datastore.go +++ b/server/mdm/android/mock/datastore.go @@ -16,6 +16,8 @@ type CreateEnterpriseFunc func(ctx context.Context, userID uint) (uint, error) type GetEnterpriseByIDFunc func(ctx context.Context, ID uint) (*android.EnterpriseDetails, error) +type GetEnterpriseBySignupTokenFunc func(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) + type GetEnterpriseFunc func(ctx context.Context) (*android.Enterprise, error) type UpdateEnterpriseFunc func(ctx context.Context, enterprise *android.EnterpriseDetails) error @@ -35,6 +37,9 @@ type Datastore struct { GetEnterpriseByIDFunc GetEnterpriseByIDFunc GetEnterpriseByIDFuncInvoked bool + GetEnterpriseBySignupTokenFunc GetEnterpriseBySignupTokenFunc + GetEnterpriseBySignupTokenFuncInvoked bool + GetEnterpriseFunc GetEnterpriseFunc GetEnterpriseFuncInvoked bool @@ -70,6 +75,13 @@ func (ds *Datastore) GetEnterpriseByID(ctx context.Context, ID uint) (*android.E return ds.GetEnterpriseByIDFunc(ctx, ID) } +func (ds *Datastore) GetEnterpriseBySignupToken(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) { + ds.mu.Lock() + ds.GetEnterpriseBySignupTokenFuncInvoked = true + ds.mu.Unlock() + return ds.GetEnterpriseBySignupTokenFunc(ctx, signupToken) +} + func (ds *Datastore) GetEnterprise(ctx context.Context) (*android.Enterprise, error) { ds.mu.Lock() ds.GetEnterpriseFuncInvoked = true diff --git a/server/mdm/android/mock/datastore_setup.go b/server/mdm/android/mock/datastore_setup.go index 9a27584eb5..702bdf73d0 100644 --- a/server/mdm/android/mock/datastore_setup.go +++ b/server/mdm/android/mock/datastore_setup.go @@ -2,6 +2,7 @@ package mock import ( "context" + "errors" "github.com/fleetdm/fleet/v4/server/mdm/android" ) @@ -19,6 +20,12 @@ func (s *Datastore) InitCommonMocks() { s.GetEnterpriseByIDFunc = func(ctx context.Context, ID uint) (*android.EnterpriseDetails, error) { return &android.EnterpriseDetails{}, nil } + s.GetEnterpriseBySignupTokenFunc = func(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) { + if signupToken == "signup_token" { + return &android.EnterpriseDetails{}, nil + } + return nil, ¬FoundError{errors.New("not found")} + } s.DeleteAllEnterprisesFunc = func(ctx context.Context) error { return nil } @@ -26,3 +33,11 @@ func (s *Datastore) InitCommonMocks() { return nil } } + +type notFoundError struct { + error +} + +func (e *notFoundError) IsNotFound() bool { + return true +} diff --git a/server/mdm/android/mysql/enterprises.go b/server/mdm/android/mysql/enterprises.go index b3bccd02ce..1b5a1951a7 100644 --- a/server/mdm/android/mysql/enterprises.go +++ b/server/mdm/android/mysql/enterprises.go @@ -35,6 +35,19 @@ func (ds *Datastore) GetEnterpriseByID(ctx context.Context, id uint) (*android.E return &enterprise, nil } +func (ds *Datastore) GetEnterpriseBySignupToken(ctx context.Context, signupToken string) (*android.EnterpriseDetails, error) { + stmt := `SELECT id, signup_name, enterprise_id, pubsub_topic_id, signup_token, user_id FROM android_enterprises WHERE signup_token = ?` + var enterprise android.EnterpriseDetails + err := sqlx.GetContext(ctx, ds.reader(ctx), &enterprise, stmt, signupToken) + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, common_mysql.NotFound("Android enterprise") + case err != nil: + return nil, ctxerr.Wrap(ctx, err, "getting enterprise by signup token") + } + return &enterprise, nil +} + func (ds *Datastore) GetEnterprise(ctx context.Context) (*android.Enterprise, error) { stmt := `SELECT id, enterprise_id FROM android_enterprises WHERE enterprise_id != '' LIMIT 1` var enterprise android.Enterprise diff --git a/server/mdm/android/mysql/enterprises_test.go b/server/mdm/android/mysql/enterprises_test.go index 8baf5701b1..76f5b6ebed 100644 --- a/server/mdm/android/mysql/enterprises_test.go +++ b/server/mdm/android/mysql/enterprises_test.go @@ -73,6 +73,10 @@ func testUpdateEnterprise(t *testing.T, ds *Datastore) { require.NoError(t, err) assert.Equal(t, enterprise, resultEnriched) + resultEnrichedByToken, err := ds.GetEnterpriseBySignupToken(testCtx(), enterprise.SignupToken) + require.NoError(t, err) + assert.Equal(t, enterprise, resultEnrichedByToken) + result, err := ds.GetEnterprise(testCtx()) require.NoError(t, err) assert.Equal(t, enterprise.Enterprise, *result) diff --git a/server/mdm/android/service.go b/server/mdm/android/service.go index 72a2657e14..026b583be4 100644 --- a/server/mdm/android/service.go +++ b/server/mdm/android/service.go @@ -6,7 +6,7 @@ import ( type Service interface { EnterpriseSignup(ctx context.Context) (*SignupDetails, error) - EnterpriseSignupCallback(ctx context.Context, enterpriseID uint, enterpriseToken string) error + EnterpriseSignupCallback(ctx context.Context, signupToken string, enterpriseToken string) error GetEnterprise(ctx context.Context) (*Enterprise, error) DeleteEnterprise(ctx context.Context) error EnterpriseSignupSSE(ctx context.Context) (chan string, error) diff --git a/server/mdm/android/service/enterprises_test.go b/server/mdm/android/service/enterprises_test.go index 34968cd8e9..146a84ff6c 100644 --- a/server/mdm/android/service/enterprises_test.go +++ b/server/mdm/android/service/enterprises_test.go @@ -110,8 +110,10 @@ func TestEnterprisesAuth(t *testing.T) { } t.Run("unauthorized", func(t *testing.T) { - err := svc.EnterpriseSignupCallback(context.Background(), 1, "token") + err := svc.EnterpriseSignupCallback(context.Background(), "signup_token", "token") checkAuthErr(t, false, err) + err = svc.EnterpriseSignupCallback(context.Background(), "bad_token", "token") + checkAuthErr(t, true, err) }) } diff --git a/server/mdm/android/service/handler.go b/server/mdm/android/service/handler.go index 9ede251100..17ffefd313 100644 --- a/server/mdm/android/service/handler.go +++ b/server/mdm/android/service/handler.go @@ -32,7 +32,7 @@ func attachFleetAPIRoutes(r *mux.Router, fleetSvc fleet.Service, svc android.Ser // These endpoints should do custom one-time authentication by verifying that a valid secret token is provided with the request. ne := newNoAuthEndpointer(fleetSvc, svc, opts, r, apiVersions()...) - ne.GET("/api/_version_/fleet/android_enterprise/{id:[0-9]+}/connect", enterpriseSignupCallbackEndpoint, enterpriseSignupCallbackRequest{}) + ne.GET("/api/_version_/fleet/android_enterprise/connect/{token}", enterpriseSignupCallbackEndpoint, enterpriseSignupCallbackRequest{}) ne.GET("/api/_version_/fleet/android_enterprise/enrollment_token", enrollmentTokenEndpoint, enrollmentTokenRequest{}) ne.POST(pubSubPushPath, pubSubPushEndpoint, pubSubPushRequest{}) diff --git a/server/mdm/android/service/service.go b/server/mdm/android/service/service.go index 9779163434..eeb9bdfd0a 100644 --- a/server/mdm/android/service/service.go +++ b/server/mdm/android/service/service.go @@ -101,7 +101,13 @@ func (svc *Service) EnterpriseSignup(ctx context.Context) (*android.SignupDetail return nil, ctxerr.Wrap(ctx, err, "creating enterprise") } - callbackURL := fmt.Sprintf("%s/api/v1/fleet/android_enterprise/%d/connect", appConfig.ServerSettings.ServerURL, id) + // signupToken is used to authenticate the signup callback URL -- to ensure that the callback came from our Android enterprise signup flow + signupToken, err := server.GenerateRandomURLSafeText(32) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "generating Android enterprise signup token") + } + + callbackURL := fmt.Sprintf("%s/api/v1/fleet/android_enterprise/connect/%s", appConfig.ServerSettings.ServerURL, signupToken) signupDetails, err := svc.proxy.SignupURLsCreate(callbackURL) if err != nil { return nil, ctxerr.Wrap(ctx, err, "creating signup url") @@ -111,7 +117,8 @@ func (svc *Service) EnterpriseSignup(ctx context.Context) (*android.SignupDetail Enterprise: android.Enterprise{ ID: id, }, - SignupName: signupDetails.Name, + SignupName: signupDetails.Name, + SignupToken: signupToken, }) if err != nil { return nil, ctxerr.Wrap(ctx, err, "updating enterprise") @@ -133,20 +140,23 @@ func (svc *Service) checkIfAndroidAlreadyConfigured(ctx context.Context) (*fleet } type enterpriseSignupCallbackRequest struct { - ID uint `url:"id"` + SignupToken string `url:"token"` EnterpriseToken string `query:"enterpriseToken"` } func enterpriseSignupCallbackEndpoint(ctx context.Context, request interface{}, svc android.Service) fleet.Errorer { req := request.(*enterpriseSignupCallbackRequest) - err := svc.EnterpriseSignupCallback(ctx, req.ID, req.EnterpriseToken) + err := svc.EnterpriseSignupCallback(ctx, req.SignupToken, req.EnterpriseToken) return android.DefaultResponse{Err: err} } -func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enterpriseToken string) error { - // Skip authorization because the callback is called by Google. - // TODO(26218): Add some authorization here so random people can't bind random Android enterprises just for fun. - // This call will fail if Proxy (Google Project) is not configured. +// EnterpriseSignupCallback handles the callback from Google UI during signup flow. +// signupToken is for authentication with Fleet +// enterpriseToken is for authentication with Google +func (svc *Service) EnterpriseSignupCallback(ctx context.Context, signupToken string, enterpriseToken string) error { + // Authorization is done by GetEnterpriseBySignupToken below. + // We call SkipAuthorization here to avoid explicitly calling it when errors occur. + // Also, this method call will fail if Proxy (Google Project) is not configured. svc.authz.SkipAuthorization(ctx) appConfig, err := svc.checkIfAndroidAlreadyConfigured(ctx) @@ -154,11 +164,10 @@ func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enter return err } - enterprise, err := svc.ds.GetEnterpriseByID(ctx, id) + enterprise, err := svc.ds.GetEnterpriseBySignupToken(ctx, signupToken) switch { case fleet.IsNotFound(err): - return fleet.NewInvalidArgumentError("id", - fmt.Sprintf("Enterprise with ID %d not found", id)).WithStatus(http.StatusNotFound) + return authz.ForbiddenWithInternal("invalid signup token", nil, nil, nil) case err != nil: return ctxerr.Wrap(ctx, err, "getting enterprise") } @@ -232,7 +241,7 @@ func (svc *Service) EnterpriseSignupCallback(ctx context.Context, id uint, enter return ctxerr.Wrapf(ctx, err, "patching %d policy", defaultAndroidPolicyID) } - err = svc.ds.DeleteOtherEnterprises(ctx, id) + err = svc.ds.DeleteOtherEnterprises(ctx, enterprise.ID) if err != nil { return ctxerr.Wrap(ctx, err, "deleting temp enterprises") }