mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #34544 Demo video: https://www.youtube.com/watch?v=VzOkISWmEKw [Original research doc](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/research/orchestration/okta-conditional-access.md) [Victor's POC branch](https://github.com/fleetdm/fleet/tree/victor/33165-okta-conditional-access-poc) # Checklist for submitter - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. ## Testing - [x] Added/updated automated tests - [x] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Conditional Access IdP integration added (IdP metadata & SSO) with device-health aware session checks. * Endpoint to download the IdP signing certificate (PEM) added. * Automatic revocation of old conditional access certificates with a configurable grace period. * **Tests** * Extensive tests for certificate rotation, lifecycle, SSO flows, URL construction, and IdP metadata. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
224 lines
8.2 KiB
Go
224 lines
8.2 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestConditionalAccessSCEP(t *testing.T) {
|
|
ds := CreateMySQLDS(t)
|
|
|
|
cases := []struct {
|
|
name string
|
|
fn func(t *testing.T, ds *Datastore)
|
|
}{
|
|
{"GetCertBySerialAndCreatedAt", testGetConditionalAccessCertBySerialAndCreatedAt},
|
|
{"RevokedCertsNotReturned", testRevokedConditionalAccessCertsNotReturned},
|
|
{"ExpiredCertsNotReturned", testExpiredConditionalAccessCertsNotReturned},
|
|
{"CertificateLifecycle", testConditionalAccessCertificateLifecycle},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
defer TruncateTables(t, ds)
|
|
c.fn(t, ds)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testGetConditionalAccessCertBySerialAndCreatedAt(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
|
|
// Create a test host
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
OsqueryHostID: ptr.String("test-host-1"),
|
|
NodeKey: ptr.String("test-node-key-1"),
|
|
UUID: "test-uuid-1",
|
|
Hostname: "test-hostname-1",
|
|
Platform: "darwin",
|
|
DetailUpdatedAt: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, host)
|
|
|
|
// Insert a valid test certificate
|
|
now := time.Now()
|
|
serialNumber := insertConditionalAccessCert(t, ds, ctx, host.ID, "test-cn", now.Add(-24*time.Hour), now.Add(365*24*time.Hour), false)
|
|
|
|
// Test retrieval by serial number
|
|
hostID, err := ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, serialNumber)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, host.ID, hostID)
|
|
|
|
// Test retrieval of created_at by host ID
|
|
createdAt, err := ds.GetConditionalAccessCertCreatedAtByHostID(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, createdAt)
|
|
// Verify timestamp is reasonable (created in the past, within last 24 hours)
|
|
assert.True(t, createdAt.Before(time.Now()))
|
|
assert.True(t, createdAt.After(time.Now().Add(-24*time.Hour)))
|
|
|
|
// Test non-existent serial
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, 999)
|
|
require.Error(t, err)
|
|
assert.True(t, fleet.IsNotFound(err))
|
|
|
|
// Test non-existent host
|
|
_, err = ds.GetConditionalAccessCertCreatedAtByHostID(ctx, 999999)
|
|
require.Error(t, err)
|
|
assert.True(t, fleet.IsNotFound(err))
|
|
}
|
|
|
|
func testRevokedConditionalAccessCertsNotReturned(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// Create a test host
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
OsqueryHostID: ptr.String("test-host-6"),
|
|
NodeKey: ptr.String("test-node-key-6"),
|
|
UUID: "test-uuid-6",
|
|
Hostname: "test-hostname-6",
|
|
Platform: "darwin",
|
|
DetailUpdatedAt: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a revoked certificate
|
|
now := time.Now()
|
|
serialNumber := insertConditionalAccessCert(t, ds, ctx, host.ID, "revoked-cert", now.Add(-24*time.Hour), now.Add(365*24*time.Hour), true)
|
|
|
|
// Revoked certs should not be returned by serial number lookup
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, serialNumber)
|
|
require.Error(t, err)
|
|
assert.True(t, fleet.IsNotFound(err))
|
|
|
|
// Note: GetConditionalAccessCertCreatedAtByHostID doesn't filter by revoked status
|
|
// since it's used for rate limiting checks, not authentication
|
|
}
|
|
|
|
func testExpiredConditionalAccessCertsNotReturned(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// Create a test host
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
OsqueryHostID: ptr.String("test-host-7"),
|
|
NodeKey: ptr.String("test-node-key-7"),
|
|
UUID: "test-uuid-7",
|
|
Hostname: "test-hostname-7",
|
|
Platform: "darwin",
|
|
DetailUpdatedAt: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert an expired certificate
|
|
now := time.Now()
|
|
serialNumber := insertConditionalAccessCert(t, ds, ctx, host.ID, "expired-cert", now.Add(-400*24*time.Hour), now.Add(-24*time.Hour), false)
|
|
|
|
// Expired certs should not be returned by serial number lookup
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, serialNumber)
|
|
require.Error(t, err)
|
|
assert.True(t, fleet.IsNotFound(err))
|
|
|
|
// Note: GetConditionalAccessCertCreatedAtByHostID doesn't filter by expiration status
|
|
// since it's used for rate limiting checks, not authentication
|
|
}
|
|
|
|
// insertConditionalAccessCert inserts a conditional access SCEP certificate for testing.
|
|
// Returns the serial number of the inserted certificate.
|
|
func insertConditionalAccessCert(t *testing.T, ds *Datastore, ctx context.Context, hostID uint, name string, notValidBefore, notValidAfter time.Time, revoked bool) uint64 {
|
|
t.Helper()
|
|
|
|
certPEM := `-----BEGIN CERTIFICATE-----
|
|
MIICEjCCAXsCAg36MA0GCSqGSIb3DQEBBQUAMIGbMQswCQYDVQQGEwJKUDEOMAwG
|
|
-----END CERTIFICATE-----`
|
|
|
|
var serialNumber uint64
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
result, err := q.ExecContext(ctx, `INSERT INTO conditional_access_scep_serials () VALUES ()`)
|
|
require.NoError(t, err)
|
|
|
|
lastID, err := result.LastInsertId()
|
|
require.NoError(t, err)
|
|
serialNumber = uint64(lastID) // nolint:gosec,G115
|
|
|
|
_, err = q.ExecContext(ctx, `
|
|
INSERT INTO conditional_access_scep_certificates
|
|
(serial, host_id, name, not_valid_before, not_valid_after, certificate_pem, revoked)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
`, serialNumber, hostID, name, notValidBefore, notValidAfter, certPEM, revoked)
|
|
return err
|
|
})
|
|
|
|
return serialNumber
|
|
}
|
|
|
|
func testConditionalAccessCertificateLifecycle(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
now := time.Now()
|
|
|
|
// Create two test hosts
|
|
host1, err := ds.NewHost(ctx, &fleet.Host{
|
|
OsqueryHostID: ptr.String("test-host-lifecycle-1"), NodeKey: ptr.String("test-node-key-1"),
|
|
UUID: "test-uuid-1", Hostname: "test-hostname-1", Platform: "darwin", DetailUpdatedAt: now,
|
|
})
|
|
require.NoError(t, err)
|
|
host2, err := ds.NewHost(ctx, &fleet.Host{
|
|
OsqueryHostID: ptr.String("test-host-lifecycle-2"), NodeKey: ptr.String("test-node-key-2"),
|
|
UUID: "test-uuid-2", Hostname: "test-hostname-2", Platform: "darwin", DetailUpdatedAt: now,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Host 1: Multiple valid certs (simulating rotation)
|
|
s1_old := insertConditionalAccessCert(t, ds, ctx, host1.ID, "h1-old", now.Add(-3*time.Hour), now.Add(24*time.Hour), false)
|
|
s1_new := insertConditionalAccessCert(t, ds, ctx, host1.ID, "h1-new", now.Add(-2*time.Hour), now.Add(24*time.Hour), false)
|
|
|
|
// Host 2: New cert within grace period + 1 revoked cert
|
|
s2_old := insertConditionalAccessCert(t, ds, ctx, host2.ID, "h2-old", now.Add(-3*time.Hour), now.Add(24*time.Hour), false)
|
|
s2_new := insertConditionalAccessCert(t, ds, ctx, host2.ID, "h2-new", now.Add(-30*time.Minute), now.Add(24*time.Hour), false)
|
|
s2_revoked := insertConditionalAccessCert(t, ds, ctx, host2.ID, "h2-revoked", now.Add(-4*time.Hour), now.Add(24*time.Hour), true)
|
|
|
|
// All valid certs authenticate
|
|
for _, s := range []uint64{s1_old, s1_new, s2_old, s2_new} {
|
|
_, err := ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Revoked cert does not authenticate
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s2_revoked)
|
|
require.Error(t, err)
|
|
assert.True(t, fleet.IsNotFound(err))
|
|
|
|
// Cleanup with 1-hour grace: only host1's old cert eligible (host1's new cert is 2h old)
|
|
count, err := ds.RevokeOldConditionalAccessCerts(ctx, 1*time.Hour)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), count)
|
|
|
|
// Host1's old cert now revoked, new cert still works
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s1_old)
|
|
assert.True(t, fleet.IsNotFound(err))
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s1_new)
|
|
require.NoError(t, err)
|
|
|
|
// Host2's certs still work (new cert within grace period)
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s2_old)
|
|
require.NoError(t, err)
|
|
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s2_new)
|
|
require.NoError(t, err)
|
|
|
|
// Second cleanup revokes nothing
|
|
count, err = ds.RevokeOldConditionalAccessCerts(ctx, 1*time.Hour)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(0), count)
|
|
|
|
// GetConditionalAccessCertCreatedAtByHostID returns most recent cert
|
|
createdAt, err := ds.GetConditionalAccessCertCreatedAtByHostID(ctx, host2.ID)
|
|
require.NoError(t, err)
|
|
assert.True(t, createdAt.After(now.Add(-1*time.Hour)) && createdAt.Before(now.Add(1*time.Minute)))
|
|
}
|