add mocks + tests and move things around (#9574)

#8948

- Add more go:generate commands for MDM mocks
- Add unit and integration tests for MDM code
- Move interfaces from their PoC location to match existing patterns
This commit is contained in:
Roberto Dip 2023-01-31 11:46:01 -03:00 committed by GitHub
parent f3642b18da
commit 4c4c114e96
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 1160 additions and 278 deletions

View file

@ -173,8 +173,8 @@ generate-dev: .prefix
NODE_ENV=development webpack --progress --colors --watch
generate-mock: .prefix
go install github.com/groob/mockimpl@latest
go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult
go install github.com/fleetdm/mockimpl@8d7943aa39d8f5f464d3d3618d9571d385f7bcc5
go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult github.com/fleetdm/fleet/v4/server/service/mock
generate-doc: .prefix
go generate github.com/fleetdm/fleet/v4/server/fleet
@ -403,4 +403,4 @@ db-replica-reset: fleet
# db-replica-run runs fleet serve with one main and one read MySQL instance.
db-replica-run: fleet
FLEET_MYSQL_ADDRESS=127.0.0.1:3308 FLEET_MYSQL_READ_REPLICA_ADDRESS=127.0.0.1:3309 FLEET_MYSQL_READ_REPLICA_USERNAME=fleet FLEET_MYSQL_READ_REPLICA_DATABASE=fleet FLEET_MYSQL_READ_REPLICA_PASSWORD=insecure ./build/fleet serve --dev --dev_license
FLEET_MYSQL_ADDRESS=127.0.0.1:3308 FLEET_MYSQL_READ_REPLICA_ADDRESS=127.0.0.1:3309 FLEET_MYSQL_READ_REPLICA_USERNAME=fleet FLEET_MYSQL_READ_REPLICA_DATABASE=fleet FLEET_MYSQL_READ_REPLICA_PASSWORD=insecure ./build/fleet serve --dev --dev_license

View file

@ -10,6 +10,8 @@ import (
"strings"
"time"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
eewebhooks "github.com/fleetdm/fleet/v4/ee/server/webhooks"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/config"
@ -17,7 +19,6 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/policies"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service/externalsvc"
@ -32,9 +33,6 @@ import (
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/hashicorp/go-multierror"
"github.com/micromdm/nanodep/godep"
nanodep_log "github.com/micromdm/nanodep/log"
depsync "github.com/micromdm/nanodep/sync"
)
func errHandler(ctx context.Context, logger kitlog.Logger, msg string, err error) {
@ -801,32 +799,6 @@ func trySendStatistics(ctx context.Context, ds fleet.Datastore, frequency time.D
return ds.RecordStatisticsSent(ctx)
}
// NanoDEPLogger is a logger adapter for nanodep.
type NanoDEPLogger struct {
logger kitlog.Logger
}
func NewNanoDEPLogger(logger kitlog.Logger) *NanoDEPLogger {
return &NanoDEPLogger{
logger: logger,
}
}
func (l *NanoDEPLogger) Info(keyvals ...interface{}) {
level.Info(l.logger).Log(keyvals...)
}
func (l *NanoDEPLogger) Debug(keyvals ...interface{}) {
level.Debug(l.logger).Log(keyvals...)
}
func (l *NanoDEPLogger) With(keyvals ...interface{}) nanodep_log.Logger {
newLogger := kitlog.With(l.logger, keyvals...)
return &NanoDEPLogger{
logger: newLogger,
}
}
// newAppleMDMDEPProfileAssigner creates the schedule to run the DEP syncer+assigner.
// The DEP syncer+assigner fetches devices from Apple Business Manager (aka ABM) and applies
// the current configured DEP profile to them.
@ -840,65 +812,13 @@ func newAppleMDMDEPProfileAssigner(
loggingDebug bool,
) (*schedule.Schedule, error) {
const name = string(fleet.CronAppleMDMDEPProfileAssigner)
depClient := fleet.NewDEPClient(depStorage, ds, logger)
assignerOpts := []depsync.AssignerOption{
depsync.WithAssignerLogger(NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-assigner"))),
}
if loggingDebug {
assignerOpts = append(assignerOpts, depsync.WithDebug())
}
assigner := depsync.NewAssigner(
depClient,
apple_mdm.DEPName,
depStorage,
assignerOpts...,
)
syncer := depsync.NewSyncer(
depClient,
apple_mdm.DEPName,
depStorage,
depsync.WithLogger(NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-syncer"))),
depsync.WithCallback(func(ctx context.Context, isFetch bool, resp *godep.DeviceResponse) error {
n, err := ds.IngestMDMAppleDevicesFromDEPSync(ctx, resp.Devices)
switch {
case err != nil:
level.Error(kitlog.With(logger, "cron", name, "component", "nanodep-syncer")).Log("err", err)
sentry.CaptureException(err)
case n > 0:
level.Info(kitlog.With(logger, "cron", name, "component", "nanodep-syncer")).Log("msg", fmt.Sprintf("added %d new mdm device(s) to pending hosts", n))
default:
// ok
}
return assigner.ProcessDeviceResponse(ctx, resp)
}),
)
logger = kitlog.With(logger, "cron", name)
logger = kitlog.With(logger, "cron", name, "component", "nanodep-syncer")
fleetSyncer := apple_mdm.NewDEPSyncer(ds, depStorage, logger, loggingDebug)
s := schedule.New(
ctx, name, instanceID, periodicity, ds, ds,
schedule.WithLogger(logger),
schedule.WithJob("dep_syncer", func(ctx context.Context) error {
profileUUID, profileModTime, err := depStorage.RetrieveAssignerProfile(ctx, apple_mdm.DEPName)
if err != nil {
return err
}
if profileUUID == "" {
logger.Log("msg", "DEP profile not set, nothing to do")
return nil
}
cursor, cursorModTime, err := depStorage.RetrieveCursor(ctx, apple_mdm.DEPName)
if err != nil {
return err
}
// If the DEP Profile was changed since last sync then we clear
// the cursor and perform a full sync of all devices and profile assigning.
if cursor != "" && profileModTime.After(cursorModTime) {
logger.Log("msg", "clearing device syncer cursor")
if err := depStorage.StoreCursor(ctx, apple_mdm.DEPName, ""); err != nil {
return err
}
}
return syncer.Run(ctx)
return fleetSyncer.Run(ctx)
}),
)

View file

@ -19,6 +19,8 @@ import (
"syscall"
"time"
scep_depot "github.com/micromdm/scep/v2/depot"
"github.com/WatchBeam/clock"
"github.com/e-dard/netbug"
"github.com/fleetdm/fleet/v4/ee/server/licensing"
@ -40,7 +42,6 @@ import (
"github.com/fleetdm/fleet/v4/server/live_query"
"github.com/fleetdm/fleet/v4/server/logging"
"github.com/fleetdm/fleet/v4/server/mail"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/pubsub"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/fleetdm/fleet/v4/server/service/async"
@ -468,7 +469,7 @@ the way that the Fleet server works.
}
var (
scepStorage *apple_mdm.SCEPMySQLDepot
scepStorage scep_depot.Depot
appleSCEPCertPEM []byte
appleSCEPKeyPEM []byte
appleAPNsCertPEM []byte
@ -545,7 +546,7 @@ the way that the Fleet server works.
initFatal(errors.New("Apple BM configuration must be provided to enable MDM"), "validate Apple MDM")
}
scepStorage, err = mds.NewMDMAppleSCEPDepot(appleSCEPCertPEM, appleSCEPKeyPEM)
scepStorage, err = mds.NewSCEPDepot(appleSCEPCertPEM, appleSCEPKeyPEM)
if err != nil {
initFatal(err, "initialize mdm apple scep storage")
}

View file

@ -0,0 +1,102 @@
package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestGenerateMDMAppleBM(t *testing.T) {
outdir, err := os.MkdirTemp("", t.Name())
require.NoError(t, err)
defer os.Remove(outdir)
publicKeyPath := filepath.Join(outdir, "public-key.crt")
privateKeyPath := filepath.Join(outdir, "private-key.key")
out := runAppForTest(t, []string{
"generate", "mdm-apple-bm",
"--public-key", publicKeyPath,
"--private-key", privateKeyPath,
})
require.Contains(t, out, fmt.Sprintf("Generated your public key at %s", outdir))
require.Contains(t, out, fmt.Sprintf("Generated your private key at %s", outdir))
// validate that the keypair is valid
cert, err := tls.LoadX509KeyPair(publicKeyPath, privateKeyPath)
require.NoError(t, err)
parsed, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Equal(t, "FleetDM", parsed.Issuer.CommonName)
}
func TestGenerateMDMApple(t *testing.T) {
t.Run("missing input", func(t *testing.T) {
runAppCheckErr(t, []string{"generate", "mdm-apple"}, `Required flags "email, org" not set`)
runAppCheckErr(t, []string{"generate", "mdm-apple", "--email", "user@example.com"}, `Required flag "org" not set`)
runAppCheckErr(t, []string{"generate", "mdm-apple", "--org", "Acme"}, `Required flag "email" not set`)
})
t.Run("CSR API call fails", func(t *testing.T) {
_, _ = runServerWithMockedDS(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// fail this call
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
}))
t.Setenv("TEST_FLEETDM_API_URL", srv.URL)
t.Cleanup(srv.Close)
runAppCheckErr(
t,
[]string{
"generate", "mdm-apple",
"--email", "user@example.com",
"--org", "Acme",
},
`POST /api/latest/fleet/mdm/apple/request_csr received status 502 Bad Gateway: FleetDM CSR request failed: api responded with 400: bad request`,
)
})
t.Run("successful run", func(t *testing.T) {
_, _ = runServerWithMockedDS(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
t.Setenv("TEST_FLEETDM_API_URL", srv.URL)
t.Cleanup(srv.Close)
outdir, err := os.MkdirTemp("", "TestGenerateMDMApple")
require.NoError(t, err)
defer os.Remove(outdir)
apnsKeyPath := filepath.Join(outdir, "apns.key")
scepCertPath := filepath.Join(outdir, "scep.crt")
scepKeyPath := filepath.Join(outdir, "scep.key")
out := runAppForTest(t, []string{
"generate", "mdm-apple",
"--email", "user@example.com",
"--org", "Acme",
"--apns-key", apnsKeyPath,
"--scep-cert", scepCertPath,
"--scep-key", scepKeyPath,
})
require.Contains(t, out, fmt.Sprintf("Generated your APNs key at %s", apnsKeyPath))
require.Contains(t, out, fmt.Sprintf("Generated your SCEP certificate at %s", scepCertPath))
require.Contains(t, out, fmt.Sprintf("Generated your SCEP key at %s", scepKeyPath))
// validate that the keypair is valid
scepCrt, err := tls.LoadX509KeyPair(scepCertPath, scepKeyPath)
require.NoError(t, err)
parsed, err := x509.ParseCertificate(scepCrt.Certificate[0])
require.NoError(t, err)
require.Equal(t, "FleetDM", parsed.Issuer.CommonName)
})
}

View file

@ -47,7 +47,7 @@ func (svc *Service) GetAppleBM(ctx context.Context) (*fleet.AppleBM, error) {
}
func getAppleBMAccountDetail(ctx context.Context, depStorage storage.AllStorage, ds fleet.Datastore, logger kitlog.Logger) (*fleet.AppleBM, error) {
depClient := fleet.NewDEPClient(depStorage, ds, logger)
depClient := apple_mdm.NewDEPClient(depStorage, ds, logger)
res, err := depClient.AccountDetail(ctx, apple_mdm.DEPName)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "apple GET /account request failed")

View file

@ -273,16 +273,19 @@ func ingestMDMAppleDeviceFromCheckinDB(
case err != nil:
return ctxerr.Wrap(ctx, err, "get mdm apple host by serial number or udid")
case foundHost.HardwareSerial != mdmHost.SerialNumber || foundHost.UUID != mdmHost.UDID:
return updateMDMAppleHostDB(ctx, tx, foundHost.ID, mdmHost)
default:
// ok, nothing to do here
return nil
return updateMDMAppleHostDB(ctx, tx, foundHost.ID, mdmHost, appCfg)
}
}
func updateMDMAppleHostDB(ctx context.Context, tx sqlx.ExtContext, hostID uint, mdmHost fleet.MDMAppleHostDetails) error {
func updateMDMAppleHostDB(
ctx context.Context,
tx sqlx.ExtContext,
hostID uint,
mdmHost fleet.MDMAppleHostDetails,
appCfg *fleet.AppConfig,
) error {
updateStmt := `
UPDATE hosts SET
hardware_serial = ?,
@ -310,6 +313,10 @@ func updateMDMAppleHostDB(ctx context.Context, tx sqlx.ExtContext, hostID uint,
return ctxerr.Wrap(ctx, err, "update mdm apple host")
}
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, hostID); err != nil {
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info")
}
return nil
}
@ -365,7 +372,7 @@ func insertMDMAppleHostDB(
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert label membership")
}
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, host); err != nil {
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, host.ID); err != nil {
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info")
}
return nil
@ -468,7 +475,11 @@ func (ds *Datastore) IngestMDMAppleDevicesFromDEPSync(ctx context.Context, devic
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert label membership")
}
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, true, hosts...); err != nil {
var ids []uint
for _, h := range hosts {
ids = append(ids, h.ID)
}
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, true, ids...); err != nil {
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info")
}
@ -497,7 +508,7 @@ func upsertMDMAppleHostDisplayNamesDB(ctx context.Context, tx sqlx.ExtContext, h
return nil
}
func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, serverURL string, fromSync bool, hosts ...fleet.Host) error {
func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, serverURL string, fromSync bool, hostIDs ...uint) error {
result, err := tx.ExecContext(ctx, `
INSERT INTO mobile_device_management_solutions (name, server_url) VALUES (?, ?)
ON DUPLICATE KEY UPDATE server_url = VALUES(server_url)`,
@ -522,8 +533,8 @@ func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, server
args := []interface{}{}
parts := []string{}
for _, h := range hosts {
args = append(args, enrolled, serverURL, fromSync, mdmID, false, h.ID)
for _, id := range hostIDs {
args = append(args, enrolled, serverURL, fromSync, mdmID, false, id)
parts = append(parts, "(?, ?, ?, ?, ?, ?)")
}

View file

@ -26,7 +26,6 @@ import (
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
"github.com/fleetdm/fleet/v4/server/fleet"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/goose"
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
@ -36,6 +35,7 @@ import (
nanodep_client "github.com/micromdm/nanodep/client"
nanodep_mysql "github.com/micromdm/nanodep/storage/mysql"
nanomdm_mysql "github.com/micromdm/nanomdm/storage/mysql"
scep_depot "github.com/micromdm/scep/v2/depot"
"github.com/ngrok/sqlmw"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
)
@ -107,14 +107,10 @@ func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.
return stmt
}
// NewMDMAppleSCEPDepot returns a *apple_mdm.MySQLDepot that uses the Datastore
// NewMDMAppleSCEPDepot returns a scep_depot.Depot that uses the Datastore
// underlying MySQL writer *sql.DB.
func (ds *Datastore) NewMDMAppleSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (*apple_mdm.SCEPMySQLDepot, error) {
depot, err := apple_mdm.NewSCEPMySQLDepot(ds.writer.DB, caCertPEM, caKeyPEM)
if err != nil {
return nil, err
}
return depot, nil
func (ds *Datastore) NewSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (scep_depot.Depot, error) {
return newSCEPDepot(ds.writer.DB, caCertPEM, caKeyPEM)
}
// NewMDMAppleMDMStorage returns a MySQL nanomdm storage that uses the Datastore

View file

@ -1,4 +1,4 @@
package apple_mdm
package mysql
import (
"crypto/rsa"
@ -11,12 +11,13 @@ import (
"fmt"
"math/big"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/micromdm/nanomdm/cryptoutil"
"github.com/micromdm/scep/v2/depot"
)
// SCEPMySQLDepot is a MySQL-backed SCEP certificate depot.
type SCEPMySQLDepot struct {
// SCEPDepot is a MySQL-backed SCEP certificate depot.
type SCEPDepot struct {
db *sql.DB
// caCrt holds the CA's certificate.
@ -25,10 +26,10 @@ type SCEPMySQLDepot struct {
caKey *rsa.PrivateKey
}
var _ depot.Depot = (*SCEPMySQLDepot)(nil)
var _ depot.Depot = (*SCEPDepot)(nil)
// NewSCEPMySQLDepot creates and returns a *SCEPMySQLDepot.
func NewSCEPMySQLDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPMySQLDepot, error) {
// newSCEPDepot creates and returns a *SCEPDepot.
func newSCEPDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPDepot, error) {
if err := db.Ping(); err != nil {
return nil, err
}
@ -40,7 +41,7 @@ func NewSCEPMySQLDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPMySQ
if err != nil {
return nil, err
}
return &SCEPMySQLDepot{
return &SCEPDepot{
db: db,
caCrt: caCrt,
caKey: caKey,
@ -56,12 +57,12 @@ func decodeRSAKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
}
// CA returns the CA's certificate and private key.
func (d *SCEPMySQLDepot) CA(_ []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) {
func (d *SCEPDepot) CA(_ []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) {
return []*x509.Certificate{d.caCrt}, d.caKey, nil
}
// Serial allocates and returns a new (increasing) serial number.
func (d *SCEPMySQLDepot) Serial() (*big.Int, error) {
func (d *SCEPDepot) Serial() (*big.Int, error) {
result, err := d.db.Exec(`INSERT INTO scep_serials () VALUES ();`)
if err != nil {
return nil, err
@ -78,7 +79,7 @@ func (d *SCEPMySQLDepot) Serial() (*big.Int, error) {
// TODO(lucas): Implement and use allowTime and revokeOldCertificate.
// - allowTime are the maximum days before expiration to allow clients to do certificate renewal.
// - revokeOldCertificate specifies whether to revoke the old certificate once renewed.
func (d *SCEPMySQLDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) {
func (d *SCEPDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) {
var ct int
row := d.db.QueryRow(`SELECT COUNT(*) FROM scep_certificates WHERE name = ?`, cn)
if err := row.Scan(&ct); err != nil {
@ -91,14 +92,14 @@ func (d *SCEPMySQLDepot) HasCN(cn string, allowTime int, cert *x509.Certificate,
//
// If the provided certificate has empty crt.Subject.CommonName,
// then the hex sha256 of the crt.Raw is used as name.
func (d *SCEPMySQLDepot) Put(name string, crt *x509.Certificate) error {
func (d *SCEPDepot) Put(name string, crt *x509.Certificate) error {
if crt.Subject.CommonName == "" {
name = fmt.Sprintf("%x", sha256.Sum256(crt.Raw))
}
if !crt.SerialNumber.IsInt64() {
return errors.New("cannot represent serial number as int64")
}
certPEM := EncodeCertPEM(crt)
certPEM := apple_mdm.EncodeCertPEM(crt)
_, err := d.db.Exec(`
INSERT INTO scep_certificates
(serial, name, not_valid_before, not_valid_after, certificate_pem)

View file

@ -1,6 +1,4 @@
// use a different package name to avoid
// import cycle
package apple_mdm_test
package mysql
import (
"crypto/x509"
@ -8,19 +6,19 @@ import (
"math/big"
"testing"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/micromdm/nanodep/tokenpki"
scep_depot "github.com/micromdm/scep/v2/depot"
"github.com/stretchr/testify/require"
)
func setup(t *testing.T) *apple_mdm.SCEPMySQLDepot {
ds := mysql.CreateNamedMySQLDS(t, t.Name())
func setup(t *testing.T) scep_depot.Depot {
ds := CreateNamedMySQLDS(t, t.Name())
cert, key, err := apple_mdm.NewSCEPCACertKey()
require.NoError(t, err)
publicKeyPEM := tokenpki.PEMCertificate(cert.Raw)
privateKeyPEM := tokenpki.PEMRSAPrivateKey(key)
depot, err := ds.NewMDMAppleSCEPDepot(publicKeyPEM, privateKeyPEM)
depot, err := ds.NewSCEPDepot(publicKeyPEM, privateKeyPEM)
require.NoError(t, err)
return depot
}

View file

@ -3,11 +3,6 @@ package fleet
import (
"context"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/micromdm/nanodep/godep"
)
type AppleMDM struct {
@ -52,42 +47,3 @@ type AppConfigUpdater interface {
AppConfig(ctx context.Context) (*AppConfig, error)
SaveAppConfig(ctx context.Context, info *AppConfig) error
}
// NewDEPClient creates an Apple DEP API HTTP client based on the provided
// storage that will flag the AppConfig's AppleBMTermsExpired field whenever
// the status of the terms changes.
func NewDEPClient(storage godep.ClientStorage, appCfgUpdater AppConfigUpdater, logger kitlog.Logger) *godep.Client {
return godep.NewClient(storage, fleethttp.NewClient(), godep.WithAfterHook(func(ctx context.Context, reqErr error) error {
// if the request failed due to terms not signed, or if it succeeded,
// update the app config flag accordingly. If it failed for any other
// reason, do not update the flag.
termsExpired := reqErr != nil && godep.IsTermsNotSigned(reqErr)
if reqErr == nil || termsExpired {
appCfg, err := appCfgUpdater.AppConfig(ctx)
if err != nil {
level.Error(logger).Log("msg", "Apple DEP client: failed to get app config", "err", err)
return reqErr
}
var mustSaveAppCfg bool
if termsExpired && !appCfg.MDM.AppleBMTermsExpired {
// flag the AppConfig that the terms have changed and must be accepted
appCfg.MDM.AppleBMTermsExpired = true
mustSaveAppCfg = true
} else if reqErr == nil && appCfg.MDM.AppleBMTermsExpired {
// flag the AppConfig that the terms have been accepted
appCfg.MDM.AppleBMTermsExpired = false
mustSaveAppCfg = true
}
if mustSaveAppCfg {
if err := appCfgUpdater.SaveAppConfig(ctx, appCfg); err != nil {
level.Error(logger).Log("msg", "Apple DEP client: failed to save app config", "err", err)
}
level.Debug(logger).Log("msg", "Apple DEP client: updated app config Terms Expired flag",
"apple_bm_terms_expired", appCfg.MDM.AppleBMTermsExpired)
}
}
return reqErr
}))
}

View file

@ -11,25 +11,13 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/mock"
nanodep_mock "github.com/fleetdm/fleet/v4/server/mock/nanodep"
"github.com/go-kit/log"
nanodep_client "github.com/micromdm/nanodep/client"
"github.com/micromdm/nanodep/godep"
"github.com/stretchr/testify/require"
)
type mockStorage struct {
token string
url string
}
func (s mockStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
return &nanodep_client.OAuth1Tokens{AccessToken: s.token}, nil
}
func (s mockStorage) RetrieveConfig(context.Context, string) (*nanodep_client.Config, error) {
return &nanodep_client.Config{BaseURL: s.url}, nil
}
func TestDEPClient(t *testing.T) {
ctx := context.Background()
@ -141,8 +129,15 @@ func TestDEPClient(t *testing.T) {
for i, c := range cases {
t.Logf("case %d", i)
store := mockStorage{token: c.token, url: srv.URL}
dep := fleet.NewDEPClient(store, ds, logger)
store := &nanodep_mock.Storage{}
store.RetrieveAuthTokensFunc = func(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
return &nanodep_client.OAuth1Tokens{AccessToken: c.token}, nil
}
store.RetrieveConfigFunc = func(context.Context, string) (*nanodep_client.Config, error) {
return &nanodep_client.Config{BaseURL: srv.URL}, nil
}
dep := apple_mdm.NewDEPClient(store, ds, logger)
res, err := dep.AccountDetail(ctx, apple_mdm.DEPName)
if c.wantErr {
@ -163,6 +158,8 @@ func TestDEPClient(t *testing.T) {
} else {
require.NoError(t, err)
require.Equal(t, "test", res.AdminID)
require.True(t, store.RetrieveAuthTokensFuncInvoked)
require.True(t, store.RetrieveConfigFuncInvoked)
}
checkDSCalled(c.readInvoked, c.writeInvoked)
require.Equal(t, c.termsFlag, appCfg.MDM.AppleBMTermsExpired)

33
server/logging/nanodep.go Normal file
View file

@ -0,0 +1,33 @@
package logging
import (
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/log/level"
nanodep_log "github.com/micromdm/nanodep/log"
)
// NanoDEPLogger is a logger adapter for nanodep.
type NanoDEPLogger struct {
logger kitlog.Logger
}
func NewNanoDEPLogger(logger kitlog.Logger) *NanoDEPLogger {
return &NanoDEPLogger{
logger: logger,
}
}
func (l *NanoDEPLogger) Info(keyvals ...interface{}) {
level.Info(l.logger).Log(keyvals...)
}
func (l *NanoDEPLogger) Debug(keyvals ...interface{}) {
level.Debug(l.logger).Log(keyvals...)
}
func (l *NanoDEPLogger) With(keyvals ...interface{}) nanodep_log.Logger {
newLogger := kitlog.With(l.logger, keyvals...)
return &NanoDEPLogger{
logger: newLogger,
}
}

View file

@ -1,8 +1,21 @@
package apple_mdm
import (
"context"
"fmt"
"net/url"
"path"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/logging"
"github.com/getsentry/sentry-go"
"github.com/go-kit/log/level"
"github.com/micromdm/nanodep/godep"
kitlog "github.com/go-kit/kit/log"
nanodep_storage "github.com/micromdm/nanodep/storage"
depsync "github.com/micromdm/nanodep/sync"
)
// DEPName is the identifier/name used in nanodep MySQL storage which
@ -47,3 +60,118 @@ func resolveURL(serverURL, relPath string) (string, error) {
u.Path = path.Join(u.Path, relPath)
return u.String(), nil
}
type DEPSyncer struct {
depStorage nanodep_storage.AllStorage
syncer *depsync.Syncer
logger kitlog.Logger
}
func (d *DEPSyncer) Run(ctx context.Context) error {
profileUUID, profileModTime, err := d.depStorage.RetrieveAssignerProfile(ctx, DEPName)
if err != nil {
return err
}
if profileUUID == "" {
d.logger.Log("msg", "DEP profile not set, nothing to do")
return nil
}
cursor, cursorModTime, err := d.depStorage.RetrieveCursor(ctx, DEPName)
if err != nil {
return err
}
// If the DEP Profile was changed since last sync then we clear
// the cursor and perform a full sync of all devices and profile assigning.
if cursor != "" && profileModTime.After(cursorModTime) {
d.logger.Log("msg", "clearing device syncer cursor")
if err := d.depStorage.StoreCursor(ctx, DEPName, ""); err != nil {
return err
}
}
return d.syncer.Run(ctx)
}
func NewDEPSyncer(
ds fleet.Datastore,
depStorage nanodep_storage.AllStorage,
logger kitlog.Logger,
loggingDebug bool,
) *DEPSyncer {
depClient := NewDEPClient(depStorage, ds, logger)
assignerOpts := []depsync.AssignerOption{
depsync.WithAssignerLogger(logging.NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-assigner"))),
}
if loggingDebug {
assignerOpts = append(assignerOpts, depsync.WithDebug())
}
assigner := depsync.NewAssigner(
depClient,
DEPName,
depStorage,
assignerOpts...,
)
syncer := depsync.NewSyncer(
depClient,
DEPName,
depStorage,
depsync.WithLogger(logging.NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-syncer"))),
depsync.WithCallback(func(ctx context.Context, isFetch bool, resp *godep.DeviceResponse) error {
n, err := ds.IngestMDMAppleDevicesFromDEPSync(ctx, resp.Devices)
switch {
case err != nil:
level.Error(kitlog.With(logger)).Log("err", err)
sentry.CaptureException(err)
case n > 0:
level.Info(kitlog.With(logger)).Log("msg", fmt.Sprintf("added %d new mdm device(s) to pending hosts", n))
}
return assigner.ProcessDeviceResponse(ctx, resp)
}),
)
return &DEPSyncer{
syncer: syncer,
depStorage: depStorage,
logger: logger,
}
}
// NewDEPClient creates an Apple DEP API HTTP client based on the provided
// storage that will flag the AppConfig's AppleBMTermsExpired field
// whenever the status of the terms changes.
func NewDEPClient(storage godep.ClientStorage, appCfgUpdater fleet.AppConfigUpdater, logger kitlog.Logger) *godep.Client {
return godep.NewClient(storage, fleethttp.NewClient(), godep.WithAfterHook(func(ctx context.Context, reqErr error) error {
// if the request failed due to terms not signed, or if it succeeded,
// update the app config flag accordingly. If it failed for any other
// reason, do not update the flag.
termsExpired := reqErr != nil && godep.IsTermsNotSigned(reqErr)
if reqErr == nil || termsExpired {
appCfg, err := appCfgUpdater.AppConfig(ctx)
if err != nil {
level.Error(logger).Log("msg", "Apple DEP client: failed to get app config", "err", err)
return reqErr
}
var mustSaveAppCfg bool
if termsExpired && !appCfg.MDM.AppleBMTermsExpired {
// flag the AppConfig that the terms have changed and must be accepted
appCfg.MDM.AppleBMTermsExpired = true
mustSaveAppCfg = true
} else if reqErr == nil && appCfg.MDM.AppleBMTermsExpired {
// flag the AppConfig that the terms have been accepted
appCfg.MDM.AppleBMTermsExpired = false
mustSaveAppCfg = true
}
if mustSaveAppCfg {
if err := appCfgUpdater.SaveAppConfig(ctx, appCfg); err != nil {
level.Error(logger).Log("msg", "Apple DEP client: failed to save app config", "err", err)
}
level.Debug(logger).Log("msg", "Apple DEP client: updated app config Terms Expired flag",
"apple_bm_terms_expired", appCfg.MDM.AppleBMTermsExpired)
}
}
return reqErr
}))
}

View file

@ -8,6 +8,9 @@ import (
//go:generate mockimpl -o datastore_mock.go "s *DataStore" "fleet.Datastore"
//go:generate mockimpl -o datastore_installers.go "s *InstallerStore" "fleet.InstallerStore"
//go:generate mockimpl -o nanomdm/storage.go "s *Storage" "github.com/micromdm/nanomdm/storage.AllStorage"
//go:generate mockimpl -o nanodep/storage.go "s *Storage" "github.com/micromdm/nanodep/storage.AllStorage"
//go:generate mockimpl -o scep/depot.go "d *Depot" "depot.Depot"
var _ fleet.Datastore = (*Store)(nil)

View file

@ -0,0 +1,115 @@
// Automatically generated by mockimpl. DO NOT EDIT!
package mock
import (
"context"
"time"
"github.com/micromdm/nanodep/client"
"github.com/micromdm/nanodep/storage"
)
var _ storage.AllStorage = (*Storage)(nil)
type RetrieveAuthTokensFunc func(ctx context.Context, name string) (*client.OAuth1Tokens, error)
type RetrieveConfigFunc func(p0 context.Context, p1 string) (*client.Config, error)
type RetrieveAssignerProfileFunc func(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error)
type RetrieveCursorFunc func(ctx context.Context, name string) (cursor string, modTime time.Time, err error)
type StoreCursorFunc func(ctx context.Context, name string, cursor string) error
type StoreAuthTokensFunc func(ctx context.Context, name string, tokens *client.OAuth1Tokens) error
type StoreConfigFunc func(ctx context.Context, name string, config *client.Config) error
type StoreTokenPKIFunc func(ctx context.Context, name string, pemCert []byte, pemKey []byte) error
type RetrieveTokenPKIFunc func(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error)
type StoreAssignerProfileFunc func(ctx context.Context, name string, profileUUID string) error
type Storage struct {
RetrieveAuthTokensFunc RetrieveAuthTokensFunc
RetrieveAuthTokensFuncInvoked bool
RetrieveConfigFunc RetrieveConfigFunc
RetrieveConfigFuncInvoked bool
RetrieveAssignerProfileFunc RetrieveAssignerProfileFunc
RetrieveAssignerProfileFuncInvoked bool
RetrieveCursorFunc RetrieveCursorFunc
RetrieveCursorFuncInvoked bool
StoreCursorFunc StoreCursorFunc
StoreCursorFuncInvoked bool
StoreAuthTokensFunc StoreAuthTokensFunc
StoreAuthTokensFuncInvoked bool
StoreConfigFunc StoreConfigFunc
StoreConfigFuncInvoked bool
StoreTokenPKIFunc StoreTokenPKIFunc
StoreTokenPKIFuncInvoked bool
RetrieveTokenPKIFunc RetrieveTokenPKIFunc
RetrieveTokenPKIFuncInvoked bool
StoreAssignerProfileFunc StoreAssignerProfileFunc
StoreAssignerProfileFuncInvoked bool
}
func (s *Storage) RetrieveAuthTokens(ctx context.Context, name string) (*client.OAuth1Tokens, error) {
s.RetrieveAuthTokensFuncInvoked = true
return s.RetrieveAuthTokensFunc(ctx, name)
}
func (s *Storage) RetrieveConfig(p0 context.Context, p1 string) (*client.Config, error) {
s.RetrieveConfigFuncInvoked = true
return s.RetrieveConfigFunc(p0, p1)
}
func (s *Storage) RetrieveAssignerProfile(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error) {
s.RetrieveAssignerProfileFuncInvoked = true
return s.RetrieveAssignerProfileFunc(ctx, name)
}
func (s *Storage) RetrieveCursor(ctx context.Context, name string) (cursor string, modTime time.Time, err error) {
s.RetrieveCursorFuncInvoked = true
return s.RetrieveCursorFunc(ctx, name)
}
func (s *Storage) StoreCursor(ctx context.Context, name string, cursor string) error {
s.StoreCursorFuncInvoked = true
return s.StoreCursorFunc(ctx, name, cursor)
}
func (s *Storage) StoreAuthTokens(ctx context.Context, name string, tokens *client.OAuth1Tokens) error {
s.StoreAuthTokensFuncInvoked = true
return s.StoreAuthTokensFunc(ctx, name, tokens)
}
func (s *Storage) StoreConfig(ctx context.Context, name string, config *client.Config) error {
s.StoreConfigFuncInvoked = true
return s.StoreConfigFunc(ctx, name, config)
}
func (s *Storage) StoreTokenPKI(ctx context.Context, name string, pemCert []byte, pemKey []byte) error {
s.StoreTokenPKIFuncInvoked = true
return s.StoreTokenPKIFunc(ctx, name, pemCert, pemKey)
}
func (s *Storage) RetrieveTokenPKI(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error) {
s.RetrieveTokenPKIFuncInvoked = true
return s.RetrieveTokenPKIFunc(ctx, name)
}
func (s *Storage) StoreAssignerProfile(ctx context.Context, name string, profileUUID string) error {
s.StoreAssignerProfileFuncInvoked = true
return s.StoreAssignerProfileFunc(ctx, name, profileUUID)
}

View file

@ -0,0 +1,215 @@
// Automatically generated by mockimpl. DO NOT EDIT!
package mock
import (
"context"
"crypto/tls"
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)
var _ storage.AllStorage = (*Storage)(nil)
type StoreAuthenticateFunc func(r *mdm.Request, msg *mdm.Authenticate) error
type StoreTokenUpdateFunc func(r *mdm.Request, msg *mdm.TokenUpdate) error
type StoreUserAuthenticateFunc func(r *mdm.Request, msg *mdm.UserAuthenticate) error
type DisableFunc func(r *mdm.Request) error
type StoreCommandReportFunc func(r *mdm.Request, report *mdm.CommandResults) error
type RetrieveNextCommandFunc func(r *mdm.Request, skipNotNow bool) (*mdm.Command, error)
type ClearQueueFunc func(r *mdm.Request) error
type StoreBootstrapTokenFunc func(r *mdm.Request, msg *mdm.SetBootstrapToken) error
type RetrieveBootstrapTokenFunc func(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error)
type RetrievePushInfoFunc func(p0 context.Context, p1 []string) (map[string]*mdm.Push, error)
type IsPushCertStaleFunc func(ctx context.Context, topic string, staleToken string) (bool, error)
type RetrievePushCertFunc func(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error)
type StorePushCertFunc func(ctx context.Context, pemCert []byte, pemKey []byte) error
type EnqueueCommandFunc func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error)
type HasCertHashFunc func(r *mdm.Request, hash string) (bool, error)
type EnrollmentHasCertHashFunc func(r *mdm.Request, hash string) (bool, error)
type IsCertHashAssociatedFunc func(r *mdm.Request, hash string) (bool, error)
type AssociateCertHashFunc func(r *mdm.Request, hash string) error
type RetrieveMigrationCheckinsFunc func(p0 context.Context, p1 chan<- interface{}) error
type RetrieveTokenUpdateTallyFunc func(ctx context.Context, id string) (int, error)
type Storage struct {
StoreAuthenticateFunc StoreAuthenticateFunc
StoreAuthenticateFuncInvoked bool
StoreTokenUpdateFunc StoreTokenUpdateFunc
StoreTokenUpdateFuncInvoked bool
StoreUserAuthenticateFunc StoreUserAuthenticateFunc
StoreUserAuthenticateFuncInvoked bool
DisableFunc DisableFunc
DisableFuncInvoked bool
StoreCommandReportFunc StoreCommandReportFunc
StoreCommandReportFuncInvoked bool
RetrieveNextCommandFunc RetrieveNextCommandFunc
RetrieveNextCommandFuncInvoked bool
ClearQueueFunc ClearQueueFunc
ClearQueueFuncInvoked bool
StoreBootstrapTokenFunc StoreBootstrapTokenFunc
StoreBootstrapTokenFuncInvoked bool
RetrieveBootstrapTokenFunc RetrieveBootstrapTokenFunc
RetrieveBootstrapTokenFuncInvoked bool
RetrievePushInfoFunc RetrievePushInfoFunc
RetrievePushInfoFuncInvoked bool
IsPushCertStaleFunc IsPushCertStaleFunc
IsPushCertStaleFuncInvoked bool
RetrievePushCertFunc RetrievePushCertFunc
RetrievePushCertFuncInvoked bool
StorePushCertFunc StorePushCertFunc
StorePushCertFuncInvoked bool
EnqueueCommandFunc EnqueueCommandFunc
EnqueueCommandFuncInvoked bool
HasCertHashFunc HasCertHashFunc
HasCertHashFuncInvoked bool
EnrollmentHasCertHashFunc EnrollmentHasCertHashFunc
EnrollmentHasCertHashFuncInvoked bool
IsCertHashAssociatedFunc IsCertHashAssociatedFunc
IsCertHashAssociatedFuncInvoked bool
AssociateCertHashFunc AssociateCertHashFunc
AssociateCertHashFuncInvoked bool
RetrieveMigrationCheckinsFunc RetrieveMigrationCheckinsFunc
RetrieveMigrationCheckinsFuncInvoked bool
RetrieveTokenUpdateTallyFunc RetrieveTokenUpdateTallyFunc
RetrieveTokenUpdateTallyFuncInvoked bool
}
func (s *Storage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error {
s.StoreAuthenticateFuncInvoked = true
return s.StoreAuthenticateFunc(r, msg)
}
func (s *Storage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error {
s.StoreTokenUpdateFuncInvoked = true
return s.StoreTokenUpdateFunc(r, msg)
}
func (s *Storage) StoreUserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) error {
s.StoreUserAuthenticateFuncInvoked = true
return s.StoreUserAuthenticateFunc(r, msg)
}
func (s *Storage) Disable(r *mdm.Request) error {
s.DisableFuncInvoked = true
return s.DisableFunc(r)
}
func (s *Storage) StoreCommandReport(r *mdm.Request, report *mdm.CommandResults) error {
s.StoreCommandReportFuncInvoked = true
return s.StoreCommandReportFunc(r, report)
}
func (s *Storage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) {
s.RetrieveNextCommandFuncInvoked = true
return s.RetrieveNextCommandFunc(r, skipNotNow)
}
func (s *Storage) ClearQueue(r *mdm.Request) error {
s.ClearQueueFuncInvoked = true
return s.ClearQueueFunc(r)
}
func (s *Storage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error {
s.StoreBootstrapTokenFuncInvoked = true
return s.StoreBootstrapTokenFunc(r, msg)
}
func (s *Storage) RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) {
s.RetrieveBootstrapTokenFuncInvoked = true
return s.RetrieveBootstrapTokenFunc(r, msg)
}
func (s *Storage) RetrievePushInfo(p0 context.Context, p1 []string) (map[string]*mdm.Push, error) {
s.RetrievePushInfoFuncInvoked = true
return s.RetrievePushInfoFunc(p0, p1)
}
func (s *Storage) IsPushCertStale(ctx context.Context, topic string, staleToken string) (bool, error) {
s.IsPushCertStaleFuncInvoked = true
return s.IsPushCertStaleFunc(ctx, topic, staleToken)
}
func (s *Storage) RetrievePushCert(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) {
s.RetrievePushCertFuncInvoked = true
return s.RetrievePushCertFunc(ctx, topic)
}
func (s *Storage) StorePushCert(ctx context.Context, pemCert []byte, pemKey []byte) error {
s.StorePushCertFuncInvoked = true
return s.StorePushCertFunc(ctx, pemCert, pemKey)
}
func (s *Storage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
s.EnqueueCommandFuncInvoked = true
return s.EnqueueCommandFunc(ctx, id, cmd)
}
func (s *Storage) HasCertHash(r *mdm.Request, hash string) (bool, error) {
s.HasCertHashFuncInvoked = true
return s.HasCertHashFunc(r, hash)
}
func (s *Storage) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) {
s.EnrollmentHasCertHashFuncInvoked = true
return s.EnrollmentHasCertHashFunc(r, hash)
}
func (s *Storage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) {
s.IsCertHashAssociatedFuncInvoked = true
return s.IsCertHashAssociatedFunc(r, hash)
}
func (s *Storage) AssociateCertHash(r *mdm.Request, hash string) error {
s.AssociateCertHashFuncInvoked = true
return s.AssociateCertHashFunc(r, hash)
}
func (s *Storage) RetrieveMigrationCheckins(p0 context.Context, p1 chan<- interface{}) error {
s.RetrieveMigrationCheckinsFuncInvoked = true
return s.RetrieveMigrationCheckinsFunc(p0, p1)
}
func (s *Storage) RetrieveTokenUpdateTally(ctx context.Context, id string) (int, error) {
s.RetrieveTokenUpdateTallyFuncInvoked = true
return s.RetrieveTokenUpdateTallyFunc(ctx, id)
}

55
server/mock/scep/depot.go Normal file
View file

@ -0,0 +1,55 @@
// Automatically generated by mockimpl. DO NOT EDIT!
package mock
import (
"crypto/rsa"
"crypto/x509"
"math/big"
"github.com/micromdm/scep/v2/depot"
)
var _ depot.Depot = (*Depot)(nil)
type CAFunc func(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error)
type PutFunc func(name string, crt *x509.Certificate) error
type SerialFunc func() (*big.Int, error)
type HasCNFunc func(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error)
type Depot struct {
CAFunc CAFunc
CAFuncInvoked bool
PutFunc PutFunc
PutFuncInvoked bool
SerialFunc SerialFunc
SerialFuncInvoked bool
HasCNFunc HasCNFunc
HasCNFuncInvoked bool
}
func (d *Depot) CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) {
d.CAFuncInvoked = true
return d.CAFunc(pass)
}
func (d *Depot) Put(name string, crt *x509.Certificate) error {
d.PutFuncInvoked = true
return d.PutFunc(name, crt)
}
func (d *Depot) Serial() (*big.Int, error) {
d.SerialFuncInvoked = true
return d.SerialFunc()
}
func (d *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) {
d.HasCNFuncInvoked = true
return d.HasCNFunc(cn, allowTime, cert, revokeOldCertificate)
}

View file

@ -121,7 +121,7 @@ func (svc *Service) setDEPProfile(ctx context.Context, enrollmentProfile *fleet.
depProfileRequest.URL = enrollURL
depProfileRequest.ConfigurationWebURL = enrollURL
depClient := fleet.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
depClient := apple_mdm.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
res, err := depClient.DefineProfile(ctx, apple_mdm.DEPName, &depProfileRequest)
if err != nil {
return ctxerr.Wrap(ctx, err, "apple POST /profile request failed")
@ -430,7 +430,7 @@ func (svc *Service) ListMDMAppleDEPDevices(ctx context.Context) ([]fleet.MDMAppl
if err := svc.authz.Authorize(ctx, &fleet.MDMAppleDEPDevice{}, fleet.ActionWrite); err != nil {
return nil, ctxerr.Wrap(ctx, err)
}
depClient := fleet.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
depClient := apple_mdm.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
// TODO(lucas): Use cursors and limit to fetch in multiple requests.
// This single-request version supports up to 1000 devices (max to return in one call).

View file

@ -3,9 +3,11 @@ package service
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
@ -13,50 +15,20 @@ import (
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
nanodep_mock "github.com/fleetdm/fleet/v4/server/mock/nanodep"
nanomdm_mock "github.com/fleetdm/fleet/v4/server/mock/nanomdm"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/google/uuid"
kitlog "github.com/go-kit/kit/log"
nanodep_client "github.com/micromdm/nanodep/client"
nanodep_storage "github.com/micromdm/nanodep/storage"
"github.com/micromdm/nanomdm/mdm"
nanomdm_push "github.com/micromdm/nanomdm/push"
"github.com/micromdm/nanomdm/storage"
nanomdm_pushsvc "github.com/micromdm/nanomdm/push/service"
"github.com/stretchr/testify/require"
)
type dummyDEPStorage struct {
nanodep_storage.AllStorage
testAuthAddr string
}
func (d dummyDEPStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
return &nanodep_client.OAuth1Tokens{}, nil
}
func (d dummyDEPStorage) RetrieveConfig(context.Context, string) (*nanodep_client.Config, error) {
return &nanodep_client.Config{
BaseURL: d.testAuthAddr,
}, nil
}
type dummyMDMStorage struct {
*mysql.NanoMDMStorage
}
func (d dummyMDMStorage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
return nil, nil
}
type dummyMDMPusher struct{}
func (d dummyMDMPusher) Push(context.Context, []string) (map[string]*nanomdm_push.Response, error) {
return nil, nil
}
func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorage nanodep_storage.AllStorage, mdmPusher nanomdm_push.Pusher) (fleet.Service, context.Context, *mock.Store) {
func setupAppleMDMService(t *testing.T) (fleet.Service, context.Context, *mock.Store) {
ds := new(mock.Store)
cfg := config.TestConfig()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -72,23 +44,55 @@ func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorag
}
}))
mdmStorage := &nanomdm_mock.Storage{}
depStorage := &nanodep_mock.Storage{}
pushFactory, _ := newMockAPNSPushProviderFactory()
pusher := nanomdm_pushsvc.New(
mdmStorage,
mdmStorage,
pushFactory,
NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)),
)
opts := &TestServerOpts{
FleetConfig: &cfg,
MDMStorage: dummyMDMStorage{},
DEPStorage: dummyDEPStorage{testAuthAddr: ts.URL},
MDMPusher: dummyMDMPusher{},
}
if mdmStorage != nil {
opts.MDMStorage = mdmStorage
}
if depStorage != nil {
opts.DEPStorage = depStorage
}
if mdmPusher != nil {
opts.MDMPusher = mdmPusher
MDMStorage: mdmStorage,
DEPStorage: depStorage,
MDMPusher: pusher,
}
svc, ctx := newTestServiceWithConfig(t, ds, cfg, nil, nil, opts)
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
return nil, nil
}
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, tokens []string) (map[string]*mdm.Push, error) {
res := make(map[string]*mdm.Push, len(tokens))
for _, t := range tokens {
res[t] = &mdm.Push{
PushMagic: "",
Token: []byte(t),
Topic: "",
}
}
return res, nil
}
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key")
return &cert, "", err
}
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
return false, nil
}
depStorage.RetrieveAuthTokensFunc = func(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
return &nanodep_client.OAuth1Tokens{}, nil
}
depStorage.RetrieveConfigFunc = func(context.Context, string) (*nanodep_client.Config, error) {
return &nanodep_client.Config{
BaseURL: ts.URL,
}, nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{
OrgInfo: fleet.OrgInfo{
@ -148,7 +152,7 @@ func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorag
}
func TestAppleMDMAuthorization(t *testing.T) {
svc, ctx, _ := setupAppleMDMService(t, nil, nil, nil)
svc, ctx, _ := setupAppleMDMService(t)
checkAuthErr := func(t *testing.T, err error, shouldFailWithAuth bool) {
t.Helper()
@ -252,7 +256,7 @@ func TestMDMAppleEnrollURL(t *testing.T) {
}
func TestAppleMDMEnrollmentProfile(t *testing.T) {
svc, ctx, _ := setupAppleMDMService(t, nil, nil, nil)
svc, ctx, _ := setupAppleMDMService(t)
// Only global admins can create enrollment profiles.
ctx = test.UserContext(ctx, test.UserAdmin)
@ -273,22 +277,8 @@ func TestAppleMDMEnrollmentProfile(t *testing.T) {
}
}
type noErrorPusher struct{}
// Push simulates successful push responses. The result maps each of the provided deviceIDs to a
// internally generated UUID, which is intended here to mock the APNs API response.
func (nep *noErrorPusher) Push(ctx context.Context, deviceIDs []string) (map[string]*nanomdm_push.Response, error) {
res := make(map[string]*nanomdm_push.Response)
for _, s := range deviceIDs {
res[s] = &nanomdm_push.Response{Id: uuid.New().String()}
}
return res, nil
}
func TestMDMCommandAuthz(t *testing.T) {
pusher := noErrorPusher{}
svc, ctx, ds := setupAppleMDMService(t, nil, nil, &pusher)
svc, ctx, ds := setupAppleMDMService(t)
ds.HostLiteFunc = func(ctx context.Context, hostID uint) (*fleet.Host, error) {
switch hostID {

View file

@ -668,7 +668,7 @@ func RegisterAppleMDMProtocolServices(
mux *http.ServeMux,
scepConfig config.MDMAppleSCEPConfig,
mdmStorage nanomdm_storage.AllStorage,
scepStorage *apple_mdm.SCEPMySQLDepot,
scepStorage scep_depot.Depot,
logger kitlog.Logger,
checkinAndCommandService nanomdm_service.CheckinAndCommandService,
) error {
@ -692,7 +692,7 @@ func registerSCEP(
scepConfig config.MDMAppleSCEPConfig,
scepCert *x509.Certificate,
scepKey *rsa.PrivateKey,
scepStorage *apple_mdm.SCEPMySQLDepot,
scepStorage scep_depot.Depot,
logger kitlog.Logger,
) error {
var signer scepserver.CSRSigner = scep_depot.NewSigner(

View file

@ -17,21 +17,30 @@ import (
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/push"
nanomdm_pushsvc "github.com/micromdm/nanomdm/push/service"
"sync/atomic"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/service/mock"
"github.com/fleetdm/fleet/v4/server/service/schedule"
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/google/uuid"
"github.com/groob/plist"
"github.com/jmoiron/sqlx"
nanodep_client "github.com/micromdm/nanodep/client"
"github.com/micromdm/nanodep/godep"
nanodep_storage "github.com/micromdm/nanodep/storage"
"github.com/micromdm/nanodep/tokenpki"
scepclient "github.com/micromdm/scep/v2/client"
"github.com/micromdm/scep/v2/cryptoutil/x509util"
@ -48,10 +57,13 @@ func TestIntegrationsMDM(t *testing.T) {
}
type integrationMDMTestSuite struct {
withServer
suite.Suite
withServer
fleetCfg config.FleetConfig
fleetDMFailCSR atomic.Bool
pushProvider *mock.APNSPushProvider
depStorage nanodep_storage.AllStorage
depSchedule *schedule.Schedule
}
func (s *integrationMDMTestSuite) SetupSuite() {
@ -69,9 +81,18 @@ func (s *integrationMDMTestSuite) SetupSuite() {
require.NoError(s.T(), err)
depStorage, err := s.ds.NewMDMAppleDEPStorage(*testBMToken)
require.NoError(s.T(), err)
scepStorage, err := s.ds.NewMDMAppleSCEPDepot(testCertPEM, testKeyPEM)
scepStorage, err := s.ds.NewSCEPDepot(testCertPEM, testKeyPEM)
require.NoError(s.T(), err)
pushFactory, pushProvider := newMockAPNSPushProviderFactory()
mdmPushService := nanomdm_pushsvc.New(
mdmStorage,
mdmStorage,
pushFactory,
NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)),
)
var depSchedule *schedule.Schedule
config := TestServerOpts{
License: &fleet.LicenseInfo{
Tier: fleet.TierPremium,
@ -80,7 +101,24 @@ func (s *integrationMDMTestSuite) SetupSuite() {
MDMStorage: mdmStorage,
DEPStorage: depStorage,
SCEPStorage: scepStorage,
MDMPusher: dummyMDMPusher{},
MDMPusher: mdmPushService,
StartCronSchedules: []TestNewScheduleFunc{
func(ctx context.Context, ds fleet.Datastore) fleet.NewCronScheduleFunc {
return func() (fleet.CronSchedule, error) {
const name = string(fleet.CronAppleMDMDEPProfileAssigner)
logger := kitlog.NewJSONLogger(os.Stdout)
fleetSyncer := apple_mdm.NewDEPSyncer(ds, depStorage, logger, true)
depSchedule = schedule.New(
ctx, name, s.T().Name(), 1*time.Hour, ds, ds,
schedule.WithLogger(logger),
schedule.WithJob("dep_syncer", func(ctx context.Context) error {
return fleetSyncer.Run(ctx)
}),
)
return depSchedule, nil
}
},
},
}
users, server := RunServerForTestsWithDS(s.T(), s.ds, &config)
s.server = server
@ -88,8 +126,11 @@ func (s *integrationMDMTestSuite) SetupSuite() {
s.token = s.getTestAdminToken()
s.cachedAdminToken = s.token
s.fleetCfg = fleetCfg
s.pushProvider = pushProvider
s.depStorage = depStorage
s.depSchedule = depSchedule
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fleetdmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.fleetDMFailCSR.Swap(false) {
// fail this call
w.WriteHeader(http.StatusBadRequest)
@ -99,8 +140,8 @@ func (s *integrationMDMTestSuite) SetupSuite() {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
s.T().Setenv("TEST_FLEETDM_API_URL", srv.URL)
s.T().Cleanup(srv.Close)
s.T().Setenv("TEST_FLEETDM_API_URL", fleetdmSrv.URL)
s.T().Cleanup(fleetdmSrv.Close)
}
func (s *integrationMDMTestSuite) FailNextCSRRequest() {
@ -115,6 +156,18 @@ func (s *integrationMDMTestSuite) TearDownTest() {
s.withServer.commonTearDownTest(s.T())
}
func (s *integrationMDMTestSuite) mockDEPResponse(handler http.Handler) {
t := s.T()
srv := httptest.NewServer(handler)
err := s.depStorage.StoreConfig(context.Background(), apple_mdm.DEPName, &nanodep_client.Config{BaseURL: srv.URL})
require.NoError(t, err)
t.Cleanup(func() {
srv.Close()
err := s.depStorage.StoreConfig(context.Background(), apple_mdm.DEPName, &nanodep_client.Config{BaseURL: nanodep_client.DefaultBaseURL})
require.NoError(t, err)
})
}
func (s *integrationMDMTestSuite) TestAppleGetAppleMDM() {
t := s.T()
@ -126,11 +179,173 @@ func (s *integrationMDMTestSuite) TestAppleGetAppleMDM() {
require.Equal(t, "FleetDM", mdmResp.CommonName)
require.NotZero(t, mdmResp.RenewDate)
// GET /api/latest/fleet/mdm/apple_bm is not tested because it makes a call
// to an Apple API that would a) fail because we use dummy token/certs and b)
// could get us in trouble with many invalid requests.
// TODO: eventually add a way to mock the apple API, maybe with a test http
// server running and a way to use its URL instead of Apple's. (#8948)
s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
switch r.URL.Path {
case "/session":
_, _ = w.Write([]byte(`{"auth_session_token": "xyz"}`))
case "/account":
_, _ = w.Write([]byte(`{"admin_id": "abc", "org_name": "test_org"}`))
}
}))
var getAppleBMResp getAppleBMResponse
s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusOK, &getAppleBMResp)
require.NoError(t, getAppleBMResp.Err)
require.Equal(t, "abc", getAppleBMResp.AppleID)
require.Equal(t, "test_org", getAppleBMResp.OrgName)
require.Equal(t, "https://example.org/mdm/apple/mdm", getAppleBMResp.MDMServerURL)
require.Empty(t, getAppleBMResp.DefaultTeam)
// create a new team
tm, err := s.ds.NewTeam(context.Background(), &fleet.Team{
Name: t.Name(),
Description: "desc",
})
require.NoError(t, err)
// set the default bm assignment to that team
acResp := appConfigResponse{}
s.DoJSON("PATCH", "/api/latest/fleet/config", json.RawMessage(fmt.Sprintf(`{
"mdm": {
"apple_bm_default_team": %q
}
}`, tm.Name)), http.StatusOK, &acResp)
// try again, this time we get a default team in the response
getAppleBMResp = getAppleBMResponse{}
s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusOK, &getAppleBMResp)
require.NoError(t, getAppleBMResp.Err)
require.Equal(t, "abc", getAppleBMResp.AppleID)
require.Equal(t, "test_org", getAppleBMResp.OrgName)
require.Equal(t, "https://example.org/mdm/apple/mdm", getAppleBMResp.MDMServerURL)
require.Equal(t, tm.Name, getAppleBMResp.DefaultTeam)
}
func (s *integrationMDMTestSuite) TestABMExpiredToken() {
t := s.T()
s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"code": "T_C_NOT_SIGNED"}`))
}))
config := s.getConfig()
require.False(t, config.MDM.AppleBMTermsExpired)
var getAppleBMResp getAppleBMResponse
s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusInternalServerError, &getAppleBMResp)
config = s.getConfig()
require.True(t, config.MDM.AppleBMTermsExpired)
}
func (s *integrationMDMTestSuite) TestDEPProfileAssignment() {
t := s.T()
devices := []godep.Device{
{SerialNumber: uuid.New().String(), Model: "MacBook Pro", OS: "osx", OpType: "added"},
{SerialNumber: uuid.New().String(), Model: "MacBook Mini", OS: "osx", OpType: "added"},
}
var wg sync.WaitGroup
wg.Add(2)
s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
encoder := json.NewEncoder(w)
switch r.URL.Path {
case "/session":
err := encoder.Encode(map[string]string{"auth_session_token": "xyz"})
require.NoError(t, err)
case "/profile":
err := encoder.Encode(godep.ProfileResponse{ProfileUUID: "abc"})
require.NoError(t, err)
case "/server/devices":
// This endpoint is used to get an initial list of
// devices, return a single device
err := encoder.Encode(godep.DeviceResponse{Devices: devices[:1]})
require.NoError(t, err)
case "/devices/sync":
// This endpoint is polled over time to sync devices from
// ABM, send a repeated serial and a new one
err := encoder.Encode(godep.DeviceResponse{Devices: devices})
require.NoError(t, err)
case "/profile/devices":
wg.Done()
_, _ = w.Write([]byte(`{}`))
default:
_, _ = w.Write([]byte(`{}`))
}
}))
// create a DEP enrollment profile
profile := json.RawMessage("{}")
var createProfileResp createMDMAppleEnrollmentProfileResponse
createProfileReq := createMDMAppleEnrollmentProfileRequest{
Type: "automatic",
DEPProfile: &profile,
}
s.DoJSON("POST", "/api/latest/fleet/mdm/apple/enrollmentprofiles", createProfileReq, http.StatusOK, &createProfileResp)
// query all hosts
listHostsRes := listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
require.Empty(t, listHostsRes.Hosts)
// trigger a profile sync
_, err := s.depSchedule.Trigger()
require.NoError(t, err)
wg.Wait()
// both hosts should be returned from the hosts endpoint
listHostsRes = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
require.Len(t, listHostsRes.Hosts, 2)
require.Equal(t, listHostsRes.Hosts[0].HardwareSerial, devices[0].SerialNumber)
require.Equal(t, listHostsRes.Hosts[1].HardwareSerial, devices[1].SerialNumber)
require.EqualValues(
t,
[]string{devices[0].SerialNumber, devices[1].SerialNumber},
[]string{listHostsRes.Hosts[0].HardwareSerial, listHostsRes.Hosts[1].HardwareSerial},
)
// create a new host
createHostAndDeviceToken(t, s.ds, "not-dep")
listHostsRes = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
require.Len(t, listHostsRes.Hosts, 3)
// filtering by MDM status works
listHostsRes = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts?mdm_enrollment_status=pending", nil, http.StatusOK, &listHostsRes)
require.Len(t, listHostsRes.Hosts, 2)
// enroll one of the hosts
d := newDevice(s)
d.serial = devices[0].SerialNumber
d.mdmEnroll(s)
// only one shows up as pending
listHostsRes = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts?mdm_enrollment_status=pending", nil, http.StatusOK, &listHostsRes)
require.Len(t, listHostsRes.Hosts, 1)
activities := listActivitiesResponse{}
s.DoJSON("GET", "/api/latest/fleet/activities", nil, http.StatusOK, &activities, "order_key", "created_at")
found := false
for _, activity := range activities.Activities {
if activity.Type == "mdm_enrolled" &&
strings.Contains(string(*activity.Details), devices[0].SerialNumber) {
found = true
require.Nil(t, activity.ActorID)
require.Nil(t, activity.ActorFullName)
require.JSONEq(
t,
fmt.Sprintf(
`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": true}`,
devices[0].SerialNumber, devices[0].Model, devices[0].SerialNumber,
),
string(*activity.Details),
)
}
}
require.True(t, found)
}
func (s *integrationMDMTestSuite) TestDeviceMDMManualEnroll() {
@ -164,7 +379,7 @@ func (s *integrationMDMTestSuite) TestDeviceMDMManualEnroll() {
require.Equal(t, apple_mdm.FleetPayloadIdentifier, profile.PayloadIdentifier)
}
func (s *integrationMDMTestSuite) TestDeviceEnrollment() {
func (s *integrationMDMTestSuite) TestAppleMDMDeviceEnrollment() {
t := s.T()
// Enroll two devices into MDM
@ -212,8 +427,8 @@ func (s *integrationMDMTestSuite) TestDeviceEnrollment() {
}
}
require.Len(t, details, 2)
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceA.serial, deviceA.model, deviceA.serial), string(*details[0]))
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceB.serial, deviceB.model, deviceB.serial), string(*details[1]))
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceA.serial, deviceA.model, deviceA.serial), string(*details[len(details)-2]))
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceB.serial, deviceB.model, deviceB.serial), string(*details[len(details)-1]))
// set an enroll secret
var applyResp applyEnrollSecretSpecResponse
@ -317,6 +532,60 @@ func (s *integrationMDMTestSuite) TestAppleMDMCSRRequest() {
require.Contains(t, string(reqCSRResp.SCEPKey), "-----BEGIN RSA PRIVATE KEY-----\n")
}
func (s *integrationMDMTestSuite) TestMDMAppleUnenroll() {
t := s.T()
// enroll into mdm
d := newMDMEnrolledDevice(s)
// set an enroll secret
var applyResp applyEnrollSecretSpecResponse
s.DoJSON("POST", "/api/latest/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{
Spec: &fleet.EnrollSecretSpec{
Secrets: []*fleet.EnrollSecret{{Secret: t.Name()}},
},
}, http.StatusOK, &applyResp)
// simulate a matching host enrolling via osquery
j, err := json.Marshal(&enrollAgentRequest{
EnrollSecret: t.Name(),
HostIdentifier: d.uuid,
})
require.NoError(t, err)
var enrollResp enrollAgentResponse
hres := s.DoRawNoAuth("POST", "/api/osquery/enroll", j, http.StatusOK)
defer hres.Body.Close()
require.NoError(t, json.NewDecoder(hres.Body).Decode(&enrollResp))
require.NotEmpty(t, enrollResp.NodeKey)
listHostsRes := listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
require.Len(t, listHostsRes.Hosts, 1)
h := listHostsRes.Hosts[0]
// try to unenroll the host, fails since the host doesn't respond
s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusGatewayTimeout)
// we're going to modify this mock, make sure we restore its default
originalPushMock := s.pushProvider.PushFunc
defer func() { s.pushProvider.PushFunc = originalPushMock }()
// TODO: this is not working as expected, we're still waiting on the
// device to unenroll even if we weren't able to send a push
// notification, we should return an error instead.
//
// the APNs service returns an error
// s.pushProvider.PushFunc = mockFailedPush
// s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK)
// try again, but this time the host is online and answers
s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
res, err := mockSuccessfulPush(pushes)
d.checkout()
return res, err
}
s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK)
}
type device struct {
uuid string
serial string
@ -338,13 +607,18 @@ func newDevice(s *integrationMDMTestSuite) *device {
func newMDMEnrolledDevice(s *integrationMDMTestSuite) *device {
d := newDevice(s)
d.scepEnroll()
d.authenticate()
d.mdmEnroll(s)
return d
}
func (d *device) mdmEnroll(s *integrationMDMTestSuite) {
d.scepEnroll()
d.authenticate()
d.tokenUpdate()
}
func (d *device) authenticate() {
payload := map[string]string{
payload := map[string]any{
"MessageType": "Authenticate",
"UDID": d.uuid,
"Model": d.model,
@ -356,8 +630,21 @@ func (d *device) authenticate() {
d.request("application/x-apple-aspen-mdm-checkin", payload)
}
func (d *device) tokenUpdate() {
payload := map[string]any{
"MessageType": "TokenUpdate",
"UDID": d.uuid,
"Topic": "com.apple.mgmt.External." + d.uuid,
"EnrollmentID": "testenrollmentid-" + d.uuid,
"NotOnConsole": "false",
"PushMagic": "pushmagic" + d.serial,
"Token": []byte("token" + d.serial),
}
d.request("application/x-apple-aspen-mdm-checkin", payload)
}
func (d *device) checkout() {
payload := map[string]string{
payload := map[string]any{
"MessageType": "CheckOut",
"Topic": "com.apple.mgmt.External." + d.uuid,
"UDID": d.uuid,
@ -366,7 +653,7 @@ func (d *device) checkout() {
d.request("application/x-apple-aspen-mdm-checkin", payload)
}
func (d *device) request(reqType string, payload map[string]string) {
func (d *device) request(reqType string, payload map[string]any) {
body, err := plist.Marshal(payload)
require.NoError(d.s.T(), err)

View file

@ -1,3 +1,5 @@
package mock
//go:generate mockimpl -o service_osquery.go "s *TLSService" "fleet.OsqueryService"
//go:generate mockimpl -o service_pusher_factory.go "s *APNSPushProviderFactory" "github.com/micromdm/nanomdm/push.PushProviderFactory"
//go:generate mockimpl -o service_push_provider.go "s *APNSPushProvider" "github.com/micromdm/nanomdm/push.PushProvider"

View file

@ -0,0 +1,22 @@
// Automatically generated by mockimpl. DO NOT EDIT!
package mock
import (
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/push"
)
var _ push.PushProvider = (*APNSPushProvider)(nil)
type PushFunc func(p0 []*mdm.Push) (map[string]*push.Response, error)
type APNSPushProvider struct {
PushFunc PushFunc
PushFuncInvoked bool
}
func (s *APNSPushProvider) Push(p0 []*mdm.Push) (map[string]*push.Response, error) {
s.PushFuncInvoked = true
return s.PushFunc(p0)
}

View file

@ -0,0 +1,23 @@
// Automatically generated by mockimpl. DO NOT EDIT!
package mock
import (
"crypto/tls"
"github.com/micromdm/nanomdm/push"
)
var _ push.PushProviderFactory = (*APNSPushProviderFactory)(nil)
type NewPushProviderFunc func(p0 *tls.Certificate) (push.PushProvider, error)
type APNSPushProviderFactory struct {
NewPushProviderFunc NewPushProviderFunc
NewPushProviderFuncInvoked bool
}
func (s *APNSPushProviderFactory) NewPushProvider(p0 *tls.Certificate) (push.PushProvider, error) {
s.NewPushProviderFuncInvoked = true
return s.NewPushProviderFunc(p0)
}

View file

@ -2,6 +2,7 @@ package service
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
@ -18,15 +19,19 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/logging"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service/async"
"github.com/fleetdm/fleet/v4/server/service/mock"
"github.com/fleetdm/fleet/v4/server/sso"
"github.com/fleetdm/fleet/v4/server/test"
kitlog "github.com/go-kit/kit/log"
"github.com/google/uuid"
nanodep_storage "github.com/micromdm/nanodep/storage"
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/push"
nanomdm_push "github.com/micromdm/nanomdm/push"
nanomdm_storage "github.com/micromdm/nanomdm/storage"
scep_depot "github.com/micromdm/scep/v2/depot"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/throttled/throttled/v2"
@ -238,7 +243,7 @@ type TestServerOpts struct {
FleetConfig *config.FleetConfig
MDMStorage nanomdm_storage.AllStorage
DEPStorage nanodep_storage.AllStorage
SCEPStorage *apple_mdm.SCEPMySQLDepot
SCEPStorage scep_depot.Depot
MDMPusher nanomdm_push.Pusher
HTTPServerConfig *http.Server
StartCronSchedules []TestNewScheduleFunc
@ -491,3 +496,25 @@ func (nopEnrollHostLimiter) CanEnrollNewHost(ctx context.Context) (bool, error)
func (nopEnrollHostLimiter) SyncEnrolledHostIDs(ctx context.Context) error {
return nil
}
func newMockAPNSPushProviderFactory() (*mock.APNSPushProviderFactory, *mock.APNSPushProvider) {
provider := &mock.APNSPushProvider{}
provider.PushFunc = mockSuccessfulPush
factory := &mock.APNSPushProviderFactory{}
factory.NewPushProviderFunc = func(*tls.Certificate) (push.PushProvider, error) {
return provider, nil
}
return factory, provider
}
func mockSuccessfulPush(pushes []*mdm.Push) (map[string]*push.Response, error) {
res := make(map[string]*push.Response, len(pushes))
for _, p := range pushes {
res[p.Token.String()] = &push.Response{
Id: uuid.New().String(),
Err: nil,
}
}
return res, nil
}