fleet/server/datastore/mysql/conditional_access_scep_test.go
Victor Lyuboslavsky 5cfc28ae5a
Okta IdP factor (#35143)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #34544 

Demo video: https://www.youtube.com/watch?v=VzOkISWmEKw
[Original research
doc](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/research/orchestration/okta-conditional-access.md)
[Victor's POC
branch](https://github.com/fleetdm/fleet/tree/victor/33165-okta-conditional-access-poc)

# Checklist for submitter

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

## Testing

- [x] Added/updated automated tests
- [x] Where appropriate, [automated tests simulate multiple hosts and
test for host
isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing)
(updates to one hosts's records do not affect another)

- [x] QA'd all new/changed functionality manually

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Conditional Access IdP integration added (IdP metadata & SSO) with
device-health aware session checks.
  * Endpoint to download the IdP signing certificate (PEM) added.
* Automatic revocation of old conditional access certificates with a
configurable grace period.

* **Tests**
* Extensive tests for certificate rotation, lifecycle, SSO flows, URL
construction, and IdP metadata.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-11-07 16:19:25 -06:00

224 lines
8.2 KiB
Go

package mysql
import (
"context"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConditionalAccessSCEP(t *testing.T) {
ds := CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *Datastore)
}{
{"GetCertBySerialAndCreatedAt", testGetConditionalAccessCertBySerialAndCreatedAt},
{"RevokedCertsNotReturned", testRevokedConditionalAccessCertsNotReturned},
{"ExpiredCertsNotReturned", testExpiredConditionalAccessCertsNotReturned},
{"CertificateLifecycle", testConditionalAccessCertificateLifecycle},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer TruncateTables(t, ds)
c.fn(t, ds)
})
}
}
func testGetConditionalAccessCertBySerialAndCreatedAt(t *testing.T, ds *Datastore) {
ctx := t.Context()
// Create a test host
host, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-host-1"),
NodeKey: ptr.String("test-node-key-1"),
UUID: "test-uuid-1",
Hostname: "test-hostname-1",
Platform: "darwin",
DetailUpdatedAt: time.Now(),
})
require.NoError(t, err)
require.NotNil(t, host)
// Insert a valid test certificate
now := time.Now()
serialNumber := insertConditionalAccessCert(t, ds, ctx, host.ID, "test-cn", now.Add(-24*time.Hour), now.Add(365*24*time.Hour), false)
// Test retrieval by serial number
hostID, err := ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, serialNumber)
require.NoError(t, err)
assert.Equal(t, host.ID, hostID)
// Test retrieval of created_at by host ID
createdAt, err := ds.GetConditionalAccessCertCreatedAtByHostID(ctx, host.ID)
require.NoError(t, err)
require.NotNil(t, createdAt)
// Verify timestamp is reasonable (created in the past, within last 24 hours)
assert.True(t, createdAt.Before(time.Now()))
assert.True(t, createdAt.After(time.Now().Add(-24*time.Hour)))
// Test non-existent serial
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, 999)
require.Error(t, err)
assert.True(t, fleet.IsNotFound(err))
// Test non-existent host
_, err = ds.GetConditionalAccessCertCreatedAtByHostID(ctx, 999999)
require.Error(t, err)
assert.True(t, fleet.IsNotFound(err))
}
func testRevokedConditionalAccessCertsNotReturned(t *testing.T, ds *Datastore) {
ctx := context.Background()
// Create a test host
host, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-host-6"),
NodeKey: ptr.String("test-node-key-6"),
UUID: "test-uuid-6",
Hostname: "test-hostname-6",
Platform: "darwin",
DetailUpdatedAt: time.Now(),
})
require.NoError(t, err)
// Insert a revoked certificate
now := time.Now()
serialNumber := insertConditionalAccessCert(t, ds, ctx, host.ID, "revoked-cert", now.Add(-24*time.Hour), now.Add(365*24*time.Hour), true)
// Revoked certs should not be returned by serial number lookup
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, serialNumber)
require.Error(t, err)
assert.True(t, fleet.IsNotFound(err))
// Note: GetConditionalAccessCertCreatedAtByHostID doesn't filter by revoked status
// since it's used for rate limiting checks, not authentication
}
func testExpiredConditionalAccessCertsNotReturned(t *testing.T, ds *Datastore) {
ctx := context.Background()
// Create a test host
host, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-host-7"),
NodeKey: ptr.String("test-node-key-7"),
UUID: "test-uuid-7",
Hostname: "test-hostname-7",
Platform: "darwin",
DetailUpdatedAt: time.Now(),
})
require.NoError(t, err)
// Insert an expired certificate
now := time.Now()
serialNumber := insertConditionalAccessCert(t, ds, ctx, host.ID, "expired-cert", now.Add(-400*24*time.Hour), now.Add(-24*time.Hour), false)
// Expired certs should not be returned by serial number lookup
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, serialNumber)
require.Error(t, err)
assert.True(t, fleet.IsNotFound(err))
// Note: GetConditionalAccessCertCreatedAtByHostID doesn't filter by expiration status
// since it's used for rate limiting checks, not authentication
}
// insertConditionalAccessCert inserts a conditional access SCEP certificate for testing.
// Returns the serial number of the inserted certificate.
func insertConditionalAccessCert(t *testing.T, ds *Datastore, ctx context.Context, hostID uint, name string, notValidBefore, notValidAfter time.Time, revoked bool) uint64 {
t.Helper()
certPEM := `-----BEGIN CERTIFICATE-----
MIICEjCCAXsCAg36MA0GCSqGSIb3DQEBBQUAMIGbMQswCQYDVQQGEwJKUDEOMAwG
-----END CERTIFICATE-----`
var serialNumber uint64
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
result, err := q.ExecContext(ctx, `INSERT INTO conditional_access_scep_serials () VALUES ()`)
require.NoError(t, err)
lastID, err := result.LastInsertId()
require.NoError(t, err)
serialNumber = uint64(lastID) // nolint:gosec,G115
_, err = q.ExecContext(ctx, `
INSERT INTO conditional_access_scep_certificates
(serial, host_id, name, not_valid_before, not_valid_after, certificate_pem, revoked)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, serialNumber, hostID, name, notValidBefore, notValidAfter, certPEM, revoked)
return err
})
return serialNumber
}
func testConditionalAccessCertificateLifecycle(t *testing.T, ds *Datastore) {
ctx := t.Context()
now := time.Now()
// Create two test hosts
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-host-lifecycle-1"), NodeKey: ptr.String("test-node-key-1"),
UUID: "test-uuid-1", Hostname: "test-hostname-1", Platform: "darwin", DetailUpdatedAt: now,
})
require.NoError(t, err)
host2, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: ptr.String("test-host-lifecycle-2"), NodeKey: ptr.String("test-node-key-2"),
UUID: "test-uuid-2", Hostname: "test-hostname-2", Platform: "darwin", DetailUpdatedAt: now,
})
require.NoError(t, err)
// Host 1: Multiple valid certs (simulating rotation)
s1_old := insertConditionalAccessCert(t, ds, ctx, host1.ID, "h1-old", now.Add(-3*time.Hour), now.Add(24*time.Hour), false)
s1_new := insertConditionalAccessCert(t, ds, ctx, host1.ID, "h1-new", now.Add(-2*time.Hour), now.Add(24*time.Hour), false)
// Host 2: New cert within grace period + 1 revoked cert
s2_old := insertConditionalAccessCert(t, ds, ctx, host2.ID, "h2-old", now.Add(-3*time.Hour), now.Add(24*time.Hour), false)
s2_new := insertConditionalAccessCert(t, ds, ctx, host2.ID, "h2-new", now.Add(-30*time.Minute), now.Add(24*time.Hour), false)
s2_revoked := insertConditionalAccessCert(t, ds, ctx, host2.ID, "h2-revoked", now.Add(-4*time.Hour), now.Add(24*time.Hour), true)
// All valid certs authenticate
for _, s := range []uint64{s1_old, s1_new, s2_old, s2_new} {
_, err := ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s)
require.NoError(t, err)
}
// Revoked cert does not authenticate
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s2_revoked)
require.Error(t, err)
assert.True(t, fleet.IsNotFound(err))
// Cleanup with 1-hour grace: only host1's old cert eligible (host1's new cert is 2h old)
count, err := ds.RevokeOldConditionalAccessCerts(ctx, 1*time.Hour)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
// Host1's old cert now revoked, new cert still works
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s1_old)
assert.True(t, fleet.IsNotFound(err))
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s1_new)
require.NoError(t, err)
// Host2's certs still work (new cert within grace period)
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s2_old)
require.NoError(t, err)
_, err = ds.GetConditionalAccessCertHostIDBySerialNumber(ctx, s2_new)
require.NoError(t, err)
// Second cleanup revokes nothing
count, err = ds.RevokeOldConditionalAccessCerts(ctx, 1*time.Hour)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
// GetConditionalAccessCertCreatedAtByHostID returns most recent cert
createdAt, err := ds.GetConditionalAccessCertCreatedAtByHostID(ctx, host2.ID)
require.NoError(t, err)
assert.True(t, createdAt.After(now.Add(-1*time.Hour)) && createdAt.Before(now.Add(1*time.Minute)))
}