mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
add mocks + tests and move things around (#9574)
#8948 - Add more go:generate commands for MDM mocks - Add unit and integration tests for MDM code - Move interfaces from their PoC location to match existing patterns
This commit is contained in:
parent
f3642b18da
commit
4c4c114e96
25 changed files with 1160 additions and 278 deletions
6
Makefile
6
Makefile
|
|
@ -173,8 +173,8 @@ generate-dev: .prefix
|
|||
NODE_ENV=development webpack --progress --colors --watch
|
||||
|
||||
generate-mock: .prefix
|
||||
go install github.com/groob/mockimpl@latest
|
||||
go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult
|
||||
go install github.com/fleetdm/mockimpl@8d7943aa39d8f5f464d3d3618d9571d385f7bcc5
|
||||
go generate github.com/fleetdm/fleet/v4/server/mock github.com/fleetdm/fleet/v4/server/mock/mockresult github.com/fleetdm/fleet/v4/server/service/mock
|
||||
|
||||
generate-doc: .prefix
|
||||
go generate github.com/fleetdm/fleet/v4/server/fleet
|
||||
|
|
@ -403,4 +403,4 @@ db-replica-reset: fleet
|
|||
|
||||
# db-replica-run runs fleet serve with one main and one read MySQL instance.
|
||||
db-replica-run: fleet
|
||||
FLEET_MYSQL_ADDRESS=127.0.0.1:3308 FLEET_MYSQL_READ_REPLICA_ADDRESS=127.0.0.1:3309 FLEET_MYSQL_READ_REPLICA_USERNAME=fleet FLEET_MYSQL_READ_REPLICA_DATABASE=fleet FLEET_MYSQL_READ_REPLICA_PASSWORD=insecure ./build/fleet serve --dev --dev_license
|
||||
FLEET_MYSQL_ADDRESS=127.0.0.1:3308 FLEET_MYSQL_READ_REPLICA_ADDRESS=127.0.0.1:3309 FLEET_MYSQL_READ_REPLICA_USERNAME=fleet FLEET_MYSQL_READ_REPLICA_DATABASE=fleet FLEET_MYSQL_READ_REPLICA_PASSWORD=insecure ./build/fleet serve --dev --dev_license
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
|
||||
eewebhooks "github.com/fleetdm/fleet/v4/ee/server/webhooks"
|
||||
"github.com/fleetdm/fleet/v4/server"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
|
|
@ -17,7 +19,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/license"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/fleet/v4/server/policies"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service/externalsvc"
|
||||
|
|
@ -32,9 +33,6 @@ import (
|
|||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/micromdm/nanodep/godep"
|
||||
nanodep_log "github.com/micromdm/nanodep/log"
|
||||
depsync "github.com/micromdm/nanodep/sync"
|
||||
)
|
||||
|
||||
func errHandler(ctx context.Context, logger kitlog.Logger, msg string, err error) {
|
||||
|
|
@ -801,32 +799,6 @@ func trySendStatistics(ctx context.Context, ds fleet.Datastore, frequency time.D
|
|||
return ds.RecordStatisticsSent(ctx)
|
||||
}
|
||||
|
||||
// NanoDEPLogger is a logger adapter for nanodep.
|
||||
type NanoDEPLogger struct {
|
||||
logger kitlog.Logger
|
||||
}
|
||||
|
||||
func NewNanoDEPLogger(logger kitlog.Logger) *NanoDEPLogger {
|
||||
return &NanoDEPLogger{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *NanoDEPLogger) Info(keyvals ...interface{}) {
|
||||
level.Info(l.logger).Log(keyvals...)
|
||||
}
|
||||
|
||||
func (l *NanoDEPLogger) Debug(keyvals ...interface{}) {
|
||||
level.Debug(l.logger).Log(keyvals...)
|
||||
}
|
||||
|
||||
func (l *NanoDEPLogger) With(keyvals ...interface{}) nanodep_log.Logger {
|
||||
newLogger := kitlog.With(l.logger, keyvals...)
|
||||
return &NanoDEPLogger{
|
||||
logger: newLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// newAppleMDMDEPProfileAssigner creates the schedule to run the DEP syncer+assigner.
|
||||
// The DEP syncer+assigner fetches devices from Apple Business Manager (aka ABM) and applies
|
||||
// the current configured DEP profile to them.
|
||||
|
|
@ -840,65 +812,13 @@ func newAppleMDMDEPProfileAssigner(
|
|||
loggingDebug bool,
|
||||
) (*schedule.Schedule, error) {
|
||||
const name = string(fleet.CronAppleMDMDEPProfileAssigner)
|
||||
depClient := fleet.NewDEPClient(depStorage, ds, logger)
|
||||
assignerOpts := []depsync.AssignerOption{
|
||||
depsync.WithAssignerLogger(NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-assigner"))),
|
||||
}
|
||||
if loggingDebug {
|
||||
assignerOpts = append(assignerOpts, depsync.WithDebug())
|
||||
}
|
||||
assigner := depsync.NewAssigner(
|
||||
depClient,
|
||||
apple_mdm.DEPName,
|
||||
depStorage,
|
||||
assignerOpts...,
|
||||
)
|
||||
syncer := depsync.NewSyncer(
|
||||
depClient,
|
||||
apple_mdm.DEPName,
|
||||
depStorage,
|
||||
depsync.WithLogger(NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-syncer"))),
|
||||
depsync.WithCallback(func(ctx context.Context, isFetch bool, resp *godep.DeviceResponse) error {
|
||||
n, err := ds.IngestMDMAppleDevicesFromDEPSync(ctx, resp.Devices)
|
||||
switch {
|
||||
case err != nil:
|
||||
level.Error(kitlog.With(logger, "cron", name, "component", "nanodep-syncer")).Log("err", err)
|
||||
sentry.CaptureException(err)
|
||||
case n > 0:
|
||||
level.Info(kitlog.With(logger, "cron", name, "component", "nanodep-syncer")).Log("msg", fmt.Sprintf("added %d new mdm device(s) to pending hosts", n))
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
|
||||
return assigner.ProcessDeviceResponse(ctx, resp)
|
||||
}),
|
||||
)
|
||||
logger = kitlog.With(logger, "cron", name)
|
||||
logger = kitlog.With(logger, "cron", name, "component", "nanodep-syncer")
|
||||
fleetSyncer := apple_mdm.NewDEPSyncer(ds, depStorage, logger, loggingDebug)
|
||||
s := schedule.New(
|
||||
ctx, name, instanceID, periodicity, ds, ds,
|
||||
schedule.WithLogger(logger),
|
||||
schedule.WithJob("dep_syncer", func(ctx context.Context) error {
|
||||
profileUUID, profileModTime, err := depStorage.RetrieveAssignerProfile(ctx, apple_mdm.DEPName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if profileUUID == "" {
|
||||
logger.Log("msg", "DEP profile not set, nothing to do")
|
||||
return nil
|
||||
}
|
||||
cursor, cursorModTime, err := depStorage.RetrieveCursor(ctx, apple_mdm.DEPName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If the DEP Profile was changed since last sync then we clear
|
||||
// the cursor and perform a full sync of all devices and profile assigning.
|
||||
if cursor != "" && profileModTime.After(cursorModTime) {
|
||||
logger.Log("msg", "clearing device syncer cursor")
|
||||
if err := depStorage.StoreCursor(ctx, apple_mdm.DEPName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return syncer.Run(ctx)
|
||||
return fleetSyncer.Run(ctx)
|
||||
}),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
scep_depot "github.com/micromdm/scep/v2/depot"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/e-dard/netbug"
|
||||
"github.com/fleetdm/fleet/v4/ee/server/licensing"
|
||||
|
|
@ -40,7 +42,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/live_query"
|
||||
"github.com/fleetdm/fleet/v4/server/logging"
|
||||
"github.com/fleetdm/fleet/v4/server/mail"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
"github.com/fleetdm/fleet/v4/server/service"
|
||||
"github.com/fleetdm/fleet/v4/server/service/async"
|
||||
|
|
@ -468,7 +469,7 @@ the way that the Fleet server works.
|
|||
}
|
||||
|
||||
var (
|
||||
scepStorage *apple_mdm.SCEPMySQLDepot
|
||||
scepStorage scep_depot.Depot
|
||||
appleSCEPCertPEM []byte
|
||||
appleSCEPKeyPEM []byte
|
||||
appleAPNsCertPEM []byte
|
||||
|
|
@ -545,7 +546,7 @@ the way that the Fleet server works.
|
|||
initFatal(errors.New("Apple BM configuration must be provided to enable MDM"), "validate Apple MDM")
|
||||
}
|
||||
|
||||
scepStorage, err = mds.NewMDMAppleSCEPDepot(appleSCEPCertPEM, appleSCEPKeyPEM)
|
||||
scepStorage, err = mds.NewSCEPDepot(appleSCEPCertPEM, appleSCEPKeyPEM)
|
||||
if err != nil {
|
||||
initFatal(err, "initialize mdm apple scep storage")
|
||||
}
|
||||
|
|
|
|||
102
cmd/fleetctl/apple_mdm_test.go
Normal file
102
cmd/fleetctl/apple_mdm_test.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateMDMAppleBM(t *testing.T) {
|
||||
outdir, err := os.MkdirTemp("", t.Name())
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(outdir)
|
||||
publicKeyPath := filepath.Join(outdir, "public-key.crt")
|
||||
privateKeyPath := filepath.Join(outdir, "private-key.key")
|
||||
out := runAppForTest(t, []string{
|
||||
"generate", "mdm-apple-bm",
|
||||
"--public-key", publicKeyPath,
|
||||
"--private-key", privateKeyPath,
|
||||
})
|
||||
|
||||
require.Contains(t, out, fmt.Sprintf("Generated your public key at %s", outdir))
|
||||
require.Contains(t, out, fmt.Sprintf("Generated your private key at %s", outdir))
|
||||
|
||||
// validate that the keypair is valid
|
||||
cert, err := tls.LoadX509KeyPair(publicKeyPath, privateKeyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsed, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "FleetDM", parsed.Issuer.CommonName)
|
||||
}
|
||||
|
||||
func TestGenerateMDMApple(t *testing.T) {
|
||||
t.Run("missing input", func(t *testing.T) {
|
||||
runAppCheckErr(t, []string{"generate", "mdm-apple"}, `Required flags "email, org" not set`)
|
||||
runAppCheckErr(t, []string{"generate", "mdm-apple", "--email", "user@example.com"}, `Required flag "org" not set`)
|
||||
runAppCheckErr(t, []string{"generate", "mdm-apple", "--org", "Acme"}, `Required flag "email" not set`)
|
||||
})
|
||||
|
||||
t.Run("CSR API call fails", func(t *testing.T) {
|
||||
_, _ = runServerWithMockedDS(t)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// fail this call
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
}))
|
||||
t.Setenv("TEST_FLEETDM_API_URL", srv.URL)
|
||||
t.Cleanup(srv.Close)
|
||||
runAppCheckErr(
|
||||
t,
|
||||
[]string{
|
||||
"generate", "mdm-apple",
|
||||
"--email", "user@example.com",
|
||||
"--org", "Acme",
|
||||
},
|
||||
`POST /api/latest/fleet/mdm/apple/request_csr received status 502 Bad Gateway: FleetDM CSR request failed: api responded with 400: bad request`,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("successful run", func(t *testing.T) {
|
||||
_, _ = runServerWithMockedDS(t)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
t.Setenv("TEST_FLEETDM_API_URL", srv.URL)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
outdir, err := os.MkdirTemp("", "TestGenerateMDMApple")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(outdir)
|
||||
apnsKeyPath := filepath.Join(outdir, "apns.key")
|
||||
scepCertPath := filepath.Join(outdir, "scep.crt")
|
||||
scepKeyPath := filepath.Join(outdir, "scep.key")
|
||||
out := runAppForTest(t, []string{
|
||||
"generate", "mdm-apple",
|
||||
"--email", "user@example.com",
|
||||
"--org", "Acme",
|
||||
"--apns-key", apnsKeyPath,
|
||||
"--scep-cert", scepCertPath,
|
||||
"--scep-key", scepKeyPath,
|
||||
})
|
||||
|
||||
require.Contains(t, out, fmt.Sprintf("Generated your APNs key at %s", apnsKeyPath))
|
||||
require.Contains(t, out, fmt.Sprintf("Generated your SCEP certificate at %s", scepCertPath))
|
||||
require.Contains(t, out, fmt.Sprintf("Generated your SCEP key at %s", scepKeyPath))
|
||||
|
||||
// validate that the keypair is valid
|
||||
scepCrt, err := tls.LoadX509KeyPair(scepCertPath, scepKeyPath)
|
||||
require.NoError(t, err)
|
||||
parsed, err := x509.ParseCertificate(scepCrt.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "FleetDM", parsed.Issuer.CommonName)
|
||||
})
|
||||
}
|
||||
|
|
@ -47,7 +47,7 @@ func (svc *Service) GetAppleBM(ctx context.Context) (*fleet.AppleBM, error) {
|
|||
}
|
||||
|
||||
func getAppleBMAccountDetail(ctx context.Context, depStorage storage.AllStorage, ds fleet.Datastore, logger kitlog.Logger) (*fleet.AppleBM, error) {
|
||||
depClient := fleet.NewDEPClient(depStorage, ds, logger)
|
||||
depClient := apple_mdm.NewDEPClient(depStorage, ds, logger)
|
||||
res, err := depClient.AccountDetail(ctx, apple_mdm.DEPName)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "apple GET /account request failed")
|
||||
|
|
|
|||
|
|
@ -273,16 +273,19 @@ func ingestMDMAppleDeviceFromCheckinDB(
|
|||
case err != nil:
|
||||
return ctxerr.Wrap(ctx, err, "get mdm apple host by serial number or udid")
|
||||
|
||||
case foundHost.HardwareSerial != mdmHost.SerialNumber || foundHost.UUID != mdmHost.UDID:
|
||||
return updateMDMAppleHostDB(ctx, tx, foundHost.ID, mdmHost)
|
||||
|
||||
default:
|
||||
// ok, nothing to do here
|
||||
return nil
|
||||
return updateMDMAppleHostDB(ctx, tx, foundHost.ID, mdmHost, appCfg)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func updateMDMAppleHostDB(ctx context.Context, tx sqlx.ExtContext, hostID uint, mdmHost fleet.MDMAppleHostDetails) error {
|
||||
func updateMDMAppleHostDB(
|
||||
ctx context.Context,
|
||||
tx sqlx.ExtContext,
|
||||
hostID uint,
|
||||
mdmHost fleet.MDMAppleHostDetails,
|
||||
appCfg *fleet.AppConfig,
|
||||
) error {
|
||||
updateStmt := `
|
||||
UPDATE hosts SET
|
||||
hardware_serial = ?,
|
||||
|
|
@ -310,6 +313,10 @@ func updateMDMAppleHostDB(ctx context.Context, tx sqlx.ExtContext, hostID uint,
|
|||
return ctxerr.Wrap(ctx, err, "update mdm apple host")
|
||||
}
|
||||
|
||||
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, hostID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -365,7 +372,7 @@ func insertMDMAppleHostDB(
|
|||
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert label membership")
|
||||
}
|
||||
|
||||
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, host); err != nil {
|
||||
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, false, host.ID); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info")
|
||||
}
|
||||
return nil
|
||||
|
|
@ -468,7 +475,11 @@ func (ds *Datastore) IngestMDMAppleDevicesFromDEPSync(ctx context.Context, devic
|
|||
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert label membership")
|
||||
}
|
||||
|
||||
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, true, hosts...); err != nil {
|
||||
var ids []uint
|
||||
for _, h := range hosts {
|
||||
ids = append(ids, h.ID)
|
||||
}
|
||||
if err := upsertMDMAppleHostMDMInfoDB(ctx, tx, appCfg.ServerSettings.ServerURL, true, ids...); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "ingest mdm apple host upsert MDM info")
|
||||
}
|
||||
|
||||
|
|
@ -497,7 +508,7 @@ func upsertMDMAppleHostDisplayNamesDB(ctx context.Context, tx sqlx.ExtContext, h
|
|||
return nil
|
||||
}
|
||||
|
||||
func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, serverURL string, fromSync bool, hosts ...fleet.Host) error {
|
||||
func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, serverURL string, fromSync bool, hostIDs ...uint) error {
|
||||
result, err := tx.ExecContext(ctx, `
|
||||
INSERT INTO mobile_device_management_solutions (name, server_url) VALUES (?, ?)
|
||||
ON DUPLICATE KEY UPDATE server_url = VALUES(server_url)`,
|
||||
|
|
@ -522,8 +533,8 @@ func upsertMDMAppleHostMDMInfoDB(ctx context.Context, tx sqlx.ExtContext, server
|
|||
|
||||
args := []interface{}{}
|
||||
parts := []string{}
|
||||
for _, h := range hosts {
|
||||
args = append(args, enrolled, serverURL, fromSync, mdmID, false, h.ID)
|
||||
for _, id := range hostIDs {
|
||||
args = append(args, enrolled, serverURL, fromSync, mdmID, false, id)
|
||||
parts = append(parts, "(?, ?, ?, ?, ?, ?)")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/goose"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
|
|
@ -36,6 +35,7 @@ import (
|
|||
nanodep_client "github.com/micromdm/nanodep/client"
|
||||
nanodep_mysql "github.com/micromdm/nanodep/storage/mysql"
|
||||
nanomdm_mysql "github.com/micromdm/nanomdm/storage/mysql"
|
||||
scep_depot "github.com/micromdm/scep/v2/depot"
|
||||
"github.com/ngrok/sqlmw"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
)
|
||||
|
|
@ -107,14 +107,10 @@ func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.
|
|||
return stmt
|
||||
}
|
||||
|
||||
// NewMDMAppleSCEPDepot returns a *apple_mdm.MySQLDepot that uses the Datastore
|
||||
// NewMDMAppleSCEPDepot returns a scep_depot.Depot that uses the Datastore
|
||||
// underlying MySQL writer *sql.DB.
|
||||
func (ds *Datastore) NewMDMAppleSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (*apple_mdm.SCEPMySQLDepot, error) {
|
||||
depot, err := apple_mdm.NewSCEPMySQLDepot(ds.writer.DB, caCertPEM, caKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return depot, nil
|
||||
func (ds *Datastore) NewSCEPDepot(caCertPEM []byte, caKeyPEM []byte) (scep_depot.Depot, error) {
|
||||
return newSCEPDepot(ds.writer.DB, caCertPEM, caKeyPEM)
|
||||
}
|
||||
|
||||
// NewMDMAppleMDMStorage returns a MySQL nanomdm storage that uses the Datastore
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
package apple_mdm
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
|
|
@ -11,12 +11,13 @@ import (
|
|||
"fmt"
|
||||
"math/big"
|
||||
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/micromdm/nanomdm/cryptoutil"
|
||||
"github.com/micromdm/scep/v2/depot"
|
||||
)
|
||||
|
||||
// SCEPMySQLDepot is a MySQL-backed SCEP certificate depot.
|
||||
type SCEPMySQLDepot struct {
|
||||
// SCEPDepot is a MySQL-backed SCEP certificate depot.
|
||||
type SCEPDepot struct {
|
||||
db *sql.DB
|
||||
|
||||
// caCrt holds the CA's certificate.
|
||||
|
|
@ -25,10 +26,10 @@ type SCEPMySQLDepot struct {
|
|||
caKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
var _ depot.Depot = (*SCEPMySQLDepot)(nil)
|
||||
var _ depot.Depot = (*SCEPDepot)(nil)
|
||||
|
||||
// NewSCEPMySQLDepot creates and returns a *SCEPMySQLDepot.
|
||||
func NewSCEPMySQLDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPMySQLDepot, error) {
|
||||
// newSCEPDepot creates and returns a *SCEPDepot.
|
||||
func newSCEPDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPDepot, error) {
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -40,7 +41,7 @@ func NewSCEPMySQLDepot(db *sql.DB, caCertPEM []byte, caKeyPEM []byte) (*SCEPMySQ
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SCEPMySQLDepot{
|
||||
return &SCEPDepot{
|
||||
db: db,
|
||||
caCrt: caCrt,
|
||||
caKey: caKey,
|
||||
|
|
@ -56,12 +57,12 @@ func decodeRSAKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
|
|||
}
|
||||
|
||||
// CA returns the CA's certificate and private key.
|
||||
func (d *SCEPMySQLDepot) CA(_ []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
func (d *SCEPDepot) CA(_ []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
return []*x509.Certificate{d.caCrt}, d.caKey, nil
|
||||
}
|
||||
|
||||
// Serial allocates and returns a new (increasing) serial number.
|
||||
func (d *SCEPMySQLDepot) Serial() (*big.Int, error) {
|
||||
func (d *SCEPDepot) Serial() (*big.Int, error) {
|
||||
result, err := d.db.Exec(`INSERT INTO scep_serials () VALUES ();`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -78,7 +79,7 @@ func (d *SCEPMySQLDepot) Serial() (*big.Int, error) {
|
|||
// TODO(lucas): Implement and use allowTime and revokeOldCertificate.
|
||||
// - allowTime are the maximum days before expiration to allow clients to do certificate renewal.
|
||||
// - revokeOldCertificate specifies whether to revoke the old certificate once renewed.
|
||||
func (d *SCEPMySQLDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) {
|
||||
func (d *SCEPDepot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) {
|
||||
var ct int
|
||||
row := d.db.QueryRow(`SELECT COUNT(*) FROM scep_certificates WHERE name = ?`, cn)
|
||||
if err := row.Scan(&ct); err != nil {
|
||||
|
|
@ -91,14 +92,14 @@ func (d *SCEPMySQLDepot) HasCN(cn string, allowTime int, cert *x509.Certificate,
|
|||
//
|
||||
// If the provided certificate has empty crt.Subject.CommonName,
|
||||
// then the hex sha256 of the crt.Raw is used as name.
|
||||
func (d *SCEPMySQLDepot) Put(name string, crt *x509.Certificate) error {
|
||||
func (d *SCEPDepot) Put(name string, crt *x509.Certificate) error {
|
||||
if crt.Subject.CommonName == "" {
|
||||
name = fmt.Sprintf("%x", sha256.Sum256(crt.Raw))
|
||||
}
|
||||
if !crt.SerialNumber.IsInt64() {
|
||||
return errors.New("cannot represent serial number as int64")
|
||||
}
|
||||
certPEM := EncodeCertPEM(crt)
|
||||
certPEM := apple_mdm.EncodeCertPEM(crt)
|
||||
_, err := d.db.Exec(`
|
||||
INSERT INTO scep_certificates
|
||||
(serial, name, not_valid_before, not_valid_after, certificate_pem)
|
||||
|
|
@ -1,6 +1,4 @@
|
|||
// use a different package name to avoid
|
||||
// import cycle
|
||||
package apple_mdm_test
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
|
|
@ -8,19 +6,19 @@ import (
|
|||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/micromdm/nanodep/tokenpki"
|
||||
scep_depot "github.com/micromdm/scep/v2/depot"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setup(t *testing.T) *apple_mdm.SCEPMySQLDepot {
|
||||
ds := mysql.CreateNamedMySQLDS(t, t.Name())
|
||||
func setup(t *testing.T) scep_depot.Depot {
|
||||
ds := CreateNamedMySQLDS(t, t.Name())
|
||||
cert, key, err := apple_mdm.NewSCEPCACertKey()
|
||||
require.NoError(t, err)
|
||||
publicKeyPEM := tokenpki.PEMCertificate(cert.Raw)
|
||||
privateKeyPEM := tokenpki.PEMRSAPrivateKey(key)
|
||||
depot, err := ds.NewMDMAppleSCEPDepot(publicKeyPEM, privateKeyPEM)
|
||||
depot, err := ds.NewSCEPDepot(publicKeyPEM, privateKeyPEM)
|
||||
require.NoError(t, err)
|
||||
return depot
|
||||
}
|
||||
|
|
@ -3,11 +3,6 @@ package fleet
|
|||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/micromdm/nanodep/godep"
|
||||
)
|
||||
|
||||
type AppleMDM struct {
|
||||
|
|
@ -52,42 +47,3 @@ type AppConfigUpdater interface {
|
|||
AppConfig(ctx context.Context) (*AppConfig, error)
|
||||
SaveAppConfig(ctx context.Context, info *AppConfig) error
|
||||
}
|
||||
|
||||
// NewDEPClient creates an Apple DEP API HTTP client based on the provided
|
||||
// storage that will flag the AppConfig's AppleBMTermsExpired field whenever
|
||||
// the status of the terms changes.
|
||||
func NewDEPClient(storage godep.ClientStorage, appCfgUpdater AppConfigUpdater, logger kitlog.Logger) *godep.Client {
|
||||
return godep.NewClient(storage, fleethttp.NewClient(), godep.WithAfterHook(func(ctx context.Context, reqErr error) error {
|
||||
// if the request failed due to terms not signed, or if it succeeded,
|
||||
// update the app config flag accordingly. If it failed for any other
|
||||
// reason, do not update the flag.
|
||||
termsExpired := reqErr != nil && godep.IsTermsNotSigned(reqErr)
|
||||
if reqErr == nil || termsExpired {
|
||||
appCfg, err := appCfgUpdater.AppConfig(ctx)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "Apple DEP client: failed to get app config", "err", err)
|
||||
return reqErr
|
||||
}
|
||||
|
||||
var mustSaveAppCfg bool
|
||||
if termsExpired && !appCfg.MDM.AppleBMTermsExpired {
|
||||
// flag the AppConfig that the terms have changed and must be accepted
|
||||
appCfg.MDM.AppleBMTermsExpired = true
|
||||
mustSaveAppCfg = true
|
||||
} else if reqErr == nil && appCfg.MDM.AppleBMTermsExpired {
|
||||
// flag the AppConfig that the terms have been accepted
|
||||
appCfg.MDM.AppleBMTermsExpired = false
|
||||
mustSaveAppCfg = true
|
||||
}
|
||||
|
||||
if mustSaveAppCfg {
|
||||
if err := appCfgUpdater.SaveAppConfig(ctx, appCfg); err != nil {
|
||||
level.Error(logger).Log("msg", "Apple DEP client: failed to save app config", "err", err)
|
||||
}
|
||||
level.Debug(logger).Log("msg", "Apple DEP client: updated app config Terms Expired flag",
|
||||
"apple_bm_terms_expired", appCfg.MDM.AppleBMTermsExpired)
|
||||
}
|
||||
}
|
||||
return reqErr
|
||||
}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,25 +11,13 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
nanodep_mock "github.com/fleetdm/fleet/v4/server/mock/nanodep"
|
||||
"github.com/go-kit/log"
|
||||
nanodep_client "github.com/micromdm/nanodep/client"
|
||||
"github.com/micromdm/nanodep/godep"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockStorage struct {
|
||||
token string
|
||||
url string
|
||||
}
|
||||
|
||||
func (s mockStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
|
||||
return &nanodep_client.OAuth1Tokens{AccessToken: s.token}, nil
|
||||
}
|
||||
|
||||
func (s mockStorage) RetrieveConfig(context.Context, string) (*nanodep_client.Config, error) {
|
||||
return &nanodep_client.Config{BaseURL: s.url}, nil
|
||||
}
|
||||
|
||||
func TestDEPClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
@ -141,8 +129,15 @@ func TestDEPClient(t *testing.T) {
|
|||
for i, c := range cases {
|
||||
t.Logf("case %d", i)
|
||||
|
||||
store := mockStorage{token: c.token, url: srv.URL}
|
||||
dep := fleet.NewDEPClient(store, ds, logger)
|
||||
store := &nanodep_mock.Storage{}
|
||||
store.RetrieveAuthTokensFunc = func(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
|
||||
return &nanodep_client.OAuth1Tokens{AccessToken: c.token}, nil
|
||||
}
|
||||
store.RetrieveConfigFunc = func(context.Context, string) (*nanodep_client.Config, error) {
|
||||
return &nanodep_client.Config{BaseURL: srv.URL}, nil
|
||||
}
|
||||
|
||||
dep := apple_mdm.NewDEPClient(store, ds, logger)
|
||||
res, err := dep.AccountDetail(ctx, apple_mdm.DEPName)
|
||||
|
||||
if c.wantErr {
|
||||
|
|
@ -163,6 +158,8 @@ func TestDEPClient(t *testing.T) {
|
|||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", res.AdminID)
|
||||
require.True(t, store.RetrieveAuthTokensFuncInvoked)
|
||||
require.True(t, store.RetrieveConfigFuncInvoked)
|
||||
}
|
||||
checkDSCalled(c.readInvoked, c.writeInvoked)
|
||||
require.Equal(t, c.termsFlag, appCfg.MDM.AppleBMTermsExpired)
|
||||
|
|
|
|||
33
server/logging/nanodep.go
Normal file
33
server/logging/nanodep.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/log/level"
|
||||
nanodep_log "github.com/micromdm/nanodep/log"
|
||||
)
|
||||
|
||||
// NanoDEPLogger is a logger adapter for nanodep.
|
||||
type NanoDEPLogger struct {
|
||||
logger kitlog.Logger
|
||||
}
|
||||
|
||||
func NewNanoDEPLogger(logger kitlog.Logger) *NanoDEPLogger {
|
||||
return &NanoDEPLogger{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *NanoDEPLogger) Info(keyvals ...interface{}) {
|
||||
level.Info(l.logger).Log(keyvals...)
|
||||
}
|
||||
|
||||
func (l *NanoDEPLogger) Debug(keyvals ...interface{}) {
|
||||
level.Debug(l.logger).Log(keyvals...)
|
||||
}
|
||||
|
||||
func (l *NanoDEPLogger) With(keyvals ...interface{}) nanodep_log.Logger {
|
||||
newLogger := kitlog.With(l.logger, keyvals...)
|
||||
return &NanoDEPLogger{
|
||||
logger: newLogger,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,21 @@
|
|||
package apple_mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/logging"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/micromdm/nanodep/godep"
|
||||
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
nanodep_storage "github.com/micromdm/nanodep/storage"
|
||||
depsync "github.com/micromdm/nanodep/sync"
|
||||
)
|
||||
|
||||
// DEPName is the identifier/name used in nanodep MySQL storage which
|
||||
|
|
@ -47,3 +60,118 @@ func resolveURL(serverURL, relPath string) (string, error) {
|
|||
u.Path = path.Join(u.Path, relPath)
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
type DEPSyncer struct {
|
||||
depStorage nanodep_storage.AllStorage
|
||||
syncer *depsync.Syncer
|
||||
logger kitlog.Logger
|
||||
}
|
||||
|
||||
func (d *DEPSyncer) Run(ctx context.Context) error {
|
||||
profileUUID, profileModTime, err := d.depStorage.RetrieveAssignerProfile(ctx, DEPName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if profileUUID == "" {
|
||||
d.logger.Log("msg", "DEP profile not set, nothing to do")
|
||||
return nil
|
||||
}
|
||||
cursor, cursorModTime, err := d.depStorage.RetrieveCursor(ctx, DEPName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If the DEP Profile was changed since last sync then we clear
|
||||
// the cursor and perform a full sync of all devices and profile assigning.
|
||||
if cursor != "" && profileModTime.After(cursorModTime) {
|
||||
d.logger.Log("msg", "clearing device syncer cursor")
|
||||
if err := d.depStorage.StoreCursor(ctx, DEPName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return d.syncer.Run(ctx)
|
||||
}
|
||||
|
||||
func NewDEPSyncer(
|
||||
ds fleet.Datastore,
|
||||
depStorage nanodep_storage.AllStorage,
|
||||
logger kitlog.Logger,
|
||||
loggingDebug bool,
|
||||
) *DEPSyncer {
|
||||
depClient := NewDEPClient(depStorage, ds, logger)
|
||||
assignerOpts := []depsync.AssignerOption{
|
||||
depsync.WithAssignerLogger(logging.NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-assigner"))),
|
||||
}
|
||||
if loggingDebug {
|
||||
assignerOpts = append(assignerOpts, depsync.WithDebug())
|
||||
}
|
||||
assigner := depsync.NewAssigner(
|
||||
depClient,
|
||||
DEPName,
|
||||
depStorage,
|
||||
assignerOpts...,
|
||||
)
|
||||
|
||||
syncer := depsync.NewSyncer(
|
||||
depClient,
|
||||
DEPName,
|
||||
depStorage,
|
||||
depsync.WithLogger(logging.NewNanoDEPLogger(kitlog.With(logger, "component", "nanodep-syncer"))),
|
||||
depsync.WithCallback(func(ctx context.Context, isFetch bool, resp *godep.DeviceResponse) error {
|
||||
n, err := ds.IngestMDMAppleDevicesFromDEPSync(ctx, resp.Devices)
|
||||
switch {
|
||||
case err != nil:
|
||||
level.Error(kitlog.With(logger)).Log("err", err)
|
||||
sentry.CaptureException(err)
|
||||
case n > 0:
|
||||
level.Info(kitlog.With(logger)).Log("msg", fmt.Sprintf("added %d new mdm device(s) to pending hosts", n))
|
||||
}
|
||||
|
||||
return assigner.ProcessDeviceResponse(ctx, resp)
|
||||
}),
|
||||
)
|
||||
|
||||
return &DEPSyncer{
|
||||
syncer: syncer,
|
||||
depStorage: depStorage,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDEPClient creates an Apple DEP API HTTP client based on the provided
|
||||
// storage that will flag the AppConfig's AppleBMTermsExpired field
|
||||
// whenever the status of the terms changes.
|
||||
func NewDEPClient(storage godep.ClientStorage, appCfgUpdater fleet.AppConfigUpdater, logger kitlog.Logger) *godep.Client {
|
||||
return godep.NewClient(storage, fleethttp.NewClient(), godep.WithAfterHook(func(ctx context.Context, reqErr error) error {
|
||||
// if the request failed due to terms not signed, or if it succeeded,
|
||||
// update the app config flag accordingly. If it failed for any other
|
||||
// reason, do not update the flag.
|
||||
termsExpired := reqErr != nil && godep.IsTermsNotSigned(reqErr)
|
||||
if reqErr == nil || termsExpired {
|
||||
appCfg, err := appCfgUpdater.AppConfig(ctx)
|
||||
if err != nil {
|
||||
level.Error(logger).Log("msg", "Apple DEP client: failed to get app config", "err", err)
|
||||
return reqErr
|
||||
}
|
||||
|
||||
var mustSaveAppCfg bool
|
||||
if termsExpired && !appCfg.MDM.AppleBMTermsExpired {
|
||||
// flag the AppConfig that the terms have changed and must be accepted
|
||||
appCfg.MDM.AppleBMTermsExpired = true
|
||||
mustSaveAppCfg = true
|
||||
} else if reqErr == nil && appCfg.MDM.AppleBMTermsExpired {
|
||||
// flag the AppConfig that the terms have been accepted
|
||||
appCfg.MDM.AppleBMTermsExpired = false
|
||||
mustSaveAppCfg = true
|
||||
}
|
||||
|
||||
if mustSaveAppCfg {
|
||||
if err := appCfgUpdater.SaveAppConfig(ctx, appCfg); err != nil {
|
||||
level.Error(logger).Log("msg", "Apple DEP client: failed to save app config", "err", err)
|
||||
}
|
||||
level.Debug(logger).Log("msg", "Apple DEP client: updated app config Terms Expired flag",
|
||||
"apple_bm_terms_expired", appCfg.MDM.AppleBMTermsExpired)
|
||||
}
|
||||
}
|
||||
return reqErr
|
||||
}))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ import (
|
|||
|
||||
//go:generate mockimpl -o datastore_mock.go "s *DataStore" "fleet.Datastore"
|
||||
//go:generate mockimpl -o datastore_installers.go "s *InstallerStore" "fleet.InstallerStore"
|
||||
//go:generate mockimpl -o nanomdm/storage.go "s *Storage" "github.com/micromdm/nanomdm/storage.AllStorage"
|
||||
//go:generate mockimpl -o nanodep/storage.go "s *Storage" "github.com/micromdm/nanodep/storage.AllStorage"
|
||||
//go:generate mockimpl -o scep/depot.go "d *Depot" "depot.Depot"
|
||||
|
||||
var _ fleet.Datastore = (*Store)(nil)
|
||||
|
||||
|
|
|
|||
115
server/mock/nanodep/storage.go
Normal file
115
server/mock/nanodep/storage.go
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
// Automatically generated by mockimpl. DO NOT EDIT!
|
||||
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/micromdm/nanodep/client"
|
||||
"github.com/micromdm/nanodep/storage"
|
||||
)
|
||||
|
||||
var _ storage.AllStorage = (*Storage)(nil)
|
||||
|
||||
type RetrieveAuthTokensFunc func(ctx context.Context, name string) (*client.OAuth1Tokens, error)
|
||||
|
||||
type RetrieveConfigFunc func(p0 context.Context, p1 string) (*client.Config, error)
|
||||
|
||||
type RetrieveAssignerProfileFunc func(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error)
|
||||
|
||||
type RetrieveCursorFunc func(ctx context.Context, name string) (cursor string, modTime time.Time, err error)
|
||||
|
||||
type StoreCursorFunc func(ctx context.Context, name string, cursor string) error
|
||||
|
||||
type StoreAuthTokensFunc func(ctx context.Context, name string, tokens *client.OAuth1Tokens) error
|
||||
|
||||
type StoreConfigFunc func(ctx context.Context, name string, config *client.Config) error
|
||||
|
||||
type StoreTokenPKIFunc func(ctx context.Context, name string, pemCert []byte, pemKey []byte) error
|
||||
|
||||
type RetrieveTokenPKIFunc func(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error)
|
||||
|
||||
type StoreAssignerProfileFunc func(ctx context.Context, name string, profileUUID string) error
|
||||
|
||||
type Storage struct {
|
||||
RetrieveAuthTokensFunc RetrieveAuthTokensFunc
|
||||
RetrieveAuthTokensFuncInvoked bool
|
||||
|
||||
RetrieveConfigFunc RetrieveConfigFunc
|
||||
RetrieveConfigFuncInvoked bool
|
||||
|
||||
RetrieveAssignerProfileFunc RetrieveAssignerProfileFunc
|
||||
RetrieveAssignerProfileFuncInvoked bool
|
||||
|
||||
RetrieveCursorFunc RetrieveCursorFunc
|
||||
RetrieveCursorFuncInvoked bool
|
||||
|
||||
StoreCursorFunc StoreCursorFunc
|
||||
StoreCursorFuncInvoked bool
|
||||
|
||||
StoreAuthTokensFunc StoreAuthTokensFunc
|
||||
StoreAuthTokensFuncInvoked bool
|
||||
|
||||
StoreConfigFunc StoreConfigFunc
|
||||
StoreConfigFuncInvoked bool
|
||||
|
||||
StoreTokenPKIFunc StoreTokenPKIFunc
|
||||
StoreTokenPKIFuncInvoked bool
|
||||
|
||||
RetrieveTokenPKIFunc RetrieveTokenPKIFunc
|
||||
RetrieveTokenPKIFuncInvoked bool
|
||||
|
||||
StoreAssignerProfileFunc StoreAssignerProfileFunc
|
||||
StoreAssignerProfileFuncInvoked bool
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveAuthTokens(ctx context.Context, name string) (*client.OAuth1Tokens, error) {
|
||||
s.RetrieveAuthTokensFuncInvoked = true
|
||||
return s.RetrieveAuthTokensFunc(ctx, name)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveConfig(p0 context.Context, p1 string) (*client.Config, error) {
|
||||
s.RetrieveConfigFuncInvoked = true
|
||||
return s.RetrieveConfigFunc(p0, p1)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveAssignerProfile(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error) {
|
||||
s.RetrieveAssignerProfileFuncInvoked = true
|
||||
return s.RetrieveAssignerProfileFunc(ctx, name)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveCursor(ctx context.Context, name string) (cursor string, modTime time.Time, err error) {
|
||||
s.RetrieveCursorFuncInvoked = true
|
||||
return s.RetrieveCursorFunc(ctx, name)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreCursor(ctx context.Context, name string, cursor string) error {
|
||||
s.StoreCursorFuncInvoked = true
|
||||
return s.StoreCursorFunc(ctx, name, cursor)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreAuthTokens(ctx context.Context, name string, tokens *client.OAuth1Tokens) error {
|
||||
s.StoreAuthTokensFuncInvoked = true
|
||||
return s.StoreAuthTokensFunc(ctx, name, tokens)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreConfig(ctx context.Context, name string, config *client.Config) error {
|
||||
s.StoreConfigFuncInvoked = true
|
||||
return s.StoreConfigFunc(ctx, name, config)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreTokenPKI(ctx context.Context, name string, pemCert []byte, pemKey []byte) error {
|
||||
s.StoreTokenPKIFuncInvoked = true
|
||||
return s.StoreTokenPKIFunc(ctx, name, pemCert, pemKey)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveTokenPKI(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error) {
|
||||
s.RetrieveTokenPKIFuncInvoked = true
|
||||
return s.RetrieveTokenPKIFunc(ctx, name)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreAssignerProfile(ctx context.Context, name string, profileUUID string) error {
|
||||
s.StoreAssignerProfileFuncInvoked = true
|
||||
return s.StoreAssignerProfileFunc(ctx, name, profileUUID)
|
||||
}
|
||||
215
server/mock/nanomdm/storage.go
Normal file
215
server/mock/nanomdm/storage.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
// Automatically generated by mockimpl. DO NOT EDIT!
|
||||
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/micromdm/nanomdm/mdm"
|
||||
"github.com/micromdm/nanomdm/storage"
|
||||
)
|
||||
|
||||
var _ storage.AllStorage = (*Storage)(nil)
|
||||
|
||||
type StoreAuthenticateFunc func(r *mdm.Request, msg *mdm.Authenticate) error
|
||||
|
||||
type StoreTokenUpdateFunc func(r *mdm.Request, msg *mdm.TokenUpdate) error
|
||||
|
||||
type StoreUserAuthenticateFunc func(r *mdm.Request, msg *mdm.UserAuthenticate) error
|
||||
|
||||
type DisableFunc func(r *mdm.Request) error
|
||||
|
||||
type StoreCommandReportFunc func(r *mdm.Request, report *mdm.CommandResults) error
|
||||
|
||||
type RetrieveNextCommandFunc func(r *mdm.Request, skipNotNow bool) (*mdm.Command, error)
|
||||
|
||||
type ClearQueueFunc func(r *mdm.Request) error
|
||||
|
||||
type StoreBootstrapTokenFunc func(r *mdm.Request, msg *mdm.SetBootstrapToken) error
|
||||
|
||||
type RetrieveBootstrapTokenFunc func(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error)
|
||||
|
||||
type RetrievePushInfoFunc func(p0 context.Context, p1 []string) (map[string]*mdm.Push, error)
|
||||
|
||||
type IsPushCertStaleFunc func(ctx context.Context, topic string, staleToken string) (bool, error)
|
||||
|
||||
type RetrievePushCertFunc func(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error)
|
||||
|
||||
type StorePushCertFunc func(ctx context.Context, pemCert []byte, pemKey []byte) error
|
||||
|
||||
type EnqueueCommandFunc func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error)
|
||||
|
||||
type HasCertHashFunc func(r *mdm.Request, hash string) (bool, error)
|
||||
|
||||
type EnrollmentHasCertHashFunc func(r *mdm.Request, hash string) (bool, error)
|
||||
|
||||
type IsCertHashAssociatedFunc func(r *mdm.Request, hash string) (bool, error)
|
||||
|
||||
type AssociateCertHashFunc func(r *mdm.Request, hash string) error
|
||||
|
||||
type RetrieveMigrationCheckinsFunc func(p0 context.Context, p1 chan<- interface{}) error
|
||||
|
||||
type RetrieveTokenUpdateTallyFunc func(ctx context.Context, id string) (int, error)
|
||||
|
||||
type Storage struct {
|
||||
StoreAuthenticateFunc StoreAuthenticateFunc
|
||||
StoreAuthenticateFuncInvoked bool
|
||||
|
||||
StoreTokenUpdateFunc StoreTokenUpdateFunc
|
||||
StoreTokenUpdateFuncInvoked bool
|
||||
|
||||
StoreUserAuthenticateFunc StoreUserAuthenticateFunc
|
||||
StoreUserAuthenticateFuncInvoked bool
|
||||
|
||||
DisableFunc DisableFunc
|
||||
DisableFuncInvoked bool
|
||||
|
||||
StoreCommandReportFunc StoreCommandReportFunc
|
||||
StoreCommandReportFuncInvoked bool
|
||||
|
||||
RetrieveNextCommandFunc RetrieveNextCommandFunc
|
||||
RetrieveNextCommandFuncInvoked bool
|
||||
|
||||
ClearQueueFunc ClearQueueFunc
|
||||
ClearQueueFuncInvoked bool
|
||||
|
||||
StoreBootstrapTokenFunc StoreBootstrapTokenFunc
|
||||
StoreBootstrapTokenFuncInvoked bool
|
||||
|
||||
RetrieveBootstrapTokenFunc RetrieveBootstrapTokenFunc
|
||||
RetrieveBootstrapTokenFuncInvoked bool
|
||||
|
||||
RetrievePushInfoFunc RetrievePushInfoFunc
|
||||
RetrievePushInfoFuncInvoked bool
|
||||
|
||||
IsPushCertStaleFunc IsPushCertStaleFunc
|
||||
IsPushCertStaleFuncInvoked bool
|
||||
|
||||
RetrievePushCertFunc RetrievePushCertFunc
|
||||
RetrievePushCertFuncInvoked bool
|
||||
|
||||
StorePushCertFunc StorePushCertFunc
|
||||
StorePushCertFuncInvoked bool
|
||||
|
||||
EnqueueCommandFunc EnqueueCommandFunc
|
||||
EnqueueCommandFuncInvoked bool
|
||||
|
||||
HasCertHashFunc HasCertHashFunc
|
||||
HasCertHashFuncInvoked bool
|
||||
|
||||
EnrollmentHasCertHashFunc EnrollmentHasCertHashFunc
|
||||
EnrollmentHasCertHashFuncInvoked bool
|
||||
|
||||
IsCertHashAssociatedFunc IsCertHashAssociatedFunc
|
||||
IsCertHashAssociatedFuncInvoked bool
|
||||
|
||||
AssociateCertHashFunc AssociateCertHashFunc
|
||||
AssociateCertHashFuncInvoked bool
|
||||
|
||||
RetrieveMigrationCheckinsFunc RetrieveMigrationCheckinsFunc
|
||||
RetrieveMigrationCheckinsFuncInvoked bool
|
||||
|
||||
RetrieveTokenUpdateTallyFunc RetrieveTokenUpdateTallyFunc
|
||||
RetrieveTokenUpdateTallyFuncInvoked bool
|
||||
}
|
||||
|
||||
func (s *Storage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error {
|
||||
s.StoreAuthenticateFuncInvoked = true
|
||||
return s.StoreAuthenticateFunc(r, msg)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error {
|
||||
s.StoreTokenUpdateFuncInvoked = true
|
||||
return s.StoreTokenUpdateFunc(r, msg)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreUserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) error {
|
||||
s.StoreUserAuthenticateFuncInvoked = true
|
||||
return s.StoreUserAuthenticateFunc(r, msg)
|
||||
}
|
||||
|
||||
func (s *Storage) Disable(r *mdm.Request) error {
|
||||
s.DisableFuncInvoked = true
|
||||
return s.DisableFunc(r)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreCommandReport(r *mdm.Request, report *mdm.CommandResults) error {
|
||||
s.StoreCommandReportFuncInvoked = true
|
||||
return s.StoreCommandReportFunc(r, report)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) {
|
||||
s.RetrieveNextCommandFuncInvoked = true
|
||||
return s.RetrieveNextCommandFunc(r, skipNotNow)
|
||||
}
|
||||
|
||||
func (s *Storage) ClearQueue(r *mdm.Request) error {
|
||||
s.ClearQueueFuncInvoked = true
|
||||
return s.ClearQueueFunc(r)
|
||||
}
|
||||
|
||||
func (s *Storage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error {
|
||||
s.StoreBootstrapTokenFuncInvoked = true
|
||||
return s.StoreBootstrapTokenFunc(r, msg)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) {
|
||||
s.RetrieveBootstrapTokenFuncInvoked = true
|
||||
return s.RetrieveBootstrapTokenFunc(r, msg)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrievePushInfo(p0 context.Context, p1 []string) (map[string]*mdm.Push, error) {
|
||||
s.RetrievePushInfoFuncInvoked = true
|
||||
return s.RetrievePushInfoFunc(p0, p1)
|
||||
}
|
||||
|
||||
func (s *Storage) IsPushCertStale(ctx context.Context, topic string, staleToken string) (bool, error) {
|
||||
s.IsPushCertStaleFuncInvoked = true
|
||||
return s.IsPushCertStaleFunc(ctx, topic, staleToken)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrievePushCert(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) {
|
||||
s.RetrievePushCertFuncInvoked = true
|
||||
return s.RetrievePushCertFunc(ctx, topic)
|
||||
}
|
||||
|
||||
func (s *Storage) StorePushCert(ctx context.Context, pemCert []byte, pemKey []byte) error {
|
||||
s.StorePushCertFuncInvoked = true
|
||||
return s.StorePushCertFunc(ctx, pemCert, pemKey)
|
||||
}
|
||||
|
||||
func (s *Storage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
|
||||
s.EnqueueCommandFuncInvoked = true
|
||||
return s.EnqueueCommandFunc(ctx, id, cmd)
|
||||
}
|
||||
|
||||
func (s *Storage) HasCertHash(r *mdm.Request, hash string) (bool, error) {
|
||||
s.HasCertHashFuncInvoked = true
|
||||
return s.HasCertHashFunc(r, hash)
|
||||
}
|
||||
|
||||
func (s *Storage) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) {
|
||||
s.EnrollmentHasCertHashFuncInvoked = true
|
||||
return s.EnrollmentHasCertHashFunc(r, hash)
|
||||
}
|
||||
|
||||
func (s *Storage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) {
|
||||
s.IsCertHashAssociatedFuncInvoked = true
|
||||
return s.IsCertHashAssociatedFunc(r, hash)
|
||||
}
|
||||
|
||||
func (s *Storage) AssociateCertHash(r *mdm.Request, hash string) error {
|
||||
s.AssociateCertHashFuncInvoked = true
|
||||
return s.AssociateCertHashFunc(r, hash)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveMigrationCheckins(p0 context.Context, p1 chan<- interface{}) error {
|
||||
s.RetrieveMigrationCheckinsFuncInvoked = true
|
||||
return s.RetrieveMigrationCheckinsFunc(p0, p1)
|
||||
}
|
||||
|
||||
func (s *Storage) RetrieveTokenUpdateTally(ctx context.Context, id string) (int, error) {
|
||||
s.RetrieveTokenUpdateTallyFuncInvoked = true
|
||||
return s.RetrieveTokenUpdateTallyFunc(ctx, id)
|
||||
}
|
||||
55
server/mock/scep/depot.go
Normal file
55
server/mock/scep/depot.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// Automatically generated by mockimpl. DO NOT EDIT!
|
||||
|
||||
package mock
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"math/big"
|
||||
|
||||
"github.com/micromdm/scep/v2/depot"
|
||||
)
|
||||
|
||||
var _ depot.Depot = (*Depot)(nil)
|
||||
|
||||
type CAFunc func(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error)
|
||||
|
||||
type PutFunc func(name string, crt *x509.Certificate) error
|
||||
|
||||
type SerialFunc func() (*big.Int, error)
|
||||
|
||||
type HasCNFunc func(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error)
|
||||
|
||||
type Depot struct {
|
||||
CAFunc CAFunc
|
||||
CAFuncInvoked bool
|
||||
|
||||
PutFunc PutFunc
|
||||
PutFuncInvoked bool
|
||||
|
||||
SerialFunc SerialFunc
|
||||
SerialFuncInvoked bool
|
||||
|
||||
HasCNFunc HasCNFunc
|
||||
HasCNFuncInvoked bool
|
||||
}
|
||||
|
||||
func (d *Depot) CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
d.CAFuncInvoked = true
|
||||
return d.CAFunc(pass)
|
||||
}
|
||||
|
||||
func (d *Depot) Put(name string, crt *x509.Certificate) error {
|
||||
d.PutFuncInvoked = true
|
||||
return d.PutFunc(name, crt)
|
||||
}
|
||||
|
||||
func (d *Depot) Serial() (*big.Int, error) {
|
||||
d.SerialFuncInvoked = true
|
||||
return d.SerialFunc()
|
||||
}
|
||||
|
||||
func (d *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) {
|
||||
d.HasCNFuncInvoked = true
|
||||
return d.HasCNFunc(cn, allowTime, cert, revokeOldCertificate)
|
||||
}
|
||||
|
|
@ -121,7 +121,7 @@ func (svc *Service) setDEPProfile(ctx context.Context, enrollmentProfile *fleet.
|
|||
depProfileRequest.URL = enrollURL
|
||||
depProfileRequest.ConfigurationWebURL = enrollURL
|
||||
|
||||
depClient := fleet.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
|
||||
depClient := apple_mdm.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
|
||||
res, err := depClient.DefineProfile(ctx, apple_mdm.DEPName, &depProfileRequest)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "apple POST /profile request failed")
|
||||
|
|
@ -430,7 +430,7 @@ func (svc *Service) ListMDMAppleDEPDevices(ctx context.Context) ([]fleet.MDMAppl
|
|||
if err := svc.authz.Authorize(ctx, &fleet.MDMAppleDEPDevice{}, fleet.ActionWrite); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err)
|
||||
}
|
||||
depClient := fleet.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
|
||||
depClient := apple_mdm.NewDEPClient(svc.depStorage, svc.ds, svc.logger)
|
||||
|
||||
// TODO(lucas): Use cursors and limit to fetch in multiple requests.
|
||||
// This single-request version supports up to 1000 devices (max to return in one call).
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ package service
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
|
@ -13,50 +15,20 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
nanodep_mock "github.com/fleetdm/fleet/v4/server/mock/nanodep"
|
||||
nanomdm_mock "github.com/fleetdm/fleet/v4/server/mock/nanomdm"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/google/uuid"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
nanodep_client "github.com/micromdm/nanodep/client"
|
||||
nanodep_storage "github.com/micromdm/nanodep/storage"
|
||||
"github.com/micromdm/nanomdm/mdm"
|
||||
nanomdm_push "github.com/micromdm/nanomdm/push"
|
||||
"github.com/micromdm/nanomdm/storage"
|
||||
nanomdm_pushsvc "github.com/micromdm/nanomdm/push/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dummyDEPStorage struct {
|
||||
nanodep_storage.AllStorage
|
||||
testAuthAddr string
|
||||
}
|
||||
|
||||
func (d dummyDEPStorage) RetrieveAuthTokens(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
|
||||
return &nanodep_client.OAuth1Tokens{}, nil
|
||||
}
|
||||
|
||||
func (d dummyDEPStorage) RetrieveConfig(context.Context, string) (*nanodep_client.Config, error) {
|
||||
return &nanodep_client.Config{
|
||||
BaseURL: d.testAuthAddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dummyMDMStorage struct {
|
||||
*mysql.NanoMDMStorage
|
||||
}
|
||||
|
||||
func (d dummyMDMStorage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type dummyMDMPusher struct{}
|
||||
|
||||
func (d dummyMDMPusher) Push(context.Context, []string) (map[string]*nanomdm_push.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorage nanodep_storage.AllStorage, mdmPusher nanomdm_push.Pusher) (fleet.Service, context.Context, *mock.Store) {
|
||||
func setupAppleMDMService(t *testing.T) (fleet.Service, context.Context, *mock.Store) {
|
||||
ds := new(mock.Store)
|
||||
cfg := config.TestConfig()
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -72,23 +44,55 @@ func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorag
|
|||
}
|
||||
}))
|
||||
|
||||
mdmStorage := &nanomdm_mock.Storage{}
|
||||
depStorage := &nanodep_mock.Storage{}
|
||||
pushFactory, _ := newMockAPNSPushProviderFactory()
|
||||
pusher := nanomdm_pushsvc.New(
|
||||
mdmStorage,
|
||||
mdmStorage,
|
||||
pushFactory,
|
||||
NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)),
|
||||
)
|
||||
|
||||
opts := &TestServerOpts{
|
||||
FleetConfig: &cfg,
|
||||
MDMStorage: dummyMDMStorage{},
|
||||
DEPStorage: dummyDEPStorage{testAuthAddr: ts.URL},
|
||||
MDMPusher: dummyMDMPusher{},
|
||||
}
|
||||
if mdmStorage != nil {
|
||||
opts.MDMStorage = mdmStorage
|
||||
}
|
||||
if depStorage != nil {
|
||||
opts.DEPStorage = depStorage
|
||||
}
|
||||
if mdmPusher != nil {
|
||||
opts.MDMPusher = mdmPusher
|
||||
MDMStorage: mdmStorage,
|
||||
DEPStorage: depStorage,
|
||||
MDMPusher: pusher,
|
||||
}
|
||||
svc, ctx := newTestServiceWithConfig(t, ds, cfg, nil, nil, opts)
|
||||
|
||||
mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
|
||||
return nil, nil
|
||||
}
|
||||
mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, tokens []string) (map[string]*mdm.Push, error) {
|
||||
res := make(map[string]*mdm.Push, len(tokens))
|
||||
for _, t := range tokens {
|
||||
res[t] = &mdm.Push{
|
||||
PushMagic: "",
|
||||
Token: []byte(t),
|
||||
Topic: "",
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) {
|
||||
cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key")
|
||||
return &cert, "", err
|
||||
}
|
||||
mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
depStorage.RetrieveAuthTokensFunc = func(ctx context.Context, name string) (*nanodep_client.OAuth1Tokens, error) {
|
||||
return &nanodep_client.OAuth1Tokens{}, nil
|
||||
}
|
||||
depStorage.RetrieveConfigFunc = func(context.Context, string) (*nanodep_client.Config, error) {
|
||||
return &nanodep_client.Config{
|
||||
BaseURL: ts.URL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
return &fleet.AppConfig{
|
||||
OrgInfo: fleet.OrgInfo{
|
||||
|
|
@ -148,7 +152,7 @@ func setupAppleMDMService(t *testing.T, mdmStorage storage.AllStorage, depStorag
|
|||
}
|
||||
|
||||
func TestAppleMDMAuthorization(t *testing.T) {
|
||||
svc, ctx, _ := setupAppleMDMService(t, nil, nil, nil)
|
||||
svc, ctx, _ := setupAppleMDMService(t)
|
||||
|
||||
checkAuthErr := func(t *testing.T, err error, shouldFailWithAuth bool) {
|
||||
t.Helper()
|
||||
|
|
@ -252,7 +256,7 @@ func TestMDMAppleEnrollURL(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAppleMDMEnrollmentProfile(t *testing.T) {
|
||||
svc, ctx, _ := setupAppleMDMService(t, nil, nil, nil)
|
||||
svc, ctx, _ := setupAppleMDMService(t)
|
||||
|
||||
// Only global admins can create enrollment profiles.
|
||||
ctx = test.UserContext(ctx, test.UserAdmin)
|
||||
|
|
@ -273,22 +277,8 @@ func TestAppleMDMEnrollmentProfile(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type noErrorPusher struct{}
|
||||
|
||||
// Push simulates successful push responses. The result maps each of the provided deviceIDs to a
|
||||
// internally generated UUID, which is intended here to mock the APNs API response.
|
||||
func (nep *noErrorPusher) Push(ctx context.Context, deviceIDs []string) (map[string]*nanomdm_push.Response, error) {
|
||||
res := make(map[string]*nanomdm_push.Response)
|
||||
for _, s := range deviceIDs {
|
||||
res[s] = &nanomdm_push.Response{Id: uuid.New().String()}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func TestMDMCommandAuthz(t *testing.T) {
|
||||
pusher := noErrorPusher{}
|
||||
|
||||
svc, ctx, ds := setupAppleMDMService(t, nil, nil, &pusher)
|
||||
svc, ctx, ds := setupAppleMDMService(t)
|
||||
|
||||
ds.HostLiteFunc = func(ctx context.Context, hostID uint) (*fleet.Host, error) {
|
||||
switch hostID {
|
||||
|
|
|
|||
|
|
@ -668,7 +668,7 @@ func RegisterAppleMDMProtocolServices(
|
|||
mux *http.ServeMux,
|
||||
scepConfig config.MDMAppleSCEPConfig,
|
||||
mdmStorage nanomdm_storage.AllStorage,
|
||||
scepStorage *apple_mdm.SCEPMySQLDepot,
|
||||
scepStorage scep_depot.Depot,
|
||||
logger kitlog.Logger,
|
||||
checkinAndCommandService nanomdm_service.CheckinAndCommandService,
|
||||
) error {
|
||||
|
|
@ -692,7 +692,7 @@ func registerSCEP(
|
|||
scepConfig config.MDMAppleSCEPConfig,
|
||||
scepCert *x509.Certificate,
|
||||
scepKey *rsa.PrivateKey,
|
||||
scepStorage *apple_mdm.SCEPMySQLDepot,
|
||||
scepStorage scep_depot.Depot,
|
||||
logger kitlog.Logger,
|
||||
) error {
|
||||
var signer scepserver.CSRSigner = scep_depot.NewSigner(
|
||||
|
|
|
|||
|
|
@ -17,21 +17,30 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/micromdm/nanomdm/mdm"
|
||||
"github.com/micromdm/nanomdm/push"
|
||||
nanomdm_pushsvc "github.com/micromdm/nanomdm/push/service"
|
||||
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/fleet/v4/server/service/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/service/schedule"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/google/uuid"
|
||||
"github.com/groob/plist"
|
||||
"github.com/jmoiron/sqlx"
|
||||
nanodep_client "github.com/micromdm/nanodep/client"
|
||||
"github.com/micromdm/nanodep/godep"
|
||||
nanodep_storage "github.com/micromdm/nanodep/storage"
|
||||
"github.com/micromdm/nanodep/tokenpki"
|
||||
scepclient "github.com/micromdm/scep/v2/client"
|
||||
"github.com/micromdm/scep/v2/cryptoutil/x509util"
|
||||
|
|
@ -48,10 +57,13 @@ func TestIntegrationsMDM(t *testing.T) {
|
|||
}
|
||||
|
||||
type integrationMDMTestSuite struct {
|
||||
withServer
|
||||
suite.Suite
|
||||
withServer
|
||||
fleetCfg config.FleetConfig
|
||||
fleetDMFailCSR atomic.Bool
|
||||
pushProvider *mock.APNSPushProvider
|
||||
depStorage nanodep_storage.AllStorage
|
||||
depSchedule *schedule.Schedule
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) SetupSuite() {
|
||||
|
|
@ -69,9 +81,18 @@ func (s *integrationMDMTestSuite) SetupSuite() {
|
|||
require.NoError(s.T(), err)
|
||||
depStorage, err := s.ds.NewMDMAppleDEPStorage(*testBMToken)
|
||||
require.NoError(s.T(), err)
|
||||
scepStorage, err := s.ds.NewMDMAppleSCEPDepot(testCertPEM, testKeyPEM)
|
||||
scepStorage, err := s.ds.NewSCEPDepot(testCertPEM, testKeyPEM)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
pushFactory, pushProvider := newMockAPNSPushProviderFactory()
|
||||
mdmPushService := nanomdm_pushsvc.New(
|
||||
mdmStorage,
|
||||
mdmStorage,
|
||||
pushFactory,
|
||||
NewNanoMDMLogger(kitlog.NewJSONLogger(os.Stdout)),
|
||||
)
|
||||
|
||||
var depSchedule *schedule.Schedule
|
||||
config := TestServerOpts{
|
||||
License: &fleet.LicenseInfo{
|
||||
Tier: fleet.TierPremium,
|
||||
|
|
@ -80,7 +101,24 @@ func (s *integrationMDMTestSuite) SetupSuite() {
|
|||
MDMStorage: mdmStorage,
|
||||
DEPStorage: depStorage,
|
||||
SCEPStorage: scepStorage,
|
||||
MDMPusher: dummyMDMPusher{},
|
||||
MDMPusher: mdmPushService,
|
||||
StartCronSchedules: []TestNewScheduleFunc{
|
||||
func(ctx context.Context, ds fleet.Datastore) fleet.NewCronScheduleFunc {
|
||||
return func() (fleet.CronSchedule, error) {
|
||||
const name = string(fleet.CronAppleMDMDEPProfileAssigner)
|
||||
logger := kitlog.NewJSONLogger(os.Stdout)
|
||||
fleetSyncer := apple_mdm.NewDEPSyncer(ds, depStorage, logger, true)
|
||||
depSchedule = schedule.New(
|
||||
ctx, name, s.T().Name(), 1*time.Hour, ds, ds,
|
||||
schedule.WithLogger(logger),
|
||||
schedule.WithJob("dep_syncer", func(ctx context.Context) error {
|
||||
return fleetSyncer.Run(ctx)
|
||||
}),
|
||||
)
|
||||
return depSchedule, nil
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
users, server := RunServerForTestsWithDS(s.T(), s.ds, &config)
|
||||
s.server = server
|
||||
|
|
@ -88,8 +126,11 @@ func (s *integrationMDMTestSuite) SetupSuite() {
|
|||
s.token = s.getTestAdminToken()
|
||||
s.cachedAdminToken = s.token
|
||||
s.fleetCfg = fleetCfg
|
||||
s.pushProvider = pushProvider
|
||||
s.depStorage = depStorage
|
||||
s.depSchedule = depSchedule
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fleetdmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if s.fleetDMFailCSR.Swap(false) {
|
||||
// fail this call
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
|
@ -99,8 +140,8 @@ func (s *integrationMDMTestSuite) SetupSuite() {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
s.T().Setenv("TEST_FLEETDM_API_URL", srv.URL)
|
||||
s.T().Cleanup(srv.Close)
|
||||
s.T().Setenv("TEST_FLEETDM_API_URL", fleetdmSrv.URL)
|
||||
s.T().Cleanup(fleetdmSrv.Close)
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) FailNextCSRRequest() {
|
||||
|
|
@ -115,6 +156,18 @@ func (s *integrationMDMTestSuite) TearDownTest() {
|
|||
s.withServer.commonTearDownTest(s.T())
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) mockDEPResponse(handler http.Handler) {
|
||||
t := s.T()
|
||||
srv := httptest.NewServer(handler)
|
||||
err := s.depStorage.StoreConfig(context.Background(), apple_mdm.DEPName, &nanodep_client.Config{BaseURL: srv.URL})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
srv.Close()
|
||||
err := s.depStorage.StoreConfig(context.Background(), apple_mdm.DEPName, &nanodep_client.Config{BaseURL: nanodep_client.DefaultBaseURL})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) TestAppleGetAppleMDM() {
|
||||
t := s.T()
|
||||
|
||||
|
|
@ -126,11 +179,173 @@ func (s *integrationMDMTestSuite) TestAppleGetAppleMDM() {
|
|||
require.Equal(t, "FleetDM", mdmResp.CommonName)
|
||||
require.NotZero(t, mdmResp.RenewDate)
|
||||
|
||||
// GET /api/latest/fleet/mdm/apple_bm is not tested because it makes a call
|
||||
// to an Apple API that would a) fail because we use dummy token/certs and b)
|
||||
// could get us in trouble with many invalid requests.
|
||||
// TODO: eventually add a way to mock the apple API, maybe with a test http
|
||||
// server running and a way to use its URL instead of Apple's. (#8948)
|
||||
s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
switch r.URL.Path {
|
||||
case "/session":
|
||||
_, _ = w.Write([]byte(`{"auth_session_token": "xyz"}`))
|
||||
case "/account":
|
||||
_, _ = w.Write([]byte(`{"admin_id": "abc", "org_name": "test_org"}`))
|
||||
}
|
||||
}))
|
||||
var getAppleBMResp getAppleBMResponse
|
||||
s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusOK, &getAppleBMResp)
|
||||
require.NoError(t, getAppleBMResp.Err)
|
||||
require.Equal(t, "abc", getAppleBMResp.AppleID)
|
||||
require.Equal(t, "test_org", getAppleBMResp.OrgName)
|
||||
require.Equal(t, "https://example.org/mdm/apple/mdm", getAppleBMResp.MDMServerURL)
|
||||
require.Empty(t, getAppleBMResp.DefaultTeam)
|
||||
|
||||
// create a new team
|
||||
tm, err := s.ds.NewTeam(context.Background(), &fleet.Team{
|
||||
Name: t.Name(),
|
||||
Description: "desc",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// set the default bm assignment to that team
|
||||
acResp := appConfigResponse{}
|
||||
s.DoJSON("PATCH", "/api/latest/fleet/config", json.RawMessage(fmt.Sprintf(`{
|
||||
"mdm": {
|
||||
"apple_bm_default_team": %q
|
||||
}
|
||||
}`, tm.Name)), http.StatusOK, &acResp)
|
||||
|
||||
// try again, this time we get a default team in the response
|
||||
getAppleBMResp = getAppleBMResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusOK, &getAppleBMResp)
|
||||
require.NoError(t, getAppleBMResp.Err)
|
||||
require.Equal(t, "abc", getAppleBMResp.AppleID)
|
||||
require.Equal(t, "test_org", getAppleBMResp.OrgName)
|
||||
require.Equal(t, "https://example.org/mdm/apple/mdm", getAppleBMResp.MDMServerURL)
|
||||
require.Equal(t, tm.Name, getAppleBMResp.DefaultTeam)
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) TestABMExpiredToken() {
|
||||
t := s.T()
|
||||
s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"code": "T_C_NOT_SIGNED"}`))
|
||||
}))
|
||||
|
||||
config := s.getConfig()
|
||||
require.False(t, config.MDM.AppleBMTermsExpired)
|
||||
|
||||
var getAppleBMResp getAppleBMResponse
|
||||
s.DoJSON("GET", "/api/latest/fleet/mdm/apple_bm", nil, http.StatusInternalServerError, &getAppleBMResp)
|
||||
|
||||
config = s.getConfig()
|
||||
require.True(t, config.MDM.AppleBMTermsExpired)
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) TestDEPProfileAssignment() {
|
||||
t := s.T()
|
||||
devices := []godep.Device{
|
||||
{SerialNumber: uuid.New().String(), Model: "MacBook Pro", OS: "osx", OpType: "added"},
|
||||
{SerialNumber: uuid.New().String(), Model: "MacBook Mini", OS: "osx", OpType: "added"},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
s.mockDEPResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
encoder := json.NewEncoder(w)
|
||||
switch r.URL.Path {
|
||||
case "/session":
|
||||
err := encoder.Encode(map[string]string{"auth_session_token": "xyz"})
|
||||
require.NoError(t, err)
|
||||
case "/profile":
|
||||
err := encoder.Encode(godep.ProfileResponse{ProfileUUID: "abc"})
|
||||
require.NoError(t, err)
|
||||
case "/server/devices":
|
||||
// This endpoint is used to get an initial list of
|
||||
// devices, return a single device
|
||||
err := encoder.Encode(godep.DeviceResponse{Devices: devices[:1]})
|
||||
require.NoError(t, err)
|
||||
case "/devices/sync":
|
||||
// This endpoint is polled over time to sync devices from
|
||||
// ABM, send a repeated serial and a new one
|
||||
err := encoder.Encode(godep.DeviceResponse{Devices: devices})
|
||||
require.NoError(t, err)
|
||||
case "/profile/devices":
|
||||
wg.Done()
|
||||
_, _ = w.Write([]byte(`{}`))
|
||||
default:
|
||||
_, _ = w.Write([]byte(`{}`))
|
||||
}
|
||||
}))
|
||||
|
||||
// create a DEP enrollment profile
|
||||
profile := json.RawMessage("{}")
|
||||
var createProfileResp createMDMAppleEnrollmentProfileResponse
|
||||
createProfileReq := createMDMAppleEnrollmentProfileRequest{
|
||||
Type: "automatic",
|
||||
DEPProfile: &profile,
|
||||
}
|
||||
s.DoJSON("POST", "/api/latest/fleet/mdm/apple/enrollmentprofiles", createProfileReq, http.StatusOK, &createProfileResp)
|
||||
|
||||
// query all hosts
|
||||
listHostsRes := listHostsResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
|
||||
require.Empty(t, listHostsRes.Hosts)
|
||||
|
||||
// trigger a profile sync
|
||||
_, err := s.depSchedule.Trigger()
|
||||
require.NoError(t, err)
|
||||
wg.Wait()
|
||||
|
||||
// both hosts should be returned from the hosts endpoint
|
||||
listHostsRes = listHostsResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
|
||||
require.Len(t, listHostsRes.Hosts, 2)
|
||||
require.Equal(t, listHostsRes.Hosts[0].HardwareSerial, devices[0].SerialNumber)
|
||||
require.Equal(t, listHostsRes.Hosts[1].HardwareSerial, devices[1].SerialNumber)
|
||||
require.EqualValues(
|
||||
t,
|
||||
[]string{devices[0].SerialNumber, devices[1].SerialNumber},
|
||||
[]string{listHostsRes.Hosts[0].HardwareSerial, listHostsRes.Hosts[1].HardwareSerial},
|
||||
)
|
||||
|
||||
// create a new host
|
||||
createHostAndDeviceToken(t, s.ds, "not-dep")
|
||||
listHostsRes = listHostsResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
|
||||
require.Len(t, listHostsRes.Hosts, 3)
|
||||
|
||||
// filtering by MDM status works
|
||||
listHostsRes = listHostsResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/hosts?mdm_enrollment_status=pending", nil, http.StatusOK, &listHostsRes)
|
||||
require.Len(t, listHostsRes.Hosts, 2)
|
||||
|
||||
// enroll one of the hosts
|
||||
d := newDevice(s)
|
||||
d.serial = devices[0].SerialNumber
|
||||
d.mdmEnroll(s)
|
||||
|
||||
// only one shows up as pending
|
||||
listHostsRes = listHostsResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/hosts?mdm_enrollment_status=pending", nil, http.StatusOK, &listHostsRes)
|
||||
require.Len(t, listHostsRes.Hosts, 1)
|
||||
|
||||
activities := listActivitiesResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/activities", nil, http.StatusOK, &activities, "order_key", "created_at")
|
||||
found := false
|
||||
for _, activity := range activities.Activities {
|
||||
if activity.Type == "mdm_enrolled" &&
|
||||
strings.Contains(string(*activity.Details), devices[0].SerialNumber) {
|
||||
found = true
|
||||
require.Nil(t, activity.ActorID)
|
||||
require.Nil(t, activity.ActorFullName)
|
||||
require.JSONEq(
|
||||
t,
|
||||
fmt.Sprintf(
|
||||
`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": true}`,
|
||||
devices[0].SerialNumber, devices[0].Model, devices[0].SerialNumber,
|
||||
),
|
||||
string(*activity.Details),
|
||||
)
|
||||
}
|
||||
}
|
||||
require.True(t, found)
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) TestDeviceMDMManualEnroll() {
|
||||
|
|
@ -164,7 +379,7 @@ func (s *integrationMDMTestSuite) TestDeviceMDMManualEnroll() {
|
|||
require.Equal(t, apple_mdm.FleetPayloadIdentifier, profile.PayloadIdentifier)
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) TestDeviceEnrollment() {
|
||||
func (s *integrationMDMTestSuite) TestAppleMDMDeviceEnrollment() {
|
||||
t := s.T()
|
||||
|
||||
// Enroll two devices into MDM
|
||||
|
|
@ -212,8 +427,8 @@ func (s *integrationMDMTestSuite) TestDeviceEnrollment() {
|
|||
}
|
||||
}
|
||||
require.Len(t, details, 2)
|
||||
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceA.serial, deviceA.model, deviceA.serial), string(*details[0]))
|
||||
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceB.serial, deviceB.model, deviceB.serial), string(*details[1]))
|
||||
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceA.serial, deviceA.model, deviceA.serial), string(*details[len(details)-2]))
|
||||
require.JSONEq(t, fmt.Sprintf(`{"host_serial": "%s", "host_display_name": "%s (%s)", "installed_from_dep": false}`, deviceB.serial, deviceB.model, deviceB.serial), string(*details[len(details)-1]))
|
||||
|
||||
// set an enroll secret
|
||||
var applyResp applyEnrollSecretSpecResponse
|
||||
|
|
@ -317,6 +532,60 @@ func (s *integrationMDMTestSuite) TestAppleMDMCSRRequest() {
|
|||
require.Contains(t, string(reqCSRResp.SCEPKey), "-----BEGIN RSA PRIVATE KEY-----\n")
|
||||
}
|
||||
|
||||
func (s *integrationMDMTestSuite) TestMDMAppleUnenroll() {
|
||||
t := s.T()
|
||||
// enroll into mdm
|
||||
d := newMDMEnrolledDevice(s)
|
||||
|
||||
// set an enroll secret
|
||||
var applyResp applyEnrollSecretSpecResponse
|
||||
s.DoJSON("POST", "/api/latest/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{
|
||||
Spec: &fleet.EnrollSecretSpec{
|
||||
Secrets: []*fleet.EnrollSecret{{Secret: t.Name()}},
|
||||
},
|
||||
}, http.StatusOK, &applyResp)
|
||||
|
||||
// simulate a matching host enrolling via osquery
|
||||
j, err := json.Marshal(&enrollAgentRequest{
|
||||
EnrollSecret: t.Name(),
|
||||
HostIdentifier: d.uuid,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var enrollResp enrollAgentResponse
|
||||
hres := s.DoRawNoAuth("POST", "/api/osquery/enroll", j, http.StatusOK)
|
||||
defer hres.Body.Close()
|
||||
require.NoError(t, json.NewDecoder(hres.Body).Decode(&enrollResp))
|
||||
require.NotEmpty(t, enrollResp.NodeKey)
|
||||
|
||||
listHostsRes := listHostsResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &listHostsRes)
|
||||
require.Len(t, listHostsRes.Hosts, 1)
|
||||
h := listHostsRes.Hosts[0]
|
||||
|
||||
// try to unenroll the host, fails since the host doesn't respond
|
||||
s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusGatewayTimeout)
|
||||
|
||||
// we're going to modify this mock, make sure we restore its default
|
||||
originalPushMock := s.pushProvider.PushFunc
|
||||
defer func() { s.pushProvider.PushFunc = originalPushMock }()
|
||||
|
||||
// TODO: this is not working as expected, we're still waiting on the
|
||||
// device to unenroll even if we weren't able to send a push
|
||||
// notification, we should return an error instead.
|
||||
//
|
||||
// the APNs service returns an error
|
||||
// s.pushProvider.PushFunc = mockFailedPush
|
||||
// s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK)
|
||||
|
||||
// try again, but this time the host is online and answers
|
||||
s.pushProvider.PushFunc = func(pushes []*mdm.Push) (map[string]*push.Response, error) {
|
||||
res, err := mockSuccessfulPush(pushes)
|
||||
d.checkout()
|
||||
return res, err
|
||||
}
|
||||
s.Do("PATCH", fmt.Sprintf("/api/latest/fleet/mdm/hosts/%d/unenroll", h.ID), nil, http.StatusOK)
|
||||
}
|
||||
|
||||
type device struct {
|
||||
uuid string
|
||||
serial string
|
||||
|
|
@ -338,13 +607,18 @@ func newDevice(s *integrationMDMTestSuite) *device {
|
|||
|
||||
func newMDMEnrolledDevice(s *integrationMDMTestSuite) *device {
|
||||
d := newDevice(s)
|
||||
d.scepEnroll()
|
||||
d.authenticate()
|
||||
d.mdmEnroll(s)
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *device) mdmEnroll(s *integrationMDMTestSuite) {
|
||||
d.scepEnroll()
|
||||
d.authenticate()
|
||||
d.tokenUpdate()
|
||||
}
|
||||
|
||||
func (d *device) authenticate() {
|
||||
payload := map[string]string{
|
||||
payload := map[string]any{
|
||||
"MessageType": "Authenticate",
|
||||
"UDID": d.uuid,
|
||||
"Model": d.model,
|
||||
|
|
@ -356,8 +630,21 @@ func (d *device) authenticate() {
|
|||
d.request("application/x-apple-aspen-mdm-checkin", payload)
|
||||
}
|
||||
|
||||
func (d *device) tokenUpdate() {
|
||||
payload := map[string]any{
|
||||
"MessageType": "TokenUpdate",
|
||||
"UDID": d.uuid,
|
||||
"Topic": "com.apple.mgmt.External." + d.uuid,
|
||||
"EnrollmentID": "testenrollmentid-" + d.uuid,
|
||||
"NotOnConsole": "false",
|
||||
"PushMagic": "pushmagic" + d.serial,
|
||||
"Token": []byte("token" + d.serial),
|
||||
}
|
||||
d.request("application/x-apple-aspen-mdm-checkin", payload)
|
||||
}
|
||||
|
||||
func (d *device) checkout() {
|
||||
payload := map[string]string{
|
||||
payload := map[string]any{
|
||||
"MessageType": "CheckOut",
|
||||
"Topic": "com.apple.mgmt.External." + d.uuid,
|
||||
"UDID": d.uuid,
|
||||
|
|
@ -366,7 +653,7 @@ func (d *device) checkout() {
|
|||
d.request("application/x-apple-aspen-mdm-checkin", payload)
|
||||
}
|
||||
|
||||
func (d *device) request(reqType string, payload map[string]string) {
|
||||
func (d *device) request(reqType string, payload map[string]any) {
|
||||
body, err := plist.Marshal(payload)
|
||||
require.NoError(d.s.T(), err)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockimpl -o service_osquery.go "s *TLSService" "fleet.OsqueryService"
|
||||
//go:generate mockimpl -o service_pusher_factory.go "s *APNSPushProviderFactory" "github.com/micromdm/nanomdm/push.PushProviderFactory"
|
||||
//go:generate mockimpl -o service_push_provider.go "s *APNSPushProvider" "github.com/micromdm/nanomdm/push.PushProvider"
|
||||
|
|
|
|||
22
server/service/mock/service_push_provider.go
Normal file
22
server/service/mock/service_push_provider.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Automatically generated by mockimpl. DO NOT EDIT!
|
||||
|
||||
package mock
|
||||
|
||||
import (
|
||||
"github.com/micromdm/nanomdm/mdm"
|
||||
"github.com/micromdm/nanomdm/push"
|
||||
)
|
||||
|
||||
var _ push.PushProvider = (*APNSPushProvider)(nil)
|
||||
|
||||
type PushFunc func(p0 []*mdm.Push) (map[string]*push.Response, error)
|
||||
|
||||
type APNSPushProvider struct {
|
||||
PushFunc PushFunc
|
||||
PushFuncInvoked bool
|
||||
}
|
||||
|
||||
func (s *APNSPushProvider) Push(p0 []*mdm.Push) (map[string]*push.Response, error) {
|
||||
s.PushFuncInvoked = true
|
||||
return s.PushFunc(p0)
|
||||
}
|
||||
23
server/service/mock/service_pusher_factory.go
Normal file
23
server/service/mock/service_pusher_factory.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Automatically generated by mockimpl. DO NOT EDIT!
|
||||
|
||||
package mock
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/micromdm/nanomdm/push"
|
||||
)
|
||||
|
||||
var _ push.PushProviderFactory = (*APNSPushProviderFactory)(nil)
|
||||
|
||||
type NewPushProviderFunc func(p0 *tls.Certificate) (push.PushProvider, error)
|
||||
|
||||
type APNSPushProviderFactory struct {
|
||||
NewPushProviderFunc NewPushProviderFunc
|
||||
NewPushProviderFuncInvoked bool
|
||||
}
|
||||
|
||||
func (s *APNSPushProviderFactory) NewPushProvider(p0 *tls.Certificate) (push.PushProvider, error) {
|
||||
s.NewPushProviderFuncInvoked = true
|
||||
return s.NewPushProviderFunc(p0)
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -18,15 +19,19 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/contexts/license"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/logging"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service/async"
|
||||
"github.com/fleetdm/fleet/v4/server/service/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/sso"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/google/uuid"
|
||||
nanodep_storage "github.com/micromdm/nanodep/storage"
|
||||
"github.com/micromdm/nanomdm/mdm"
|
||||
"github.com/micromdm/nanomdm/push"
|
||||
nanomdm_push "github.com/micromdm/nanomdm/push"
|
||||
nanomdm_storage "github.com/micromdm/nanomdm/storage"
|
||||
scep_depot "github.com/micromdm/scep/v2/depot"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/throttled/throttled/v2"
|
||||
|
|
@ -238,7 +243,7 @@ type TestServerOpts struct {
|
|||
FleetConfig *config.FleetConfig
|
||||
MDMStorage nanomdm_storage.AllStorage
|
||||
DEPStorage nanodep_storage.AllStorage
|
||||
SCEPStorage *apple_mdm.SCEPMySQLDepot
|
||||
SCEPStorage scep_depot.Depot
|
||||
MDMPusher nanomdm_push.Pusher
|
||||
HTTPServerConfig *http.Server
|
||||
StartCronSchedules []TestNewScheduleFunc
|
||||
|
|
@ -491,3 +496,25 @@ func (nopEnrollHostLimiter) CanEnrollNewHost(ctx context.Context) (bool, error)
|
|||
func (nopEnrollHostLimiter) SyncEnrolledHostIDs(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMockAPNSPushProviderFactory() (*mock.APNSPushProviderFactory, *mock.APNSPushProvider) {
|
||||
provider := &mock.APNSPushProvider{}
|
||||
provider.PushFunc = mockSuccessfulPush
|
||||
factory := &mock.APNSPushProviderFactory{}
|
||||
factory.NewPushProviderFunc = func(*tls.Certificate) (push.PushProvider, error) {
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
return factory, provider
|
||||
}
|
||||
|
||||
func mockSuccessfulPush(pushes []*mdm.Push) (map[string]*push.Response, error) {
|
||||
res := make(map[string]*push.Response, len(pushes))
|
||||
for _, p := range pushes {
|
||||
res[p.Token.String()] = &push.Response{
|
||||
Id: uuid.New().String(),
|
||||
Err: nil,
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue