mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 09:28:54 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #42836 This is another hot path optimization. ## Before When a host submits policy results via `SubmitDistributedQueryResults`, the system needed to determine which policies "flipped" (changed from passing to failing or vice versa). Each consumer computed this independently: ``` SubmitDistributedQueryResults(policyResults) | +-- processScriptsForNewlyFailingPolicies | filter to failing policies with scripts | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #1 | convert result to set, filter, queue scripts | +-- processSoftwareForNewlyFailingPolicies | filter to failing policies with installers | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #2 | convert result to set, filter, queue installs | +-- processVPPForNewlyFailingPolicies | filter to failing policies with VPP apps | BUILD SUBSET of results | CALL FlippingPoliciesForHost(subset) <-- DB query #3 | convert result to set, filter, queue VPP | +-- webhook filtering | filter to webhook-enabled policies | CALL FlippingPoliciesForHost(subset) <-- DB query #4 | register flipped policies in Redis | +-- RecordPolicyQueryExecutions CALL FlippingPoliciesForHost(all results) <-- DB query #5 reset attempt counters for newly passing INSERT/UPDATE policy_membership ``` Each `FlippingPoliciesForHost` call runs `SELECT policy_id, passes FROM policy_membership WHERE host_id = ? AND policy_id IN (?)`. All 5 queries hit the same table for the same host before `policy_membership` is updated, so they all see identical state. Each consumer also built intermediate maps to narrow down to its subset before calling `FlippingPoliciesForHost`, then converted the result into yet another set for filtering. This meant 3-4 temporary maps per consumer. ## After ``` SubmitDistributedQueryResults(policyResults) | CALL FlippingPoliciesForHost(all results) <-- single DB query build newFailingSet, normalize newPassing | +-- processScriptsForNewlyFailingPolicies | filter to failing policies with scripts | CHECK newFailingSet (in-memory map lookup) | queue scripts | +-- processSoftwareForNewlyFailingPolicies | filter to failing policies with installers | CHECK newFailingSet (in-memory map lookup) | queue installs | +-- processVPPForNewlyFailingPolicies | filter to failing policies with VPP apps | CHECK newFailingSet (in-memory map lookup) | queue VPP | +-- webhook filtering | filter to webhook-enabled policies | FILTER newFailing/newPassing by policy IDs (in-memory) | register flipped policies in Redis | +-- RecordPolicyQueryExecutions USE pre-computed newPassing (skip DB query) reset attempt counters for newly passing INSERT/UPDATE policy_membership ``` The intermediate subset maps and per-consumer set conversions are removed. Each process function goes directly from "policies with associated automation" to "is this policy in newFailingSet?" in a single map lookup. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Performance Improvements** * Reduced redundant database queries during policy result submissions by computing flipping policies once per host check-in instead of multiple times. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
429 lines
14 KiB
Go
429 lines
14 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/test"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestConditionalAccessBypass(t *testing.T) {
|
|
ds := CreateMySQLDS(t)
|
|
|
|
cases := []struct {
|
|
name string
|
|
fn func(t *testing.T, ds *Datastore)
|
|
}{
|
|
{"ConditionalAccessBypassDevice", testConditionalAccessBypassDevice},
|
|
{"ConditionalAccessBypassDeviceWithBlockingPolicy", testConditionalAccessBypassDeviceWithBlockingPolicy},
|
|
{"ConditionalAccessConsumeBypass", testConditionalAccessConsumeBypass},
|
|
{"ConditionalAccessClearBypasses", testConditionalAccessClearBypasses},
|
|
{"ConditionalAccessBypassDeletedWithHost", testConditionalAccessBypassDeletedWithHost},
|
|
{"ConditionalAccessBypassedAt", testConditionalAccessBypassedAt},
|
|
{"ConditionalAccessBypassAllowedWithNonCAFailingCriticalPolicy", testConditionalAccessBypassAllowedWithNonCAFailingCriticalPolicy},
|
|
{"ConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy", testConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
defer TruncateTables(t, ds)
|
|
c.fn(t, ds)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testConditionalAccessBypassDevice(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("1"),
|
|
UUID: "1",
|
|
Hostname: "foo.local",
|
|
PrimaryIP: "192.168.1.1",
|
|
PrimaryMac: "30-65-EC-6F-C4-58",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a bypass record
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the record exists
|
|
var count int
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, count)
|
|
|
|
// Call again to test ON DUPLICATE KEY UPDATE behavior
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Verify still only one record exists
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, count)
|
|
}
|
|
|
|
func testConditionalAccessConsumeBypass(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("1"),
|
|
UUID: "1",
|
|
Hostname: "foo.local",
|
|
PrimaryIP: "192.168.1.1",
|
|
PrimaryMac: "30-65-EC-6F-C4-58",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Consume when no bypass exists - should return nil without error
|
|
bypassedAt, err := ds.ConditionalAccessConsumeBypass(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, bypassedAt)
|
|
|
|
// Create a bypass record
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Consume the bypass
|
|
bypassedAt, err = ds.ConditionalAccessConsumeBypass(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bypassedAt)
|
|
|
|
// Verify the record was deleted
|
|
var count int
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
|
|
// Try to consume again - should return nil without error
|
|
bypassedAt, err = ds.ConditionalAccessConsumeBypass(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, bypassedAt)
|
|
}
|
|
|
|
func testConditionalAccessClearBypasses(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// Clear when no bypasses exist - should succeed without error
|
|
err := ds.ConditionalAccessClearBypasses(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Create multiple hosts with bypass records
|
|
host1, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("1"),
|
|
UUID: "1",
|
|
Hostname: "foo1.local",
|
|
PrimaryIP: "192.168.1.1",
|
|
PrimaryMac: "30-65-EC-6F-C4-58",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
host2, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("2"),
|
|
UUID: "2",
|
|
Hostname: "foo2.local",
|
|
PrimaryIP: "192.168.1.2",
|
|
PrimaryMac: "30-65-EC-6F-C4-59",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
host3, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("3"),
|
|
UUID: "3",
|
|
Hostname: "foo3.local",
|
|
PrimaryIP: "192.168.1.3",
|
|
PrimaryMac: "30-65-EC-6F-C4-60",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create bypass records for all hosts
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host1.ID)
|
|
require.NoError(t, err)
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host2.ID)
|
|
require.NoError(t, err)
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host3.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Verify all records exist
|
|
var count int
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access")
|
|
require.NoError(t, err)
|
|
require.Equal(t, 3, count)
|
|
|
|
// Clear all bypasses
|
|
err = ds.ConditionalAccessClearBypasses(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Verify all records were deleted
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access")
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
}
|
|
|
|
func testConditionalAccessBypassDeletedWithHost(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// Create a host
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("1"),
|
|
UUID: "1",
|
|
Hostname: "foo.local",
|
|
PrimaryIP: "192.168.1.1",
|
|
PrimaryMac: "30-65-EC-6F-C4-58",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create a bypass record for the host
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the bypass record exists
|
|
var count int
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, count)
|
|
|
|
// Delete the host
|
|
err = ds.DeleteHost(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the bypass record was also deleted
|
|
err = ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count, "bypass record should be deleted when host is deleted")
|
|
}
|
|
|
|
func testConditionalAccessBypassedAt(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("1"),
|
|
UUID: "1",
|
|
Hostname: "foo.local",
|
|
PrimaryIP: "192.168.1.1",
|
|
PrimaryMac: "30-65-EC-6F-C4-58",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
bypassedAt, err := ds.ConditionalAccessBypassedAt(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, bypassedAt)
|
|
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
bypassedAt, err = ds.ConditionalAccessBypassedAt(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bypassedAt)
|
|
require.WithinDuration(t, time.Now(), *bypassedAt, 5*time.Second)
|
|
|
|
bypassedAtAgain, err := ds.ConditionalAccessBypassedAt(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bypassedAtAgain)
|
|
require.Equal(t, bypassedAt, bypassedAtAgain)
|
|
|
|
hostWithoutBypass, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("2"),
|
|
UUID: "2",
|
|
Hostname: "bar.local",
|
|
PrimaryIP: "192.168.1.2",
|
|
PrimaryMac: "30-65-EC-6F-C4-59",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
bypassedAtOther, err := ds.ConditionalAccessBypassedAt(ctx, hostWithoutBypass.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, bypassedAtOther)
|
|
}
|
|
|
|
func testConditionalAccessBypassDeviceWithBlockingPolicy(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
|
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("blocking-policy-host"),
|
|
UUID: "blocking-policy-uuid",
|
|
Hostname: "blocking.local",
|
|
PrimaryIP: "192.168.1.10",
|
|
PrimaryMac: "30-65-EC-6F-C4-70",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Assign host to a team
|
|
team, err := ds.NewTeam(ctx, &fleet.Team{Name: "blocking-policy-team"})
|
|
require.NoError(t, err)
|
|
require.NoError(t, ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team.ID, []uint{host.ID})))
|
|
|
|
// Create a team CA-enabled critical policy that should block bypass
|
|
policy, err := ds.NewTeamPolicy(ctx, team.ID, &user.ID, fleet.PolicyPayload{
|
|
Name: "ca-critical-policy",
|
|
Query: "select 1;",
|
|
Critical: true,
|
|
ConditionalAccessEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Record a failing result for this policy on the host
|
|
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{policy.ID: new(false)}, time.Now(), false, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Bypass should fail because the host has a failing CA-enabled policy
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.Error(t, err)
|
|
var badReqErr *fleet.BadRequestError
|
|
require.ErrorAs(t, err, &badReqErr)
|
|
|
|
// Verify no host_conditional_access row was created
|
|
var count int
|
|
innerErr := ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, innerErr)
|
|
require.Equal(t, 0, count)
|
|
}
|
|
|
|
func testConditionalAccessBypassAllowedWithNonCAFailingCriticalPolicy(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("non-ca-policy-host"),
|
|
UUID: "non-ca-policy-uuid",
|
|
Hostname: "non-ca.local",
|
|
PrimaryIP: "192.168.1.11",
|
|
PrimaryMac: "30-65-EC-6F-C4-71",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Assign host to a team
|
|
team, err := ds.NewTeam(ctx, &fleet.Team{Name: "non-ca-policy-team"})
|
|
require.NoError(t, err)
|
|
require.NoError(t, ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team.ID, []uint{host.ID})))
|
|
|
|
// CA policy — passing
|
|
caPolicy, err := ds.NewTeamPolicy(ctx, team.ID, &user.ID, fleet.PolicyPayload{
|
|
Name: "ca-policy-passing",
|
|
Query: "select 1;",
|
|
Critical: true,
|
|
ConditionalAccessEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Non-CA critical policy — failing
|
|
nonCAPolicy, err := ds.NewTeamPolicy(ctx, team.ID, &user.ID, fleet.PolicyPayload{
|
|
Name: "non-ca-policy-failing",
|
|
Query: "select 1;",
|
|
Critical: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
|
|
caPolicy.ID: ptr.Bool(true), // passing
|
|
nonCAPolicy.ID: ptr.Bool(false), // failing
|
|
}, time.Now(), false, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Bypass must succeed: the only failing policy is not CA-enabled
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Verify a host_conditional_access row was created
|
|
var count int
|
|
innerErr := ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, innerErr)
|
|
require.Equal(t, 1, count)
|
|
}
|
|
|
|
// testConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy verifies that a CA-enabled but
|
|
// non-critical failing policy does NOT block bypass. Both critical=1 AND conditional_access_enabled=1
|
|
// are required to block.
|
|
func testConditionalAccessBypassAllowedWithCAEnabledNonCriticalPolicy(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "Carol", "carol@example.com", true)
|
|
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
PolicyUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String("ca-non-critical-host"),
|
|
UUID: "ca-non-critical-uuid",
|
|
Hostname: "ca-non-critical.local",
|
|
PrimaryIP: "192.168.1.12",
|
|
PrimaryMac: "30-65-EC-6F-C4-72",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
team, err := ds.NewTeam(ctx, &fleet.Team{Name: "ca-non-critical-team"})
|
|
require.NoError(t, err)
|
|
require.NoError(t, ds.AddHostsToTeam(ctx, fleet.NewAddHostsToTeamParams(&team.ID, []uint{host.ID})))
|
|
|
|
// CA-enabled but NOT critical — failing
|
|
nonCriticalCAPolicy, err := ds.NewTeamPolicy(ctx, team.ID, &user.ID, fleet.PolicyPayload{
|
|
Name: "ca-enabled-non-critical",
|
|
Query: "select 1;",
|
|
Critical: false,
|
|
ConditionalAccessEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = ds.RecordPolicyQueryExecutions(ctx, host, map[uint]*bool{
|
|
nonCriticalCAPolicy.ID: ptr.Bool(false), // failing
|
|
}, time.Now(), false, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Bypass must succeed: policy is CA-enabled but not critical
|
|
err = ds.ConditionalAccessBypassDevice(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
var count int
|
|
innerErr := ds.writer(ctx).GetContext(ctx, &count, "SELECT COUNT(*) FROM host_conditional_access WHERE host_id = ?", host.ID)
|
|
require.NoError(t, innerErr)
|
|
require.Equal(t, 1, count)
|
|
}
|