diff --git a/Makefile b/Makefile index 1ca21c76bb..6d42996796 100644 --- a/Makefile +++ b/Makefile @@ -173,8 +173,8 @@ generate-dev: .prefix NODE_ENV=development webpack --progress --colors --watch generate-mock: .prefix - go install github.com/groob/mockimpl@latest - go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult + go install github.com/fleetdm/mockimpl@8d7943aa39d8f5f464d3d3618d9571d385f7bcc5 + go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult github.com/fleetdm/fleet/v4/server/service/mock generate-doc: .prefix go generate github.com/fleetdm/fleet/v4/server/fleet @@ -403,4 +403,4 @@ db-replica-reset: fleet # db-replica-run runs fleet serve with one main and one read MySQL instance. db-replica-run: fleet - FLEET_MYSQL_ADDRESS=127.0.0.1:3308 FLEET_MYSQL_READ_REPLICA_ADDRESS=127.0.0.1:3309 FLEET_MYSQL_READ_REPLICA_USERNAME=fleet FLEET_MYSQL_READ_REPLICA_DATABASE=fleet FLEET_MYSQL_READ_REPLICA_PASSWORD=insecure ./build/fleet serve --dev --dev_license \ No newline at end of file + FLEET_MYSQL_ADDRESS=127.0.0.1:3308 FLEET_MYSQL_READ_REPLICA_ADDRESS=127.0.0.1:3309 FLEET_MYSQL_READ_REPLICA_USERNAME=fleet FLEET_MYSQL_READ_REPLICA_DATABASE=fleet FLEET_MYSQL_READ_REPLICA_PASSWORD=insecure ./build/fleet serve --dev --dev_license diff --git a/cmd/fleet/cron.go b/cmd/fleet/cron.go index f27aaede1f..dd93ccded7 100644 --- a/cmd/fleet/cron.go +++ b/cmd/fleet/cron.go @@ -10,6 +10,8 @@ import ( "strings" "time" + apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" + eewebhooks "github.com/fleetdm/fleet/v4/ee/server/webhooks" "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/config" @@ -17,7 +19,6 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" - apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/fleetdm/fleet/v4/server/policies" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/service/externalsvc" @@ -32,9 +33,6 @@ import ( kitlog "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/hashicorp/go-multierror" - "github.com/micromdm/nanodep/godep" - nanodep_log "github.com/micromdm/nanodep/log" - depsync "github.com/micromdm/nanodep/sync" ) func errHandler(ctx context.Context, logger kitlog.Logger, msg string, err error) { @@ -801,32 +799,6 @@ func trySendStatistics(ctx context.Context, ds fleet.Datastore, frequency time.D return ds.RecordStatisticsSent(ctx) } -// NanoDEPLogger is a logger adapter for nanodep. -type NanoDEPLogger struct { - logger kitlog.Logger -} - -func NewNanoDEPLogger(logger kitlog.Logger) *NanoDEPLogger { - return &NanoDEPLogger{ - logger: logger, - } -} - -func (l *NanoDEPLogger) Info(keyvals ...interface{}) { - level.Info(l.logger).Log(keyvals...) -} - -func (l *NanoDEPLogger) Debug(keyvals ...interface{}) { - level.Debug(l.logger).Log(keyvals...) -} - -func (l *NanoDEPLogger) With(keyvals ...interface{}) nanodep_log.Logger { - newLogger := kitlog.With(l.logger, keyvals...) - return &NanoDEPLogger{ - logger: newLogger, - } -} - // newAppleMDMDEPProfileAssigner creates the schedule to run the DEP syncer+assigner. // The DEP syncer+assigner fetches devices from Apple Business Manager (aka ABM) and applies // the current configured DEP profile to them. @@ -840,65 +812,13 @@ func newAppleMDMDEPProfileAssigner( loggingDebug bool, ) (*schedule.Schedule, error) { const name = string(fleet.CronAppleMDMDEPProfileAssigner) - depClient := fleet.NewDEPClient(depStorage, ds, logger) - assignerOpts := []depsync.AssignerOption{ - depsync.WithAssignerLogger(NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-assigner"))), - } - if loggingDebug { - assignerOpts = append(assignerOpts, depsync.WithDebug()) - } - assigner := depsync.NewAssigner( - depClient, - apple_mdm.DEPName, - depStorage, - assignerOpts..., - ) - syncer := depsync.NewSyncer( - depClient, - apple_mdm.DEPName, - depStorage, - depsync.WithLogger(NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-syncer"))), - depsync.WithCallback(func(ctx context.Context, isFetch bool, resp *godep.DeviceResponse) error { - n, err := ds.IngestMDMAppleDevicesFromDEPSync(ctx, resp.Devices) - switch { - case err != nil: - level.Error(kitlog.With(logger, "cron", name, "component", "nanodep-syncer")).Log("err", err) - sentry.CaptureException(err) - case n > 0: - level.Info(kitlog.With(logger, "cron", name, "component", "nanodep-syncer")).Log("msg", fmt.Sprintf("added %d new mdm device(s) to pending hosts", n)) - default: - // ok - } - - return assigner.ProcessDeviceResponse(ctx, resp) - }), - ) - logger = kitlog.With(logger, "cron", name) + logger = kitlog.With(logger, "cron", name, "component", "nanodep-syncer") + fleetSyncer := apple_mdm.NewDEPSyncer(ds, depStorage, logger, loggingDebug) s := schedule.New( ctx, name, instanceID, periodicity, ds, ds, schedule.WithLogger(logger), schedule.WithJob("dep_syncer", func(ctx context.Context) error { - profileUUID, profileModTime, err := depStorage.RetrieveAssignerProfile(ctx, apple_mdm.DEPName) - if err != nil { - return err - } - if profileUUID == "" { - logger.Log("msg", "DEP profile not set, nothing to do") - return nil - } - cursor, cursorModTime, err := depStorage.RetrieveCursor(ctx, apple_mdm.DEPName) - if err != nil { - return err - } - // If the DEP Profile was changed since last sync then we clear - // the cursor and perform a full sync of all devices and profile assigning. - if cursor != "" && profileModTime.After(cursorModTime) { - logger.Log("msg", "clearing device syncer cursor") - if err := depStorage.StoreCursor(ctx, apple_mdm.DEPName, ""); err != nil { - return err - } - } - return syncer.Run(ctx) + return fleetSyncer.Run(ctx) }), ) diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index 34edc29077..8a9ed5798e 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -19,6 +19,8 @@ import ( "syscall" "time" + scep_depot "github.com/micromdm/scep/v2/depot" + "github.com/WatchBeam/clock" "github.com/e-dard/netbug" "github.com/fleetdm/fleet/v4/ee/server/licensing" @@ -40,7 +42,6 @@ import ( "github.com/fleetdm/fleet/v4/server/live_query" "github.com/fleetdm/fleet/v4/server/logging" "github.com/fleetdm/fleet/v4/server/mail" - apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/fleetdm/fleet/v4/server/pubsub" "github.com/fleetdm/fleet/v4/server/service" "github.com/fleetdm/fleet/v4/server/service/async" @@ -468,7 +469,7 @@ the way that the Fleet server works. } var ( - scepStorage *apple_mdm.SCEPMySQLDepot + scepStorage scep_depot.Depot appleSCEPCertPEM []byte appleSCEPKeyPEM []byte appleAPNsCertPEM []byte @@ -545,7 +546,7 @@ the way that the Fleet server works. initFatal(errors.New("Apple BM configuration must be provided to enable MDM"), "validate Apple MDM") } - scepStorage, err = mds.NewMDMAppleSCEPDepot(appleSCEPCertPEM, appleSCEPKeyPEM) + scepStorage, err = mds.NewSCEPDepot(appleSCEPCertPEM, appleSCEPKeyPEM) if err != nil { initFatal(err, "initialize mdm apple scep storage") } diff --git a/cmd/fleetctl/apple_mdm_test.go b/cmd/fleetctl/apple_mdm_test.go new file mode 100644 index 0000000000..4dfb429e95 --- /dev/null +++ b/cmd/fleetctl/apple_mdm_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGenerateMDMAppleBM(t *testing.T) { + outdir, err := os.MkdirTemp("", t.Name()) + require.NoError(t, err) + defer os.Remove(outdir) + publicKeyPath := filepath.Join(outdir, "public-key.crt") + privateKeyPath := filepath.Join(outdir, "private-key.key") + out := runAppForTest(t, []string{ + "generate", "mdm-apple-bm", + "--public-key", publicKeyPath, + "--private-key", privateKeyPath, + }) + + require.Contains(t, out, fmt.Sprintf("Generated your public key at %s", outdir)) + require.Contains(t, out, fmt.Sprintf("Generated your private key at %s", outdir)) + + // validate that the keypair is valid + cert, err := tls.LoadX509KeyPair(publicKeyPath, privateKeyPath) + require.NoError(t, err) + + parsed, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "FleetDM", parsed.Issuer.CommonName) +} + +func TestGenerateMDMApple(t *testing.T) { + t.Run("missing input", func(t *testing.T) { + runAppCheckErr(t, []string{"generate", "mdm-apple"}, `Required flags "email, org" not set`) + runAppCheckErr(t, []string{"generate", "mdm-apple", "--email", "user@example.com"}, `Required flag "org" not set`) + runAppCheckErr(t, []string{"generate", "mdm-apple", "--org", "Acme"}, `Required flag "email" not set`) + }) + + t.Run("CSR API call fails", func(t *testing.T) { + _, _ = runServerWithMockedDS(t) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // fail this call + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad request")) + })) + t.Setenv("TEST_FLEETDM_API_URL", srv.URL) + t.Cleanup(srv.Close) + runAppCheckErr( + t, + []string{ + "generate", "mdm-apple", + "--email", "user@example.com", + "--org", "Acme", + }, + `POST /api/latest/fleet/mdm/apple/request_csr received status 502 Bad Gateway: FleetDM CSR request failed: api responded with 400: bad request`, + ) + }) + + t.Run("successful run", func(t *testing.T) { + _, _ = runServerWithMockedDS(t) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + t.Setenv("TEST_FLEETDM_API_URL", srv.URL) + t.Cleanup(srv.Close) + + outdir, err := os.MkdirTemp("", "TestGenerateMDMApple") + require.NoError(t, err) + defer os.Remove(outdir) + apnsKeyPath := filepath.Join(outdir, "apns.key") + scepCertPath := filepath.Join(outdir, "scep.crt") + scepKeyPath := filepath.Join(outdir, "scep.key") + out := runAppForTest(t, []string{ + "generate", "mdm-apple", + "--email", "user@example.com", + "--org", "Acme", + "--apns-key", apnsKeyPath, + "--scep-cert", scepCertPath, + "--scep-key", scepKeyPath, + }) + + require.Contains(t, out, fmt.Sprintf("Generated your APNs key at %s", apnsKeyPath)) + require.Contains(t, out, fmt.Sprintf("Generated your SCEP certificate at %s", scepCertPath)) + require.Contains(t, out, fmt.Sprintf("Generated your SCEP key at %s", scepKeyPath)) + + // validate that the keypair is valid + scepCrt, err := tls.LoadX509KeyPair(scepCertPath, scepKeyPath) + require.NoError(t, err) + parsed, err := x509.ParseCertificate(scepCrt.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "FleetDM", parsed.Issuer.CommonName) + }) +} diff --git a/ee/server/service/mdm.go b/ee/server/service/mdm.go index c56bf97d9e..5f33a8e65b 100644 --- a/ee/server/service/mdm.go +++ b/ee/server/service/mdm.go @@ -47,7 +47,7 @@ func (svc *Service) GetAppleBM(ctx context.Context) (*fleet.AppleBM, error) { } func getAppleBMAccountDetail(ctx context.Context, depStorage storage.AllStorage, ds fleet.Datastore, logger kitlog.Logger) (*fleet.AppleBM, error) { - depClient := fleet.NewDEPClient(depStorage, ds, logger) + depClient := apple_mdm.NewDEPClient(depStorage, ds, logger) res, err := depClient.AccountDetail(ctx, apple_mdm.DEPName) if err != nil { return nil, ctxerr.Wrap(ctx, err, "apple GET /account request failed") diff --git a/server/datastore/mysql/apple_mdm.go b/server/datastore/mysql/apple_mdm.go index b06dfbd7ba..546a8aea97 100644 --- a/server/datastore/mysql/apple_mdm.go +++ b/server/datastore/mysql/apple_mdm.go @@ -273,16 +273,19 @@ func ingestMDMAppleDeviceFromCheckinDB( case err != nil: return ctxerr.Wrap(ctx, err, "get mdm apple host by serial number or udid") - case foundHost.HardwareSerial != mdmHost.SerialNumber || foundHost.UUID != mdmHost.UDID: - return updateMDMAppleHostDB(ctx, tx, foundHost.ID, mdmHost) - default: - // ok, nothing to do here - return nil + return updateMDMAppleHostDB(ctx, tx, foundHost.ID, mdmHost, appCfg) + } } -func updateMDMAppleHostDB(ctx context.Context, tx sqlx.ExtContext, hostID uint, mdmHost fleet.MDMAppleHostDetails) error { +func updateMDMAppleHostDB( + ctx context.Context, + tx sqlx.ExtContext, + hostID uint, + mdmHost fleet.MDMAppleHostDetails, + appCfg *fleet.AppConfig, +) error { updateStmt := ` UPDATE hosts SET hardware_serial = ?, @@ -310,6 +313,10 @@ func updateMDMAppleHostDB(ctx context.Context, tx sqlx.ExtContext, hostID uint, return ctxerr.Wrap(ctx, err, "update mdm apple host") } + if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, hostID); err != nil { + return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info") + } + return nil } @@ -365,7 +372,7 @@ func insertMDMAppleHostDB( return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert label membership") } - if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, host); err != nil { + if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, host.ID); err != nil { return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info") } return nil @@ -468,7 +475,11 @@ func (ds *Datastore) IngestMDMAppleDevicesFromDEPSync(ctx context.Context, devic return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert label membership") } - if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, true, hosts...); err != nil { + var ids []uint + for _, h := range hosts { + ids = append(ids, h.ID) + } + if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, true, ids...); err != nil { return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info") } @@ -497,7 +508,7 @@ func upsertMDMAppleHostDisplayNamesDB(ctx context.Context, tx sqlx.ExtContext, h return nil } -func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, serverURL string, fromSync bool, hosts ...fleet.Host) error { +func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, serverURL string, fromSync bool, hostIDs ...uint) error { result, err := tx.ExecContext(ctx, ` INSERT INTO mobile_device_management_solutions (name, server_url) VALUES (?, ?) ON DUPLICATE KEY UPDATE server_url = VALUES(server_url)`, @@ -522,8 +533,8 @@ func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, server args := []interface{}{} parts := []string{} - for _, h := range hosts { - args = append(args, enrolled, serverURL, fromSync, mdmID, false, h.ID) + for _, id := range hostIDs { + args = append(args, enrolled, serverURL, fromSync, mdmID, false, id) parts = append(parts, "(?, ?, ?, ?, ?, ?)") } diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index a0f39b7214..2226d23e85 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -26,7 +26,6 @@ import ( "github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data" "github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables" "github.com/fleetdm/fleet/v4/server/fleet" - apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/fleetdm/goose" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" @@ -36,6 +35,7 @@ import ( nanodep_client "github.com/micromdm/nanodep/client" nanodep_mysql "github.com/micromdm/nanodep/storage/mysql" nanomdm_mysql "github.com/micromdm/nanomdm/storage/mysql" + scep_depot "github.com/micromdm/scep/v2/depot" "github.com/ngrok/sqlmw" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" ) @@ -107,14 +107,10 @@ func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx. return stmt } -// NewMDMAppleSCEPDepot returns a *apple_mdm.MySQLDepot that uses the Datastore +// NewMDMAppleSCEPDepot returns a scep_depot.Depot that uses the Datastore // underlying MySQL writer *sql.DB. -func (ds *Datastore) NewMDMAppleSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (*apple_mdm.SCEPMySQLDepot, error) { - depot, err := apple_mdm.NewSCEPMySQLDepot(ds.writer.DB, caCertPEM, caKeyPEM) - if err != nil { - return nil, err - } - return depot, nil +func (ds *Datastore) NewSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (scep_depot.Depot, error) { + return newSCEPDepot(ds.writer.DB, caCertPEM, caKeyPEM) } // NewMDMAppleMDMStorage returns a MySQL nanomdm storage that uses the Datastore diff --git a/server/mdm/apple/scep_mysql.go b/server/datastore/mysql/scep.go similarity index 77% rename from server/mdm/apple/scep_mysql.go rename to server/datastore/mysql/scep.go index 3275c09817..8296e29c9a 100644 --- a/server/mdm/apple/scep_mysql.go +++ b/server/datastore/mysql/scep.go @@ -1,4 +1,4 @@ -package apple_mdm +package mysql import ( "crypto/rsa" @@ -11,12 +11,13 @@ import ( "fmt" "math/big" + apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/micromdm/nanomdm/cryptoutil" "github.com/micromdm/scep/v2/depot" ) -// SCEPMySQLDepot is a MySQL-backed SCEP certificate depot. -type SCEPMySQLDepot struct { +// SCEPDepot is a MySQL-backed SCEP certificate depot. +type SCEPDepot struct { db *sql.DB // caCrt holds the CA's certificate. @@ -25,10 +26,10 @@ type SCEPMySQLDepot struct { caKey *rsa.PrivateKey } -var _ depot.Depot = (*SCEPMySQLDepot)(nil) +var _ depot.Depot = (*SCEPDepot)(nil) -// NewSCEPMySQLDepot creates and returns a *SCEPMySQLDepot. -func NewSCEPMySQLDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPMySQLDepot, error) { +// newSCEPDepot creates and returns a *SCEPDepot. +func newSCEPDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPDepot, error) { if err := db.Ping(); err != nil { return nil, err } @@ -40,7 +41,7 @@ func NewSCEPMySQLDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPMySQ if err != nil { return nil, err } - return &SCEPMySQLDepot{ + return &SCEPDepot{ db: db, caCrt: caCrt, caKey: caKey, @@ -56,12 +57,12 @@ func decodeRSAKeyFromPEM(key []byte) (*rsa.PrivateKey, error) { } // CA returns the CA's certificate and private key. -func (d *SCEPMySQLDepot) CA(_ []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) { +func (d *SCEPDepot) CA(_ []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) { return []*x509.Certificate{d.caCrt}, d.caKey, nil } // Serial allocates and returns a new (increasing) serial number. -func (d *SCEPMySQLDepot) Serial() (*big.Int, error) { +func (d *SCEPDepot) Serial() (*big.Int, error) { result, err := d.db.Exec(`INSERT INTO scep_serials () VALUES ();`) if err != nil { return nil, err @@ -78,7 +79,7 @@ func (d *SCEPMySQLDepot) Serial() (*big.Int, error) { // TODO(lucas): Implement and use allowTime and revokeOldCertificate. // - allowTime are the maximum days before expiration to allow clients to do certificate renewal. // - revokeOldCertificate specifies whether to revoke the old certificate once renewed. -func (d *SCEPMySQLDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) { +func (d *SCEPDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) { var ct int row := d.db.QueryRow(`SELECT COUNT(*) FROM scep_certificates WHERE name = ?`, cn) if err := row.Scan(&ct); err != nil { @@ -91,14 +92,14 @@ func (d *SCEPMySQLDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, // // If the provided certificate has empty crt.Subject.CommonName, // then the hex sha256 of the crt.Raw is used as name. -func (d *SCEPMySQLDepot) Put(name string, crt *x509.Certificate) error { +func (d *SCEPDepot) Put(name string, crt *x509.Certificate) error { if crt.Subject.CommonName == "" { name = fmt.Sprintf("%x", sha256.Sum256(crt.Raw)) } if !crt.SerialNumber.IsInt64() { return errors.New("cannot represent serial number as int64") } - certPEM := EncodeCertPEM(crt) + certPEM := apple_mdm.EncodeCertPEM(crt) _, err := d.db.Exec(` INSERT INTO scep_certificates (serial, name, not_valid_before, not_valid_after, certificate_pem) diff --git a/server/mdm/apple/scep_mysql_test.go b/server/datastore/mysql/scep_test.go similarity index 80% rename from server/mdm/apple/scep_mysql_test.go rename to server/datastore/mysql/scep_test.go index 8c2e23dee3..fd6770474e 100644 --- a/server/mdm/apple/scep_mysql_test.go +++ b/server/datastore/mysql/scep_test.go @@ -1,6 +1,4 @@ -// use a different package name to avoid -// import cycle -package apple_mdm_test +package mysql import ( "crypto/x509" @@ -8,19 +6,19 @@ import ( "math/big" "testing" - "github.com/fleetdm/fleet/v4/server/datastore/mysql" apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/micromdm/nanodep/tokenpki" + scep_depot "github.com/micromdm/scep/v2/depot" "github.com/stretchr/testify/require" ) -func setup(t *testing.T) *apple_mdm.SCEPMySQLDepot { - ds := mysql.CreateNamedMySQLDS(t, t.Name()) +func setup(t *testing.T) scep_depot.Depot { + ds := CreateNamedMySQLDS(t, t.Name()) cert, key, err := apple_mdm.NewSCEPCACertKey() require.NoError(t, err) publicKeyPEM := tokenpki.PEMCertificate(cert.Raw) privateKeyPEM := tokenpki.PEMRSAPrivateKey(key) - depot, err := ds.NewMDMAppleSCEPDepot(publicKeyPEM, privateKeyPEM) + depot, err := ds.NewSCEPDepot(publicKeyPEM, privateKeyPEM) require.NoError(t, err) return depot } diff --git a/server/fleet/mdm.go b/server/fleet/mdm.go index a6d1e058bd..0f2a83dcdb 100644 --- a/server/fleet/mdm.go +++ b/server/fleet/mdm.go @@ -3,11 +3,6 @@ package fleet import ( "context" "time" - - "github.com/fleetdm/fleet/v4/pkg/fleethttp" - kitlog "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" - "github.com/micromdm/nanodep/godep" ) type AppleMDM struct { @@ -52,42 +47,3 @@ type AppConfigUpdater interface { AppConfig(ctx context.Context) (*AppConfig, error) SaveAppConfig(ctx context.Context, info *AppConfig) error } - -// NewDEPClient creates an Apple DEP API HTTP client based on the provided -// storage that will flag the AppConfig's AppleBMTermsExpired field whenever -// the status of the terms changes. -func NewDEPClient(storage godep.ClientStorage, appCfgUpdater AppConfigUpdater, logger kitlog.Logger) *godep.Client { - return godep.NewClient(storage, fleethttp.NewClient(), godep.WithAfterHook(func(ctx context.Context, reqErr error) error { - // if the request failed due to terms not signed, or if it succeeded, - // update the app config flag accordingly. If it failed for any other - // reason, do not update the flag. - termsExpired := reqErr != nil && godep.IsTermsNotSigned(reqErr) - if reqErr == nil || termsExpired { - appCfg, err := appCfgUpdater.AppConfig(ctx) - if err != nil { - level.Error(logger).Log("msg", "Apple DEP client: failed to get app config", "err", err) - return reqErr - } - - var mustSaveAppCfg bool - if termsExpired && !appCfg.MDM.AppleBMTermsExpired { - // flag the AppConfig that the terms have changed and must be accepted - appCfg.MDM.AppleBMTermsExpired = true - mustSaveAppCfg = true - } else if reqErr == nil && appCfg.MDM.AppleBMTermsExpired { - // flag the AppConfig that the terms have been accepted - appCfg.MDM.AppleBMTermsExpired = false - mustSaveAppCfg = true - } - - if mustSaveAppCfg { - if err := appCfgUpdater.SaveAppConfig(ctx, appCfg); err != nil { - level.Error(logger).Log("msg", "Apple DEP client: failed to save app config", "err", err) - } - level.Debug(logger).Log("msg", "Apple DEP client: updated app config Terms Expired flag", - "apple_bm_terms_expired", appCfg.MDM.AppleBMTermsExpired) - } - } - return reqErr - })) -} diff --git a/server/fleet/mdm_test.go b/server/fleet/mdm_test.go index a12cfd3b1b..e4be53b9f6 100644 --- a/server/fleet/mdm_test.go +++ b/server/fleet/mdm_test.go @@ -11,25 +11,13 @@ import ( "github.com/fleetdm/fleet/v4/server/fleet" apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/fleetdm/fleet/v4/server/mock" + nanodep_mock "github.com/fleetdm/fleet/v4/server/mock/nanodep" "github.com/go-kit/log" nanodep_client "github.com/micromdm/nanodep/client" "github.com/micromdm/nanodep/godep" "github.com/stretchr/testify/require" ) -type mockStorage struct { - token string - url string -} - -func (s mockStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) { - return &nanodep_client.OAuth1Tokens{AccessToken: s.token}, nil -} - -func (s mockStorage) RetrieveConfig(context.Context, string) (*nanodep_client.Config, error) { - return &nanodep_client.Config{BaseURL: s.url}, nil -} - func TestDEPClient(t *testing.T) { ctx := context.Background() @@ -141,8 +129,15 @@ func TestDEPClient(t *testing.T) { for i, c := range cases { t.Logf("case %d", i) - store := mockStorage{token: c.token, url: srv.URL} - dep := fleet.NewDEPClient(store, ds, logger) + store := &nanodep_mock.Storage{} + store.RetrieveAuthTokensFunc = func(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) { + return &nanodep_client.OAuth1Tokens{AccessToken: c.token}, nil + } + store.RetrieveConfigFunc = func(context.Context, string) (*nanodep_client.Config, error) { + return &nanodep_client.Config{BaseURL: srv.URL}, nil + } + + dep := apple_mdm.NewDEPClient(store, ds, logger) res, err := dep.AccountDetail(ctx, apple_mdm.DEPName) if c.wantErr { @@ -163,6 +158,8 @@ func TestDEPClient(t *testing.T) { } else { require.NoError(t, err) require.Equal(t, "test", res.AdminID) + require.True(t, store.RetrieveAuthTokensFuncInvoked) + require.True(t, store.RetrieveConfigFuncInvoked) } checkDSCalled(c.readInvoked, c.writeInvoked) require.Equal(t, c.termsFlag, appCfg.MDM.AppleBMTermsExpired) diff --git a/server/logging/nanodep.go b/server/logging/nanodep.go new file mode 100644 index 0000000000..cfc40dcabb --- /dev/null +++ b/server/logging/nanodep.go @@ -0,0 +1,33 @@ +package logging + +import ( + kitlog "github.com/go-kit/kit/log" + "github.com/go-kit/log/level" + nanodep_log "github.com/micromdm/nanodep/log" +) + +// NanoDEPLogger is a logger adapter for nanodep. +type NanoDEPLogger struct { + logger kitlog.Logger +} + +func NewNanoDEPLogger(logger kitlog.Logger) *NanoDEPLogger { + return &NanoDEPLogger{ + logger: logger, + } +} + +func (l *NanoDEPLogger) Info(keyvals ...interface{}) { + level.Info(l.logger).Log(keyvals...) +} + +func (l *NanoDEPLogger) Debug(keyvals ...interface{}) { + level.Debug(l.logger).Log(keyvals...) +} + +func (l *NanoDEPLogger) With(keyvals ...interface{}) nanodep_log.Logger { + newLogger := kitlog.With(l.logger, keyvals...) + return &NanoDEPLogger{ + logger: newLogger, + } +} diff --git a/server/mdm/apple/apple_mdm.go b/server/mdm/apple/apple_mdm.go index e2a7a5db45..d27d411631 100644 --- a/server/mdm/apple/apple_mdm.go +++ b/server/mdm/apple/apple_mdm.go @@ -1,8 +1,21 @@ package apple_mdm import ( + "context" + "fmt" "net/url" "path" + + "github.com/fleetdm/fleet/v4/pkg/fleethttp" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/logging" + "github.com/getsentry/sentry-go" + "github.com/go-kit/log/level" + "github.com/micromdm/nanodep/godep" + + kitlog "github.com/go-kit/kit/log" + nanodep_storage "github.com/micromdm/nanodep/storage" + depsync "github.com/micromdm/nanodep/sync" ) // DEPName is the identifier/name used in nanodep MySQL storage which @@ -47,3 +60,118 @@ func resolveURL(serverURL, relPath string) (string, error) { u.Path = path.Join(u.Path, relPath) return u.String(), nil } + +type DEPSyncer struct { + depStorage nanodep_storage.AllStorage + syncer *depsync.Syncer + logger kitlog.Logger +} + +func (d *DEPSyncer) Run(ctx context.Context) error { + profileUUID, profileModTime, err := d.depStorage.RetrieveAssignerProfile(ctx, DEPName) + if err != nil { + return err + } + if profileUUID == "" { + d.logger.Log("msg", "DEP profile not set, nothing to do") + return nil + } + cursor, cursorModTime, err := d.depStorage.RetrieveCursor(ctx, DEPName) + if err != nil { + return err + } + // If the DEP Profile was changed since last sync then we clear + // the cursor and perform a full sync of all devices and profile assigning. + if cursor != "" && profileModTime.After(cursorModTime) { + d.logger.Log("msg", "clearing device syncer cursor") + if err := d.depStorage.StoreCursor(ctx, DEPName, ""); err != nil { + return err + } + } + return d.syncer.Run(ctx) +} + +func NewDEPSyncer( + ds fleet.Datastore, + depStorage nanodep_storage.AllStorage, + logger kitlog.Logger, + loggingDebug bool, +) *DEPSyncer { + depClient := NewDEPClient(depStorage, ds, logger) + assignerOpts := []depsync.AssignerOption{ + depsync.WithAssignerLogger(logging.NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-assigner"))), + } + if loggingDebug { + assignerOpts = append(assignerOpts, depsync.WithDebug()) + } + assigner := depsync.NewAssigner( + depClient, + DEPName, + depStorage, + assignerOpts..., + ) + + syncer := depsync.NewSyncer( + depClient, + DEPName, + depStorage, + depsync.WithLogger(logging.NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-syncer"))), + depsync.WithCallback(func(ctx context.Context, isFetch bool, resp *godep.DeviceResponse) error { + n, err := ds.IngestMDMAppleDevicesFromDEPSync(ctx, resp.Devices) + switch { + case err != nil: + level.Error(kitlog.With(logger)).Log("err", err) + sentry.CaptureException(err) + case n > 0: + level.Info(kitlog.With(logger)).Log("msg", fmt.Sprintf("added %d new mdm device(s) to pending hosts", n)) + } + + return assigner.ProcessDeviceResponse(ctx, resp) + }), + ) + + return &DEPSyncer{ + syncer: syncer, + depStorage: depStorage, + logger: logger, + } +} + +// NewDEPClient creates an Apple DEP API HTTP client based on the provided +// storage that will flag the AppConfig's AppleBMTermsExpired field +// whenever the status of the terms changes. +func NewDEPClient(storage godep.ClientStorage, appCfgUpdater fleet.AppConfigUpdater, logger kitlog.Logger) *godep.Client { + return godep.NewClient(storage, fleethttp.NewClient(), godep.WithAfterHook(func(ctx context.Context, reqErr error) error { + // if the request failed due to terms not signed, or if it succeeded, + // update the app config flag accordingly. If it failed for any other + // reason, do not update the flag. + termsExpired := reqErr != nil && godep.IsTermsNotSigned(reqErr) + if reqErr == nil || termsExpired { + appCfg, err := appCfgUpdater.AppConfig(ctx) + if err != nil { + level.Error(logger).Log("msg", "Apple DEP client: failed to get app config", "err", err) + return reqErr + } + + var mustSaveAppCfg bool + if termsExpired && !appCfg.MDM.AppleBMTermsExpired { + // flag the AppConfig that the terms have changed and must be accepted + appCfg.MDM.AppleBMTermsExpired = true + mustSaveAppCfg = true + } else if reqErr == nil && appCfg.MDM.AppleBMTermsExpired { + // flag the AppConfig that the terms have been accepted + appCfg.MDM.AppleBMTermsExpired = false + mustSaveAppCfg = true + } + + if mustSaveAppCfg { + if err := appCfgUpdater.SaveAppConfig(ctx, appCfg); err != nil { + level.Error(logger).Log("msg", "Apple DEP client: failed to save app config", "err", err) + } + level.Debug(logger).Log("msg", "Apple DEP client: updated app config Terms Expired flag", + "apple_bm_terms_expired", appCfg.MDM.AppleBMTermsExpired) + } + } + return reqErr + })) +} diff --git a/server/mock/datastore.go b/server/mock/datastore.go index 3a1bd8045f..7690c4fe35 100644 --- a/server/mock/datastore.go +++ b/server/mock/datastore.go @@ -8,6 +8,9 @@ import ( //go:generate mockimpl -o datastore_mock.go "s *DataStore" "fleet.Datastore" //go:generate mockimpl -o datastore_installers.go "s *InstallerStore" "fleet.InstallerStore" +//go:generate mockimpl -o nanomdm/storage.go "s *Storage" "github.com/micromdm/nanomdm/storage.AllStorage" +//go:generate mockimpl -o nanodep/storage.go "s *Storage" "github.com/micromdm/nanodep/storage.AllStorage" +//go:generate mockimpl -o scep/depot.go "d *Depot" "depot.Depot" var _ fleet.Datastore = (*Store)(nil) diff --git a/server/mock/nanodep/storage.go b/server/mock/nanodep/storage.go new file mode 100644 index 0000000000..314dce0834 --- /dev/null +++ b/server/mock/nanodep/storage.go @@ -0,0 +1,115 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "context" + "time" + + "github.com/micromdm/nanodep/client" + "github.com/micromdm/nanodep/storage" +) + +var _ storage.AllStorage = (*Storage)(nil) + +type RetrieveAuthTokensFunc func(ctx context.Context, name string) (*client.OAuth1Tokens, error) + +type RetrieveConfigFunc func(p0 context.Context, p1 string) (*client.Config, error) + +type RetrieveAssignerProfileFunc func(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error) + +type RetrieveCursorFunc func(ctx context.Context, name string) (cursor string, modTime time.Time, err error) + +type StoreCursorFunc func(ctx context.Context, name string, cursor string) error + +type StoreAuthTokensFunc func(ctx context.Context, name string, tokens *client.OAuth1Tokens) error + +type StoreConfigFunc func(ctx context.Context, name string, config *client.Config) error + +type StoreTokenPKIFunc func(ctx context.Context, name string, pemCert []byte, pemKey []byte) error + +type RetrieveTokenPKIFunc func(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error) + +type StoreAssignerProfileFunc func(ctx context.Context, name string, profileUUID string) error + +type Storage struct { + RetrieveAuthTokensFunc RetrieveAuthTokensFunc + RetrieveAuthTokensFuncInvoked bool + + RetrieveConfigFunc RetrieveConfigFunc + RetrieveConfigFuncInvoked bool + + RetrieveAssignerProfileFunc RetrieveAssignerProfileFunc + RetrieveAssignerProfileFuncInvoked bool + + RetrieveCursorFunc RetrieveCursorFunc + RetrieveCursorFuncInvoked bool + + StoreCursorFunc StoreCursorFunc + StoreCursorFuncInvoked bool + + StoreAuthTokensFunc StoreAuthTokensFunc + StoreAuthTokensFuncInvoked bool + + StoreConfigFunc StoreConfigFunc + StoreConfigFuncInvoked bool + + StoreTokenPKIFunc StoreTokenPKIFunc + StoreTokenPKIFuncInvoked bool + + RetrieveTokenPKIFunc RetrieveTokenPKIFunc + RetrieveTokenPKIFuncInvoked bool + + StoreAssignerProfileFunc StoreAssignerProfileFunc + StoreAssignerProfileFuncInvoked bool +} + +func (s *Storage) RetrieveAuthTokens(ctx context.Context, name string) (*client.OAuth1Tokens, error) { + s.RetrieveAuthTokensFuncInvoked = true + return s.RetrieveAuthTokensFunc(ctx, name) +} + +func (s *Storage) RetrieveConfig(p0 context.Context, p1 string) (*client.Config, error) { + s.RetrieveConfigFuncInvoked = true + return s.RetrieveConfigFunc(p0, p1) +} + +func (s *Storage) RetrieveAssignerProfile(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error) { + s.RetrieveAssignerProfileFuncInvoked = true + return s.RetrieveAssignerProfileFunc(ctx, name) +} + +func (s *Storage) RetrieveCursor(ctx context.Context, name string) (cursor string, modTime time.Time, err error) { + s.RetrieveCursorFuncInvoked = true + return s.RetrieveCursorFunc(ctx, name) +} + +func (s *Storage) StoreCursor(ctx context.Context, name string, cursor string) error { + s.StoreCursorFuncInvoked = true + return s.StoreCursorFunc(ctx, name, cursor) +} + +func (s *Storage) StoreAuthTokens(ctx context.Context, name string, tokens *client.OAuth1Tokens) error { + s.StoreAuthTokensFuncInvoked = true + return s.StoreAuthTokensFunc(ctx, name, tokens) +} + +func (s *Storage) StoreConfig(ctx context.Context, name string, config *client.Config) error { + s.StoreConfigFuncInvoked = true + return s.StoreConfigFunc(ctx, name, config) +} + +func (s *Storage) StoreTokenPKI(ctx context.Context, name string, pemCert []byte, pemKey []byte) error { + s.StoreTokenPKIFuncInvoked = true + return s.StoreTokenPKIFunc(ctx, name, pemCert, pemKey) +} + +func (s *Storage) RetrieveTokenPKI(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error) { + s.RetrieveTokenPKIFuncInvoked = true + return s.RetrieveTokenPKIFunc(ctx, name) +} + +func (s *Storage) StoreAssignerProfile(ctx context.Context, name string, profileUUID string) error { + s.StoreAssignerProfileFuncInvoked = true + return s.StoreAssignerProfileFunc(ctx, name, profileUUID) +} diff --git a/server/mock/nanomdm/storage.go b/server/mock/nanomdm/storage.go new file mode 100644 index 0000000000..633ed11719 --- /dev/null +++ b/server/mock/nanomdm/storage.go @@ -0,0 +1,215 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "context" + "crypto/tls" + + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" +) + +var _ storage.AllStorage = (*Storage)(nil) + +type StoreAuthenticateFunc func(r *mdm.Request, msg *mdm.Authenticate) error + +type StoreTokenUpdateFunc func(r *mdm.Request, msg *mdm.TokenUpdate) error + +type StoreUserAuthenticateFunc func(r *mdm.Request, msg *mdm.UserAuthenticate) error + +type DisableFunc func(r *mdm.Request) error + +type StoreCommandReportFunc func(r *mdm.Request, report *mdm.CommandResults) error + +type RetrieveNextCommandFunc func(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) + +type ClearQueueFunc func(r *mdm.Request) error + +type StoreBootstrapTokenFunc func(r *mdm.Request, msg *mdm.SetBootstrapToken) error + +type RetrieveBootstrapTokenFunc func(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) + +type RetrievePushInfoFunc func(p0 context.Context, p1 []string) (map[string]*mdm.Push, error) + +type IsPushCertStaleFunc func(ctx context.Context, topic string, staleToken string) (bool, error) + +type RetrievePushCertFunc func(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) + +type StorePushCertFunc func(ctx context.Context, pemCert []byte, pemKey []byte) error + +type EnqueueCommandFunc func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) + +type HasCertHashFunc func(r *mdm.Request, hash string) (bool, error) + +type EnrollmentHasCertHashFunc func(r *mdm.Request, hash string) (bool, error) + +type IsCertHashAssociatedFunc func(r *mdm.Request, hash string) (bool, error) + +type AssociateCertHashFunc func(r *mdm.Request, hash string) error + +type RetrieveMigrationCheckinsFunc func(p0 context.Context, p1 chan<- interface{}) error + +type RetrieveTokenUpdateTallyFunc func(ctx context.Context, id string) (int, error) + +type Storage struct { + StoreAuthenticateFunc StoreAuthenticateFunc + StoreAuthenticateFuncInvoked bool + + StoreTokenUpdateFunc StoreTokenUpdateFunc + StoreTokenUpdateFuncInvoked bool + + StoreUserAuthenticateFunc StoreUserAuthenticateFunc + StoreUserAuthenticateFuncInvoked bool + + DisableFunc DisableFunc + DisableFuncInvoked bool + + StoreCommandReportFunc StoreCommandReportFunc + StoreCommandReportFuncInvoked bool + + RetrieveNextCommandFunc RetrieveNextCommandFunc + RetrieveNextCommandFuncInvoked bool + + ClearQueueFunc ClearQueueFunc + ClearQueueFuncInvoked bool + + StoreBootstrapTokenFunc StoreBootstrapTokenFunc + StoreBootstrapTokenFuncInvoked bool + + RetrieveBootstrapTokenFunc RetrieveBootstrapTokenFunc + RetrieveBootstrapTokenFuncInvoked bool + + RetrievePushInfoFunc RetrievePushInfoFunc + RetrievePushInfoFuncInvoked bool + + IsPushCertStaleFunc IsPushCertStaleFunc + IsPushCertStaleFuncInvoked bool + + RetrievePushCertFunc RetrievePushCertFunc + RetrievePushCertFuncInvoked bool + + StorePushCertFunc StorePushCertFunc + StorePushCertFuncInvoked bool + + EnqueueCommandFunc EnqueueCommandFunc + EnqueueCommandFuncInvoked bool + + HasCertHashFunc HasCertHashFunc + HasCertHashFuncInvoked bool + + EnrollmentHasCertHashFunc EnrollmentHasCertHashFunc + EnrollmentHasCertHashFuncInvoked bool + + IsCertHashAssociatedFunc IsCertHashAssociatedFunc + IsCertHashAssociatedFuncInvoked bool + + AssociateCertHashFunc AssociateCertHashFunc + AssociateCertHashFuncInvoked bool + + RetrieveMigrationCheckinsFunc RetrieveMigrationCheckinsFunc + RetrieveMigrationCheckinsFuncInvoked bool + + RetrieveTokenUpdateTallyFunc RetrieveTokenUpdateTallyFunc + RetrieveTokenUpdateTallyFuncInvoked bool +} + +func (s *Storage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error { + s.StoreAuthenticateFuncInvoked = true + return s.StoreAuthenticateFunc(r, msg) +} + +func (s *Storage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { + s.StoreTokenUpdateFuncInvoked = true + return s.StoreTokenUpdateFunc(r, msg) +} + +func (s *Storage) StoreUserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) error { + s.StoreUserAuthenticateFuncInvoked = true + return s.StoreUserAuthenticateFunc(r, msg) +} + +func (s *Storage) Disable(r *mdm.Request) error { + s.DisableFuncInvoked = true + return s.DisableFunc(r) +} + +func (s *Storage) StoreCommandReport(r *mdm.Request, report *mdm.CommandResults) error { + s.StoreCommandReportFuncInvoked = true + return s.StoreCommandReportFunc(r, report) +} + +func (s *Storage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) { + s.RetrieveNextCommandFuncInvoked = true + return s.RetrieveNextCommandFunc(r, skipNotNow) +} + +func (s *Storage) ClearQueue(r *mdm.Request) error { + s.ClearQueueFuncInvoked = true + return s.ClearQueueFunc(r) +} + +func (s *Storage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error { + s.StoreBootstrapTokenFuncInvoked = true + return s.StoreBootstrapTokenFunc(r, msg) +} + +func (s *Storage) RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { + s.RetrieveBootstrapTokenFuncInvoked = true + return s.RetrieveBootstrapTokenFunc(r, msg) +} + +func (s *Storage) RetrievePushInfo(p0 context.Context, p1 []string) (map[string]*mdm.Push, error) { + s.RetrievePushInfoFuncInvoked = true + return s.RetrievePushInfoFunc(p0, p1) +} + +func (s *Storage) IsPushCertStale(ctx context.Context, topic string, staleToken string) (bool, error) { + s.IsPushCertStaleFuncInvoked = true + return s.IsPushCertStaleFunc(ctx, topic, staleToken) +} + +func (s *Storage) RetrievePushCert(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) { + s.RetrievePushCertFuncInvoked = true + return s.RetrievePushCertFunc(ctx, topic) +} + +func (s *Storage) StorePushCert(ctx context.Context, pemCert []byte, pemKey []byte) error { + s.StorePushCertFuncInvoked = true + return s.StorePushCertFunc(ctx, pemCert, pemKey) +} + +func (s *Storage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { + s.EnqueueCommandFuncInvoked = true + return s.EnqueueCommandFunc(ctx, id, cmd) +} + +func (s *Storage) HasCertHash(r *mdm.Request, hash string) (bool, error) { + s.HasCertHashFuncInvoked = true + return s.HasCertHashFunc(r, hash) +} + +func (s *Storage) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) { + s.EnrollmentHasCertHashFuncInvoked = true + return s.EnrollmentHasCertHashFunc(r, hash) +} + +func (s *Storage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) { + s.IsCertHashAssociatedFuncInvoked = true + return s.IsCertHashAssociatedFunc(r, hash) +} + +func (s *Storage) AssociateCertHash(r *mdm.Request, hash string) error { + s.AssociateCertHashFuncInvoked = true + return s.AssociateCertHashFunc(r, hash) +} + +func (s *Storage) RetrieveMigrationCheckins(p0 context.Context, p1 chan<- interface{}) error { + s.RetrieveMigrationCheckinsFuncInvoked = true + return s.RetrieveMigrationCheckinsFunc(p0, p1) +} + +func (s *Storage) RetrieveTokenUpdateTally(ctx context.Context, id string) (int, error) { + s.RetrieveTokenUpdateTallyFuncInvoked = true + return s.RetrieveTokenUpdateTallyFunc(ctx, id) +} diff --git a/server/mock/scep/depot.go b/server/mock/scep/depot.go new file mode 100644 index 0000000000..24eb971a99 --- /dev/null +++ b/server/mock/scep/depot.go @@ -0,0 +1,55 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "crypto/rsa" + "crypto/x509" + "math/big" + + "github.com/micromdm/scep/v2/depot" +) + +var _ depot.Depot = (*Depot)(nil) + +type CAFunc func(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) + +type PutFunc func(name string, crt *x509.Certificate) error + +type SerialFunc func() (*big.Int, error) + +type HasCNFunc func(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) + +type Depot struct { + CAFunc CAFunc + CAFuncInvoked bool + + PutFunc PutFunc + PutFuncInvoked bool + + SerialFunc SerialFunc + SerialFuncInvoked bool + + HasCNFunc HasCNFunc + HasCNFuncInvoked bool +} + +func (d *Depot) CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) { + d.CAFuncInvoked = true + return d.CAFunc(pass) +} + +func (d *Depot) Put(name string, crt *x509.Certificate) error { + d.PutFuncInvoked = true + return d.PutFunc(name, crt) +} + +func (d *Depot) Serial() (*big.Int, error) { + d.SerialFuncInvoked = true + return d.SerialFunc() +} + +func (d *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) { + d.HasCNFuncInvoked = true + return d.HasCNFunc(cn, allowTime, cert, revokeOldCertificate) +} diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index c63d33c744..d237734f1f 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -121,7 +121,7 @@ func (svc *Service) setDEPProfile(ctx context.Context, enrollmentProfile *fleet. depProfileRequest.URL = enrollURL depProfileRequest.ConfigurationWebURL = enrollURL - depClient := fleet.NewDEPClient(svc.depStorage, svc.ds, svc.logger) + depClient := apple_mdm.NewDEPClient(svc.depStorage, svc.ds, svc.logger) res, err := depClient.DefineProfile(ctx, apple_mdm.DEPName, &depProfileRequest) if err != nil { return ctxerr.Wrap(ctx, err, "apple POST /profile request failed") @@ -430,7 +430,7 @@ func (svc *Service) ListMDMAppleDEPDevices(ctx context.Context) ([]fleet.MDMAppl if err := svc.authz.Authorize(ctx, &fleet.MDMAppleDEPDevice{}, fleet.ActionWrite); err != nil { return nil, ctxerr.Wrap(ctx, err) } - depClient := fleet.NewDEPClient(svc.depStorage, svc.ds, svc.logger) + depClient := apple_mdm.NewDEPClient(svc.depStorage, svc.ds, svc.logger) // TODO(lucas): Use cursors and limit to fetch in multiple requests. // This single-request version supports up to 1000 devices (max to return in one call). diff --git a/server/service/apple_mdm_test.go b/server/service/apple_mdm_test.go index 5e30285b45..68857a2ad5 100644 --- a/server/service/apple_mdm_test.go +++ b/server/service/apple_mdm_test.go @@ -3,9 +3,11 @@ package service import ( "bytes" "context" + "crypto/tls" "fmt" "net/http" "net/http/httptest" + "os" "strings" "sync/atomic" "testing" @@ -13,50 +15,20 @@ import ( "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/contexts/viewer" - "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/mock" + nanodep_mock "github.com/fleetdm/fleet/v4/server/mock/nanodep" + nanomdm_mock "github.com/fleetdm/fleet/v4/server/mock/nanomdm" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" - "github.com/google/uuid" + kitlog "github.com/go-kit/kit/log" nanodep_client "github.com/micromdm/nanodep/client" - nanodep_storage "github.com/micromdm/nanodep/storage" "github.com/micromdm/nanomdm/mdm" - nanomdm_push "github.com/micromdm/nanomdm/push" - "github.com/micromdm/nanomdm/storage" + nanomdm_pushsvc "github.com/micromdm/nanomdm/push/service" "github.com/stretchr/testify/require" ) -type dummyDEPStorage struct { - nanodep_storage.AllStorage - testAuthAddr string -} - -func (d dummyDEPStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) { - return &nanodep_client.OAuth1Tokens{}, nil -} - -func (d dummyDEPStorage) RetrieveConfig(context.Context, string) (*nanodep_client.Config, error) { - return &nanodep_client.Config{ - BaseURL: d.testAuthAddr, - }, nil -} - -type dummyMDMStorage struct { - *mysql.NanoMDMStorage -} - -func (d dummyMDMStorage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { - return nil, nil -} - -type dummyMDMPusher struct{} - -func (d dummyMDMPusher) Push(context.Context, []string) (map[string]*nanomdm_push.Response, error) { - return nil, nil -} - -func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorage nanodep_storage.AllStorage, mdmPusher nanomdm_push.Pusher) (fleet.Service, context.Context, *mock.Store) { +func setupAppleMDMService(t *testing.T) (fleet.Service, context.Context, *mock.Store) { ds := new(mock.Store) cfg := config.TestConfig() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -72,23 +44,55 @@ func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorag } })) + mdmStorage := &nanomdm_mock.Storage{} + depStorage := &nanodep_mock.Storage{} + pushFactory, _ := newMockAPNSPushProviderFactory() + pusher := nanomdm_pushsvc.New( + mdmStorage, + mdmStorage, + pushFactory, + NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)), + ) + opts := &TestServerOpts{ FleetConfig: &cfg, - MDMStorage: dummyMDMStorage{}, - DEPStorage: dummyDEPStorage{testAuthAddr: ts.URL}, - MDMPusher: dummyMDMPusher{}, - } - if mdmStorage != nil { - opts.MDMStorage = mdmStorage - } - if depStorage != nil { - opts.DEPStorage = depStorage - } - if mdmPusher != nil { - opts.MDMPusher = mdmPusher + MDMStorage: mdmStorage, + DEPStorage: depStorage, + MDMPusher: pusher, } svc, ctx := newTestServiceWithConfig(t, ds, cfg, nil, nil, opts) + mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { + return nil, nil + } + mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, tokens []string) (map[string]*mdm.Push, error) { + res := make(map[string]*mdm.Push, len(tokens)) + for _, t := range tokens { + res[t] = &mdm.Push{ + PushMagic: "", + Token: []byte(t), + Topic: "", + } + } + return res, nil + } + mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) { + cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key") + return &cert, "", err + } + mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) { + return false, nil + } + + depStorage.RetrieveAuthTokensFunc = func(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) { + return &nanodep_client.OAuth1Tokens{}, nil + } + depStorage.RetrieveConfigFunc = func(context.Context, string) (*nanodep_client.Config, error) { + return &nanodep_client.Config{ + BaseURL: ts.URL, + }, nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{ OrgInfo: fleet.OrgInfo{ @@ -148,7 +152,7 @@ func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorag } func TestAppleMDMAuthorization(t *testing.T) { - svc, ctx, _ := setupAppleMDMService(t, nil, nil, nil) + svc, ctx, _ := setupAppleMDMService(t) checkAuthErr := func(t *testing.T, err error, shouldFailWithAuth bool) { t.Helper() @@ -252,7 +256,7 @@ func TestMDMAppleEnrollURL(t *testing.T) { } func TestAppleMDMEnrollmentProfile(t *testing.T) { - svc, ctx, _ := setupAppleMDMService(t, nil, nil, nil) + svc, ctx, _ := setupAppleMDMService(t) // Only global admins can create enrollment profiles. ctx = test.UserContext(ctx, test.UserAdmin) @@ -273,22 +277,8 @@ func TestAppleMDMEnrollmentProfile(t *testing.T) { } } -type noErrorPusher struct{} - -// Push simulates successful push responses. The result maps each of the provided deviceIDs to a -// internally generated UUID, which is intended here to mock the APNs API response. -func (nep *noErrorPusher) Push(ctx context.Context, deviceIDs []string) (map[string]*nanomdm_push.Response, error) { - res := make(map[string]*nanomdm_push.Response) - for _, s := range deviceIDs { - res[s] = &nanomdm_push.Response{Id: uuid.New().String()} - } - return res, nil -} - func TestMDMCommandAuthz(t *testing.T) { - pusher := noErrorPusher{} - - svc, ctx, ds := setupAppleMDMService(t, nil, nil, &pusher) + svc, ctx, ds := setupAppleMDMService(t) ds.HostLiteFunc = func(ctx context.Context, hostID uint) (*fleet.Host, error) { switch hostID { diff --git a/server/service/handler.go b/server/service/handler.go index 59587d6d24..305bdd87cf 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -668,7 +668,7 @@ func RegisterAppleMDMProtocolServices( mux *http.ServeMux, scepConfig config.MDMAppleSCEPConfig, mdmStorage nanomdm_storage.AllStorage, - scepStorage *apple_mdm.SCEPMySQLDepot, + scepStorage scep_depot.Depot, logger kitlog.Logger, checkinAndCommandService nanomdm_service.CheckinAndCommandService, ) error { @@ -692,7 +692,7 @@ func registerSCEP( scepConfig config.MDMAppleSCEPConfig, scepCert *x509.Certificate, scepKey *rsa.PrivateKey, - scepStorage *apple_mdm.SCEPMySQLDepot, + scepStorage scep_depot.Depot, logger kitlog.Logger, ) error { var signer scepserver.CSRSigner = scep_depot.NewSigner( diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index c98b902dbc..e3f62954a8 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -17,21 +17,30 @@ import ( "os" "strconv" "strings" + "sync" "testing" "time" + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/push" + nanomdm_pushsvc "github.com/micromdm/nanomdm/push/service" + "sync/atomic" "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/fleet" apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" + "github.com/fleetdm/fleet/v4/server/service/mock" + "github.com/fleetdm/fleet/v4/server/service/schedule" kitlog "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/google/uuid" "github.com/groob/plist" "github.com/jmoiron/sqlx" nanodep_client "github.com/micromdm/nanodep/client" + "github.com/micromdm/nanodep/godep" + nanodep_storage "github.com/micromdm/nanodep/storage" "github.com/micromdm/nanodep/tokenpki" scepclient "github.com/micromdm/scep/v2/client" "github.com/micromdm/scep/v2/cryptoutil/x509util" @@ -48,10 +57,13 @@ func TestIntegrationsMDM(t *testing.T) { } type integrationMDMTestSuite struct { - withServer suite.Suite + withServer fleetCfg config.FleetConfig fleetDMFailCSR atomic.Bool + pushProvider *mock.APNSPushProvider + depStorage nanodep_storage.AllStorage + depSchedule *schedule.Schedule } func (s *integrationMDMTestSuite) SetupSuite() { @@ -69,9 +81,18 @@ func (s *integrationMDMTestSuite) SetupSuite() { require.NoError(s.T(), err) depStorage, err := s.ds.NewMDMAppleDEPStorage(*testBMToken) require.NoError(s.T(), err) - scepStorage, err := s.ds.NewMDMAppleSCEPDepot(testCertPEM, testKeyPEM) + scepStorage, err := s.ds.NewSCEPDepot(testCertPEM, testKeyPEM) require.NoError(s.T(), err) + pushFactory, pushProvider := newMockAPNSPushProviderFactory() + mdmPushService := nanomdm_pushsvc.New( + mdmStorage, + mdmStorage, + pushFactory, + NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)), + ) + + var depSchedule *schedule.Schedule config := TestServerOpts{ License: &fleet.LicenseInfo{ Tier: fleet.TierPremium, @@ -80,7 +101,24 @@ func (s *integrationMDMTestSuite) SetupSuite() { MDMStorage: mdmStorage, DEPStorage: depStorage, SCEPStorage: scepStorage, - MDMPusher: dummyMDMPusher{}, + MDMPusher: mdmPushService, + StartCronSchedules: []TestNewScheduleFunc{ + func(ctx context.Context, ds fleet.Datastore) fleet.NewCronScheduleFunc { + return func() (fleet.CronSchedule, error) { + const name = string(fleet.CronAppleMDMDEPProfileAssigner) + logger := kitlog.NewJSONLogger(os.Stdout) + fleetSyncer := apple_mdm.NewDEPSyncer(ds, depStorage, logger, true) + depSchedule = schedule.New( + ctx, name, s.T().Name(), 1*time.Hour, ds, ds, + schedule.WithLogger(logger), + schedule.WithJob("dep_syncer", func(ctx context.Context) error { + return fleetSyncer.Run(ctx) + }), + ) + return depSchedule, nil + } + }, + }, } users, server := RunServerForTestsWithDS(s.T(), s.ds, &config) s.server = server @@ -88,8 +126,11 @@ func (s *integrationMDMTestSuite) SetupSuite() { s.token = s.getTestAdminToken() s.cachedAdminToken = s.token s.fleetCfg = fleetCfg + s.pushProvider = pushProvider + s.depStorage = depStorage + s.depSchedule = depSchedule - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fleetdmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if s.fleetDMFailCSR.Swap(false) { // fail this call w.WriteHeader(http.StatusBadRequest) @@ -99,8 +140,8 @@ func (s *integrationMDMTestSuite) SetupSuite() { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) - s.T().Setenv("TEST_FLEETDM_API_URL", srv.URL) - s.T().Cleanup(srv.Close) + s.T().Setenv("TEST_FLEETDM_API_URL", fleetdmSrv.URL) + s.T().Cleanup(fleetdmSrv.Close) } func (s *integrationMDMTestSuite) FailNextCSRRequest() { @@ -115,6 +156,18 @@ func (s *integrationMDMTestSuite) TearDownTest() { s.withServer.commonTearDownTest(s.T()) } +func (s *integrationMDMTestSuite) mockDEPResponse(handler http.Handler) { + t := s.T() + srv := httptest.NewServer(handler) + err := s.depStorage.StoreConfig(context.Background(), apple_mdm.DEPName, &nanodep_client.Config{BaseURL: srv.URL}) + require.NoError(t, err) + t.Cleanup(func() { + srv.Close() + err := s.depStorage.StoreConfig(context.Background(), apple_mdm.DEPName, &nanodep_client.Config{BaseURL: nanodep_client.DefaultBaseURL}) + require.NoError(t, err) + }) +} + func (s *integrationMDMTestSuite) TestAppleGetAppleMDM() { t := s.T() @@ -126,11 +179,173 @@ func (s *integrationMDMTestSuite) TestAppleGetAppleMDM() { require.Equal(t, "FleetDM", mdmResp.CommonName) require.NotZero(t, mdmResp.RenewDate) - // GET /api/latest/fleet/mdm/apple_bm is not tested because it makes a call - // to an Apple API that would a) fail because we use dummy token/certs and b) - // could get us in trouble with many invalid requests. - // TODO: eventually add a way to mock the apple API, maybe with a test http - // server running and a way to use its URL instead of Apple's. (#8948) + s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + switch r.URL.Path { + case "/session": + _, _ = w.Write([]byte(`{"auth_session_token": "xyz"}`)) + case "/account": + _, _ = w.Write([]byte(`{"admin_id": "abc", "org_name": "test_org"}`)) + } + })) + var getAppleBMResp getAppleBMResponse + s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusOK, &getAppleBMResp) + require.NoError(t, getAppleBMResp.Err) + require.Equal(t, "abc", getAppleBMResp.AppleID) + require.Equal(t, "test_org", getAppleBMResp.OrgName) + require.Equal(t, "https://example.org/mdm/apple/mdm", getAppleBMResp.MDMServerURL) + require.Empty(t, getAppleBMResp.DefaultTeam) + + // create a new team + tm, err := s.ds.NewTeam(context.Background(), &fleet.Team{ + Name: t.Name(), + Description: "desc", + }) + require.NoError(t, err) + // set the default bm assignment to that team + acResp := appConfigResponse{} + s.DoJSON("PATCH", "/api/latest/fleet/config", json.RawMessage(fmt.Sprintf(`{ + "mdm": { + "apple_bm_default_team": %q + } + }`, tm.Name)), http.StatusOK, &acResp) + + // try again, this time we get a default team in the response + getAppleBMResp = getAppleBMResponse{} + s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusOK, &getAppleBMResp) + require.NoError(t, getAppleBMResp.Err) + require.Equal(t, "abc", getAppleBMResp.AppleID) + require.Equal(t, "test_org", getAppleBMResp.OrgName) + require.Equal(t, "https://example.org/mdm/apple/mdm", getAppleBMResp.MDMServerURL) + require.Equal(t, tm.Name, getAppleBMResp.DefaultTeam) +} + +func (s *integrationMDMTestSuite) TestABMExpiredToken() { + t := s.T() + s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"code": "T_C_NOT_SIGNED"}`)) + })) + + config := s.getConfig() + require.False(t, config.MDM.AppleBMTermsExpired) + + var getAppleBMResp getAppleBMResponse + s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusInternalServerError, &getAppleBMResp) + + config = s.getConfig() + require.True(t, config.MDM.AppleBMTermsExpired) +} + +func (s *integrationMDMTestSuite) TestDEPProfileAssignment() { + t := s.T() + devices := []godep.Device{ + {SerialNumber: uuid.New().String(), Model: "MacBook Pro", OS: "osx", OpType: "added"}, + {SerialNumber: uuid.New().String(), Model: "MacBook Mini", OS: "osx", OpType: "added"}, + } + + var wg sync.WaitGroup + wg.Add(2) + s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + encoder := json.NewEncoder(w) + switch r.URL.Path { + case "/session": + err := encoder.Encode(map[string]string{"auth_session_token": "xyz"}) + require.NoError(t, err) + case "/profile": + err := encoder.Encode(godep.ProfileResponse{ProfileUUID: "abc"}) + require.NoError(t, err) + case "/server/devices": + // This endpoint is used to get an initial list of + // devices, return a single device + err := encoder.Encode(godep.DeviceResponse{Devices: devices[:1]}) + require.NoError(t, err) + case "/devices/sync": + // This endpoint is polled over time to sync devices from + // ABM, send a repeated serial and a new one + err := encoder.Encode(godep.DeviceResponse{Devices: devices}) + require.NoError(t, err) + case "/profile/devices": + wg.Done() + _, _ = w.Write([]byte(`{}`)) + default: + _, _ = w.Write([]byte(`{}`)) + } + })) + + // create a DEP enrollment profile + profile := json.RawMessage("{}") + var createProfileResp createMDMAppleEnrollmentProfileResponse + createProfileReq := createMDMAppleEnrollmentProfileRequest{ + Type: "automatic", + DEPProfile: &profile, + } + s.DoJSON("POST", "/api/latest/fleet/mdm/apple/enrollmentprofiles", createProfileReq, http.StatusOK, &createProfileResp) + + // query all hosts + listHostsRes := listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes) + require.Empty(t, listHostsRes.Hosts) + + // trigger a profile sync + _, err := s.depSchedule.Trigger() + require.NoError(t, err) + wg.Wait() + + // both hosts should be returned from the hosts endpoint + listHostsRes = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes) + require.Len(t, listHostsRes.Hosts, 2) + require.Equal(t, listHostsRes.Hosts[0].HardwareSerial, devices[0].SerialNumber) + require.Equal(t, listHostsRes.Hosts[1].HardwareSerial, devices[1].SerialNumber) + require.EqualValues( + t, + []string{devices[0].SerialNumber, devices[1].SerialNumber}, + []string{listHostsRes.Hosts[0].HardwareSerial, listHostsRes.Hosts[1].HardwareSerial}, + ) + + // create a new host + createHostAndDeviceToken(t, s.ds, "not-dep") + listHostsRes = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes) + require.Len(t, listHostsRes.Hosts, 3) + + // filtering by MDM status works + listHostsRes = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts?mdm_enrollment_status=pending", nil, http.StatusOK, &listHostsRes) + require.Len(t, listHostsRes.Hosts, 2) + + // enroll one of the hosts + d := newDevice(s) + d.serial = devices[0].SerialNumber + d.mdmEnroll(s) + + // only one shows up as pending + listHostsRes = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts?mdm_enrollment_status=pending", nil, http.StatusOK, &listHostsRes) + require.Len(t, listHostsRes.Hosts, 1) + + activities := listActivitiesResponse{} + s.DoJSON("GET", "/api/latest/fleet/activities", nil, http.StatusOK, &activities, "order_key", "created_at") + found := false + for _, activity := range activities.Activities { + if activity.Type == "mdm_enrolled" && + strings.Contains(string(*activity.Details), devices[0].SerialNumber) { + found = true + require.Nil(t, activity.ActorID) + require.Nil(t, activity.ActorFullName) + require.JSONEq( + t, + fmt.Sprintf( + `{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": true}`, + devices[0].SerialNumber, devices[0].Model, devices[0].SerialNumber, + ), + string(*activity.Details), + ) + } + } + require.True(t, found) } func (s *integrationMDMTestSuite) TestDeviceMDMManualEnroll() { @@ -164,7 +379,7 @@ func (s *integrationMDMTestSuite) TestDeviceMDMManualEnroll() { require.Equal(t, apple_mdm.FleetPayloadIdentifier, profile.PayloadIdentifier) } -func (s *integrationMDMTestSuite) TestDeviceEnrollment() { +func (s *integrationMDMTestSuite) TestAppleMDMDeviceEnrollment() { t := s.T() // Enroll two devices into MDM @@ -212,8 +427,8 @@ func (s *integrationMDMTestSuite) TestDeviceEnrollment() { } } require.Len(t, details, 2) - require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceA.serial, deviceA.model, deviceA.serial), string(*details[0])) - require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceB.serial, deviceB.model, deviceB.serial), string(*details[1])) + require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceA.serial, deviceA.model, deviceA.serial), string(*details[len(details)-2])) + require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceB.serial, deviceB.model, deviceB.serial), string(*details[len(details)-1])) // set an enroll secret var applyResp applyEnrollSecretSpecResponse @@ -317,6 +532,60 @@ func (s *integrationMDMTestSuite) TestAppleMDMCSRRequest() { require.Contains(t, string(reqCSRResp.SCEPKey), "-----BEGIN RSA PRIVATE KEY-----\n") } +func (s *integrationMDMTestSuite) TestMDMAppleUnenroll() { + t := s.T() + // enroll into mdm + d := newMDMEnrolledDevice(s) + + // set an enroll secret + var applyResp applyEnrollSecretSpecResponse + s.DoJSON("POST", "/api/latest/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{ + Spec: &fleet.EnrollSecretSpec{ + Secrets: []*fleet.EnrollSecret{{Secret: t.Name()}}, + }, + }, http.StatusOK, &applyResp) + + // simulate a matching host enrolling via osquery + j, err := json.Marshal(&enrollAgentRequest{ + EnrollSecret: t.Name(), + HostIdentifier: d.uuid, + }) + require.NoError(t, err) + var enrollResp enrollAgentResponse + hres := s.DoRawNoAuth("POST", "/api/osquery/enroll", j, http.StatusOK) + defer hres.Body.Close() + require.NoError(t, json.NewDecoder(hres.Body).Decode(&enrollResp)) + require.NotEmpty(t, enrollResp.NodeKey) + + listHostsRes := listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes) + require.Len(t, listHostsRes.Hosts, 1) + h := listHostsRes.Hosts[0] + + // try to unenroll the host, fails since the host doesn't respond + s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusGatewayTimeout) + + // we're going to modify this mock, make sure we restore its default + originalPushMock := s.pushProvider.PushFunc + defer func() { s.pushProvider.PushFunc = originalPushMock }() + + // TODO: this is not working as expected, we're still waiting on the + // device to unenroll even if we weren't able to send a push + // notification, we should return an error instead. + // + // the APNs service returns an error + // s.pushProvider.PushFunc = mockFailedPush + // s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK) + + // try again, but this time the host is online and answers + s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) { + res, err := mockSuccessfulPush(pushes) + d.checkout() + return res, err + } + s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK) +} + type device struct { uuid string serial string @@ -338,13 +607,18 @@ func newDevice(s *integrationMDMTestSuite) *device { func newMDMEnrolledDevice(s *integrationMDMTestSuite) *device { d := newDevice(s) - d.scepEnroll() - d.authenticate() + d.mdmEnroll(s) return d } +func (d *device) mdmEnroll(s *integrationMDMTestSuite) { + d.scepEnroll() + d.authenticate() + d.tokenUpdate() +} + func (d *device) authenticate() { - payload := map[string]string{ + payload := map[string]any{ "MessageType": "Authenticate", "UDID": d.uuid, "Model": d.model, @@ -356,8 +630,21 @@ func (d *device) authenticate() { d.request("application/x-apple-aspen-mdm-checkin", payload) } +func (d *device) tokenUpdate() { + payload := map[string]any{ + "MessageType": "TokenUpdate", + "UDID": d.uuid, + "Topic": "com.apple.mgmt.External." + d.uuid, + "EnrollmentID": "testenrollmentid-" + d.uuid, + "NotOnConsole": "false", + "PushMagic": "pushmagic" + d.serial, + "Token": []byte("token" + d.serial), + } + d.request("application/x-apple-aspen-mdm-checkin", payload) +} + func (d *device) checkout() { - payload := map[string]string{ + payload := map[string]any{ "MessageType": "CheckOut", "Topic": "com.apple.mgmt.External." + d.uuid, "UDID": d.uuid, @@ -366,7 +653,7 @@ func (d *device) checkout() { d.request("application/x-apple-aspen-mdm-checkin", payload) } -func (d *device) request(reqType string, payload map[string]string) { +func (d *device) request(reqType string, payload map[string]any) { body, err := plist.Marshal(payload) require.NoError(d.s.T(), err) diff --git a/server/service/mock/service.go b/server/service/mock/service.go index 897f54c337..18b8e873cb 100644 --- a/server/service/mock/service.go +++ b/server/service/mock/service.go @@ -1,3 +1,5 @@ package mock //go:generate mockimpl -o service_osquery.go "s *TLSService" "fleet.OsqueryService" +//go:generate mockimpl -o service_pusher_factory.go "s *APNSPushProviderFactory" "github.com/micromdm/nanomdm/push.PushProviderFactory" +//go:generate mockimpl -o service_push_provider.go "s *APNSPushProvider" "github.com/micromdm/nanomdm/push.PushProvider" diff --git a/server/service/mock/service_push_provider.go b/server/service/mock/service_push_provider.go new file mode 100644 index 0000000000..ab2923b020 --- /dev/null +++ b/server/service/mock/service_push_provider.go @@ -0,0 +1,22 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/push" +) + +var _ push.PushProvider = (*APNSPushProvider)(nil) + +type PushFunc func(p0 []*mdm.Push) (map[string]*push.Response, error) + +type APNSPushProvider struct { + PushFunc PushFunc + PushFuncInvoked bool +} + +func (s *APNSPushProvider) Push(p0 []*mdm.Push) (map[string]*push.Response, error) { + s.PushFuncInvoked = true + return s.PushFunc(p0) +} diff --git a/server/service/mock/service_pusher_factory.go b/server/service/mock/service_pusher_factory.go new file mode 100644 index 0000000000..d4c0631cd3 --- /dev/null +++ b/server/service/mock/service_pusher_factory.go @@ -0,0 +1,23 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "crypto/tls" + + "github.com/micromdm/nanomdm/push" +) + +var _ push.PushProviderFactory = (*APNSPushProviderFactory)(nil) + +type NewPushProviderFunc func(p0 *tls.Certificate) (push.PushProvider, error) + +type APNSPushProviderFactory struct { + NewPushProviderFunc NewPushProviderFunc + NewPushProviderFuncInvoked bool +} + +func (s *APNSPushProviderFactory) NewPushProvider(p0 *tls.Certificate) (push.PushProvider, error) { + s.NewPushProviderFuncInvoked = true + return s.NewPushProviderFunc(p0) +} diff --git a/server/service/testing_utils.go b/server/service/testing_utils.go index 330fe9c651..c5ac852c05 100644 --- a/server/service/testing_utils.go +++ b/server/service/testing_utils.go @@ -2,6 +2,7 @@ package service import ( "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -18,15 +19,19 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/logging" - apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/service/async" + "github.com/fleetdm/fleet/v4/server/service/mock" "github.com/fleetdm/fleet/v4/server/sso" "github.com/fleetdm/fleet/v4/server/test" kitlog "github.com/go-kit/kit/log" + "github.com/google/uuid" nanodep_storage "github.com/micromdm/nanodep/storage" + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/push" nanomdm_push "github.com/micromdm/nanomdm/push" nanomdm_storage "github.com/micromdm/nanomdm/storage" + scep_depot "github.com/micromdm/scep/v2/depot" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/throttled/throttled/v2" @@ -238,7 +243,7 @@ type TestServerOpts struct { FleetConfig *config.FleetConfig MDMStorage nanomdm_storage.AllStorage DEPStorage nanodep_storage.AllStorage - SCEPStorage *apple_mdm.SCEPMySQLDepot + SCEPStorage scep_depot.Depot MDMPusher nanomdm_push.Pusher HTTPServerConfig *http.Server StartCronSchedules []TestNewScheduleFunc @@ -491,3 +496,25 @@ func (nopEnrollHostLimiter) CanEnrollNewHost(ctx context.Context) (bool, error) func (nopEnrollHostLimiter) SyncEnrolledHostIDs(ctx context.Context) error { return nil } + +func newMockAPNSPushProviderFactory() (*mock.APNSPushProviderFactory, *mock.APNSPushProvider) { + provider := &mock.APNSPushProvider{} + provider.PushFunc = mockSuccessfulPush + factory := &mock.APNSPushProviderFactory{} + factory.NewPushProviderFunc = func(*tls.Certificate) (push.PushProvider, error) { + return provider, nil + } + + return factory, provider +} + +func mockSuccessfulPush(pushes []*mdm.Push) (map[string]*push.Response, error) { + res := make(map[string]*push.Response, len(pushes)) + for _, p := range pushes { + res[p.Token.String()] = &push.Response{ + Id: uuid.New().String(), + Err: nil, + } + } + return res, nil +}