fleet/server/service/hosts_test.go
Ian Littman 8e4e89f4e9
API + auth + UI changes for team labels (#37208)
Covers #36760, #36758.

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)

## Testing

- [x] Added/updated automated tests
- [x] Where appropriate, [automated tests simulate multiple hosts and
test for host
isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing)
(updates to one hosts's records do not affect another)

- [ ] QA'd all new/changed functionality manually
2025-12-29 21:28:45 -06:00

3792 lines
133 KiB
Go

package service
import (
"context"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"iter"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/config"
authzctx "github.com/fleetdm/fleet/v4/server/contexts/authz"
"github.com/fleetdm/fleet/v4/server/contexts/capabilities"
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/mdm/apple/mobileconfig"
"github.com/fleetdm/fleet/v4/server/mdm/nanodep/tokenpki"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
kitlog "github.com/go-kit/log"
"github.com/jmoiron/sqlx"
"github.com/smallstep/pkcs7"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostDetails(t *testing.T) {
ds := new(mock.Store)
svc := &Service{ds: ds}
host := &fleet.Host{ID: 3}
expectedLabels := []*fleet.Label{
{
Name: "foobar",
Description: "the foobar label",
},
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return expectedLabels, nil
}
expectedPacks := []*fleet.Pack{
{
Name: "pack1",
},
{
Name: "pack2",
},
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return expectedPacks, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
dsBats := []*fleet.HostBattery{{HostID: host.ID, SerialNumber: "a", CycleCount: 999, Health: "Normal"}, {HostID: host.ID, SerialNumber: "b", CycleCount: 1001, Health: "Service recommended"}}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return dsBats, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
}
hostDetail, err := svc.getHostDetails(test.UserContext(context.Background(), test.UserAdmin), host, opts)
require.NoError(t, err)
assert.Equal(t, expectedLabels, hostDetail.Labels)
assert.Equal(t, expectedPacks, hostDetail.Packs)
require.NotNil(t, hostDetail.Batteries)
assert.Equal(t, dsBats, *hostDetail.Batteries)
require.Nil(t, hostDetail.MDM.MacOSSettings)
}
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostDetailsMDMAppleDiskEncryption(t *testing.T) {
ds := new(mock.Store)
svc := &Service{ds: ds}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: true}}, nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.GetNanoMDMEnrollmentTimesFunc = func(ctx context.Context, hostUUID string) (*time.Time, *time.Time, error) {
return nil, nil, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
cases := []struct {
name string
rawDecrypt *int
fvProf *fleet.HostMDMAppleProfile
wantState fleet.DiskEncryptionStatus
wantAction fleet.ActionRequiredState
wantStatus *fleet.MDMDeliveryStatus
}{
{"no profile", ptr.Int(-1), nil, "", "", nil},
{
"installed profile, no key",
ptr.Int(-1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryVerifying,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionActionRequired,
fleet.ActionRequiredRotateKey,
&fleet.MDMDeliveryPending,
},
{
"installed profile, unknown decryptable",
nil,
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryVerifying,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionVerifying,
"",
&fleet.MDMDeliveryVerifying,
},
{
"installed profile, not decryptable",
ptr.Int(0),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryVerifying,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionActionRequired,
fleet.ActionRequiredRotateKey,
&fleet.MDMDeliveryPending,
},
{
"installed profile, decryptable",
ptr.Int(1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryVerifying,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionVerifying,
"",
&fleet.MDMDeliveryVerifying,
},
{
"installed profile, decryptable, verified",
ptr.Int(1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryVerified,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionVerified,
"",
&fleet.MDMDeliveryVerified,
},
{
"pending install, decryptable",
ptr.Int(1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryPending,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionEnforcing,
"",
&fleet.MDMDeliveryPending,
},
{
"pending install, unknown decryptable",
nil,
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryPending,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionEnforcing,
"",
&fleet.MDMDeliveryPending,
},
{
"pending install, no key",
ptr.Int(-1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryPending,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionEnforcing,
"",
&fleet.MDMDeliveryPending,
},
{
"failed install, no key",
ptr.Int(-1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryFailed,
OperationType: fleet.MDMOperationTypeInstall,
Detail: "some mdm profile install error",
},
fleet.DiskEncryptionFailed,
"",
&fleet.MDMDeliveryFailed,
},
{
"failed install, not decryptable",
ptr.Int(0),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryFailed,
OperationType: fleet.MDMOperationTypeInstall,
},
fleet.DiskEncryptionFailed,
"",
&fleet.MDMDeliveryFailed,
},
{
"pending remove, decryptable",
ptr.Int(1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryPending,
OperationType: fleet.MDMOperationTypeRemove,
},
fleet.DiskEncryptionRemovingEnforcement,
"",
&fleet.MDMDeliveryPending,
},
{
"pending remove, no key",
ptr.Int(-1),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryPending,
OperationType: fleet.MDMOperationTypeRemove,
},
fleet.DiskEncryptionRemovingEnforcement,
"",
&fleet.MDMDeliveryPending,
},
{
"failed remove, unknown decryptable",
nil,
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryFailed,
OperationType: fleet.MDMOperationTypeRemove,
Detail: "some mdm profile removal error",
},
fleet.DiskEncryptionFailed,
"",
&fleet.MDMDeliveryFailed,
},
{
"removed profile, not decryptable",
ptr.Int(0),
&fleet.HostMDMAppleProfile{
HostUUID: "abc",
Identifier: mobileconfig.FleetFileVaultPayloadIdentifier,
Status: &fleet.MDMDeliveryVerifying,
OperationType: fleet.MDMOperationTypeRemove,
},
"",
"",
&fleet.MDMDeliveryVerifying,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var mdmData fleet.MDMHostData
rawDecrypt := "null"
if c.rawDecrypt != nil {
rawDecrypt = strconv.Itoa(*c.rawDecrypt)
}
require.NoError(t, mdmData.Scan([]byte(fmt.Sprintf(`{"raw_decryptable": %s}`, rawDecrypt))))
host := &fleet.Host{ID: 3, MDM: mdmData, UUID: "abc", Platform: "darwin"}
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
}
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMAppleProfile, error) {
if c.fvProf == nil {
return nil, nil
}
return []fleet.HostMDMAppleProfile{*c.fvProf}, nil
}
hostDetail, err := svc.getHostDetails(test.UserContext(context.Background(), test.UserAdmin), host, opts)
require.NoError(t, err)
require.NotNil(t, hostDetail.MDM.MacOSSettings)
if c.wantState == "" {
require.Nil(t, hostDetail.MDM.MacOSSettings.DiskEncryption)
require.Nil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
require.Empty(t, hostDetail.MDM.OSSettings.DiskEncryption.Detail)
} else {
require.NotNil(t, hostDetail.MDM.MacOSSettings.DiskEncryption)
require.Equal(t, c.wantState, *hostDetail.MDM.MacOSSettings.DiskEncryption)
require.NotNil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
require.Equal(t, c.wantState, *hostDetail.MDM.OSSettings.DiskEncryption.Status)
require.Equal(t, c.fvProf.Detail, hostDetail.MDM.OSSettings.DiskEncryption.Detail)
}
if c.wantAction == "" {
require.Nil(t, hostDetail.MDM.MacOSSettings.ActionRequired)
} else {
require.NotNil(t, hostDetail.MDM.MacOSSettings.ActionRequired)
require.Equal(t, c.wantAction, *hostDetail.MDM.MacOSSettings.ActionRequired)
}
if c.wantStatus != nil {
require.NotNil(t, hostDetail.MDM.Profiles)
profs := *hostDetail.MDM.Profiles
require.EqualValues(t, *c.wantStatus, *profs[0].Status)
require.Equal(t, c.fvProf.Detail, profs[0].Detail)
} else {
require.Nil(t, *hostDetail.MDM.Profiles)
}
})
}
}
func TestHostDetailsMDMTimestamps(t *testing.T) {
ds := new(mock.Store)
svc := &Service{ds: ds}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: true, WindowsEnabledAndConfigured: true}}, nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMAppleProfile, error) {
return nil, nil
}
ds.GetHostMDMWindowsProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMWindowsProfile, error) {
return nil, nil
}
ds.GetConfigEnableDiskEncryptionFunc = func(ctx context.Context, teamID *uint) (fleet.DiskEncryptionConfig, error) {
return fleet.DiskEncryptionConfig{}, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
ts1 := time.Now().Add(-1 * time.Hour).UTC()
ts2 := time.Now().Add(-2 * time.Hour).UTC()
ds.GetNanoMDMEnrollmentTimesFunc = func(ctx context.Context, hostUUID string) (*time.Time, *time.Time, error) {
return &ts1, &ts2, nil
}
cases := []struct {
platform string
platformIsApple bool
}{
{"darwin", true},
{"ios", true},
{"ipados", true},
{"windows", false},
{"ubuntu", false},
{"centos", false},
{"rhel", false},
{"debian", false},
}
for _, testcase := range cases {
t.Run("test MDM timestamps on platform "+testcase.platform, func(t *testing.T) {
ds.GetNanoMDMEnrollmentTimesFuncInvoked = false
host := &fleet.Host{ID: 3, MDM: fleet.MDMHostData{}, Platform: testcase.platform, UUID: "abc123"}
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
ExcludeSoftware: true,
IncludeCriticalVulnerabilitiesCount: false,
}
hostDetail, err := svc.getHostDetails(test.UserContext(context.Background(), test.UserAdmin), host, opts)
require.NoError(t, err)
if testcase.platformIsApple {
assert.True(t, ds.GetNanoMDMEnrollmentTimesFuncInvoked)
require.NotNil(t, hostDetail.LastMDMEnrolledAt)
assert.Equal(t, *hostDetail.LastMDMEnrolledAt, ts1)
require.NotNil(t, hostDetail.LastMDMCheckedInAt)
assert.Equal(t, *hostDetail.LastMDMCheckedInAt, ts2)
} else {
assert.False(t, ds.GetNanoMDMEnrollmentTimesFuncInvoked)
assert.Nil(t, hostDetail.LastMDMEnrolledAt)
assert.Nil(t, hostDetail.LastMDMCheckedInAt)
}
})
}
}
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostDetailsOSSettings(t *testing.T) {
ds := new(mock.Store)
svc := &Service{ds: ds}
ctx := context.Background()
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostMDMMacOSSetupFunc = func(ctx context.Context, hid uint) (*fleet.HostMDMMacOSSetup, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, hostID uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{}, nil
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.GetNanoMDMEnrollmentTimesFunc = func(ctx context.Context, hostUUID string) (*time.Time, *time.Time, error) {
return nil, nil, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
type testCase struct {
name string
host *fleet.Host
licenseTier string
wantStatus fleet.DiskEncryptionStatus
}
cases := []testCase{
{"windows", &fleet.Host{ID: 42, Platform: "windows"}, fleet.TierPremium, fleet.DiskEncryptionEnforcing},
{"darwin", &fleet.Host{ID: 42, Platform: "darwin"}, fleet.TierPremium, ""},
// TeamID necessary to check whether disk encryption is enabled for Linux hosts, in lieu of
// MDM-related logic which doesn't apply to Linux hosts
{"ubuntu", &fleet.Host{ID: 42, Platform: "ubuntu", TeamID: ptr.Uint(1)}, fleet.TierPremium, ""},
{"not premium", &fleet.Host{ID: 42, Platform: "windows"}, fleet.TierFree, ""},
}
setupDS := func(c testCase) {
ds.AppConfigFuncInvoked = false
ds.GetMDMWindowsBitLockerStatusFuncInvoked = false
ds.GetHostMDMAppleProfilesFuncInvoked = false
ds.GetHostMDMWindowsProfilesFuncInvoked = false
ds.GetHostMDMFuncInvoked = false
ds.GetConfigEnableDiskEncryptionFuncInvoked = false
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: true, WindowsEnabledAndConfigured: true}}, nil
}
ds.GetMDMWindowsBitLockerStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostMDMDiskEncryption, error) {
if c.wantStatus == "" {
return nil, nil
}
return &fleet.HostMDMDiskEncryption{Status: &c.wantStatus, Detail: ""}, nil
}
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMAppleProfile, error) {
return nil, nil
}
ds.GetHostMDMWindowsProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMWindowsProfile, error) {
return nil, nil
}
ds.GetHostMDMFunc = func(ctx context.Context, hostID uint) (*fleet.HostMDM, error) {
hmdm := fleet.HostMDM{Enrolled: true, IsServer: false}
return &hmdm, nil
}
ds.GetConfigEnableDiskEncryptionFunc = func(ctx context.Context, teamID *uint) (fleet.DiskEncryptionConfig, error) {
// testing API response when not enabled
return fleet.DiskEncryptionConfig{}, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
setupDS(c)
ctx = license.NewContext(ctx, &fleet.LicenseInfo{Tier: c.licenseTier})
hostDetail, err := svc.getHostDetails(test.UserContext(ctx, test.UserAdmin), c.host, fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
})
require.NoError(t, err)
require.NotNil(t, hostDetail)
require.True(t, ds.AppConfigFuncInvoked)
switch c.host.Platform {
case "windows":
require.False(t, ds.GetHostMDMAppleProfilesFuncInvoked)
if c.licenseTier == fleet.TierPremium {
require.True(t, ds.GetHostMDMFuncInvoked)
} else {
require.False(t, ds.GetHostMDMFuncInvoked)
}
if c.wantStatus != "" {
require.True(t, ds.GetMDMWindowsBitLockerStatusFuncInvoked)
require.NotNil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
require.Equal(t, c.wantStatus, *hostDetail.MDM.OSSettings.DiskEncryption.Status)
} else {
require.False(t, ds.GetMDMWindowsBitLockerStatusFuncInvoked)
require.Nil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
}
case "ubuntu":
require.False(t, ds.GetHostMDMAppleProfilesFuncInvoked)
require.False(t, ds.GetMDMWindowsBitLockerStatusFuncInvoked)
// service should call this function to check whether disk encryption is enabled for a Linux host
require.True(t, ds.GetConfigEnableDiskEncryptionFuncInvoked)
// `hostDetail.MDM.OSSettings` and `hostDetail.MDM.OSSettings.DiskEncryption` will actually not
// be `nil` here due to the way those fields are initialized by `svc.ds.Host`, so we can't
// expect them to be `nil` in these tests. However, since the relevant struct tags are set to
// `omitempty`, the resulting API response WILL omit these fields/subfields when empty,
// which is confirmed at the integration layer.
require.Nil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
case "darwin":
require.True(t, ds.GetHostMDMAppleProfilesFuncInvoked)
require.False(t, ds.GetMDMWindowsBitLockerStatusFuncInvoked)
require.Nil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
default:
require.False(t, ds.GetHostMDMAppleProfilesFuncInvoked)
require.False(t, ds.GetMDMWindowsBitLockerStatusFuncInvoked)
}
})
}
}
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostDetailsOSSettingsWindowsOnly(t *testing.T) {
ds := new(mock.Store)
svc := &Service{ds: ds}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostMDMMacOSSetupFunc = func(ctx context.Context, hid uint) (*fleet.HostMDMMacOSSetup, error) {
return nil, nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{WindowsEnabledAndConfigured: true}}, nil
}
ds.GetMDMWindowsBitLockerStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostMDMDiskEncryption, error) {
verified := fleet.DiskEncryptionVerified
return &fleet.HostMDMDiskEncryption{Status: &verified, Detail: ""}, nil
}
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMAppleProfile, error) {
return nil, nil
}
ds.GetHostMDMWindowsProfilesFunc = func(ctx context.Context, uuid string) ([]fleet.HostMDMWindowsProfile, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.GetHostMDMFunc = func(ctx context.Context, hostID uint) (*fleet.HostMDM, error) {
hmdm := fleet.HostMDM{Enrolled: true, IsServer: false}
return &hmdm, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
ctx := license.NewContext(context.Background(), &fleet.LicenseInfo{Tier: fleet.TierPremium})
hostDetail, err := svc.getHostDetails(test.UserContext(ctx, test.UserAdmin), &fleet.Host{ID: 42, Platform: "windows"}, fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
})
require.NoError(t, err)
require.NotNil(t, hostDetail)
require.True(t, ds.AppConfigFuncInvoked)
require.False(t, ds.GetHostMDMAppleProfilesFuncInvoked)
require.True(t, ds.GetMDMWindowsBitLockerStatusFuncInvoked)
require.NotNil(t, hostDetail.MDM.OSSettings.DiskEncryption.Status)
require.Equal(t, fleet.DiskEncryptionVerified, *hostDetail.MDM.OSSettings.DiskEncryption.Status)
}
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostAuth(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
teamHost := &fleet.Host{TeamID: ptr.Uint(1)}
globalHost := &fleet.Host{}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.DeleteHostFunc = func(ctx context.Context, hid uint) error {
return nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if id == 1 {
return teamHost, nil
}
return globalHost, nil
}
ds.HostFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if id == 1 {
return teamHost, nil
}
return globalHost, nil
}
ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
if identifier == "1" {
return teamHost, nil
}
return globalHost, nil
}
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return nil, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) (packs []*fleet.Pack, err error) {
return nil, nil
}
ds.AddHostsToTeamFunc = func(ctx context.Context, params *fleet.AddHostsToTeamParams) error {
return nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.DeleteHostsFunc = func(ctx context.Context, ids []uint) error {
return nil
}
ds.UpdateHostRefetchRequestedFunc = func(ctx context.Context, id uint, value bool) error {
if id == 1 {
teamHost.RefetchRequested = true
} else {
globalHost.RefetchRequested = true
}
return nil
}
ds.BulkSetPendingMDMHostProfilesFunc = func(ctx context.Context, hids, tids []uint, puuids, uuids []string,
) (updates fleet.MDMProfilesUpdates, err error) {
return fleet.MDMProfilesUpdates{}, nil
}
ds.ListMDMAppleDEPSerialsInHostIDsFunc = func(ctx context.Context, hids []uint) ([]string, error) {
return nil, nil
}
ds.TeamWithExtrasFunc = func(ctx context.Context, id uint) (*fleet.Team, error) {
return &fleet.Team{ID: id}, nil
}
ds.TeamLiteFunc = func(ctx context.Context, id uint) (*fleet.TeamLite, error) {
return &fleet.TeamLite{ID: id}, nil
}
ds.NewActivityFunc = func(ctx context.Context, u *fleet.User, a fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
ds.ListHostsLiteByIDsFunc = func(ctx context.Context, ids []uint) ([]*fleet.Host, error) {
return nil, nil
}
ds.SetOrUpdateCustomHostDeviceMappingFunc = func(ctx context.Context, hostID uint, email, source string) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.ListHostUpcomingActivitiesFunc = func(ctx context.Context, hostID uint, opt fleet.ListOptions) ([]*fleet.UpcomingActivity, *fleet.PaginationMetadata, error) {
return nil, nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.ListHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, opts fleet.HostSoftwareTitleListOptions) ([]*fleet.HostSoftwareWithInstaller, *fleet.PaginationMetadata, error) {
return nil, nil, nil
}
ds.IsHostConnectedToFleetMDMFunc = func(ctx context.Context, host *fleet.Host) (bool, error) {
return true, nil
}
ds.ListHostCertificatesFunc = func(ctx context.Context, hostID uint, opts fleet.ListOptions) ([]*fleet.HostCertificateRecord, *fleet.PaginationMetadata, error) {
return nil, nil, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.GetCategoriesForSoftwareTitlesFunc = func(ctx context.Context, softwareTitleIDs []uint, team_id *uint) (map[uint][]string, error) {
return map[uint][]string{}, nil
}
ds.UpdateHostIssuesFailingPoliciesFunc = func(ctx context.Context, hostIDs []uint) error {
return nil
}
ds.UpdateHostIssuesFailingPoliciesForSingleHostFunc = func(ctx context.Context, hostID uint) error {
return nil
}
ds.GetHostIssuesLastUpdatedFunc = func(ctx context.Context, hostId uint) (time.Time, error) {
return time.Time{}, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
ds.ListMDMAndroidUUIDsToHostIDsFunc = func(ctx context.Context, hostIDs []uint) (map[string]uint, error) {
return map[string]uint{}, nil
}
testCases := []struct {
name string
user *fleet.User
shouldFailGlobalWrite bool
shouldFailGlobalRead bool
shouldFailTeamWrite bool
shouldFailTeamRead bool
}{
{
"global admin",
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
false,
false,
false,
false,
},
{
"global maintainer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
false,
false,
false,
false,
},
{
"global observer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
true,
false,
true,
false,
},
{
"team admin, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
true,
true,
false,
false,
},
{
"team maintainer, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
true,
true,
false,
false,
},
{
"team observer, belongs to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
true,
true,
true,
false,
},
{
"team admin, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}},
true,
true,
true,
true,
},
{
"team maintainer, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
true,
true,
true,
true,
},
{
"team observer, DOES NOT belong to team",
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
true,
true,
true,
true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ctx := viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
}
_, err := svc.GetHost(ctx, 1, opts)
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, err = svc.GetHostLite(ctx, 1)
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, err = svc.HostByIdentifier(ctx, "1", opts)
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, _, err = svc.ListHostUpcomingActivities(ctx, 1, fleet.ListOptions{})
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, err = svc.GetHost(ctx, 2, opts)
checkAuthErr(t, tt.shouldFailGlobalRead, err)
_, err = svc.GetHostLite(ctx, 2)
checkAuthErr(t, tt.shouldFailGlobalRead, err)
_, err = svc.HostByIdentifier(ctx, "2", opts)
checkAuthErr(t, tt.shouldFailGlobalRead, err)
_, _, err = svc.ListHostUpcomingActivities(ctx, 2, fleet.ListOptions{})
checkAuthErr(t, tt.shouldFailGlobalRead, err)
err = svc.DeleteHost(ctx, 1)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
err = svc.DeleteHost(ctx, 2)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
err = svc.DeleteHosts(ctx, []uint{1}, nil)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
err = svc.DeleteHosts(ctx, []uint{2}, nil)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
err = svc.AddHostsToTeam(ctx, ptr.Uint(1), []uint{1}, false)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
emptyFilter := make(map[string]interface{})
err = svc.AddHostsToTeamByFilter(ctx, ptr.Uint(1), &emptyFilter)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
err = svc.RefetchHost(ctx, 1)
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, err = svc.SetHostDeviceMapping(ctx, 1, "a@b.c", "custom")
checkAuthErr(t, tt.shouldFailTeamWrite, err)
_, err = svc.SetHostDeviceMapping(ctx, 2, "a@b.c", "custom")
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, _, err = svc.ListHostSoftware(ctx, 1, fleet.HostSoftwareTitleListOptions{})
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, _, err = svc.ListHostSoftware(ctx, 2, fleet.HostSoftwareTitleListOptions{})
checkAuthErr(t, tt.shouldFailGlobalRead, err)
_, _, err = svc.ListHostCertificates(ctx, 1, fleet.ListOptions{})
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, _, err = svc.ListHostCertificates(ctx, 2, fleet.ListOptions{})
checkAuthErr(t, tt.shouldFailGlobalRead, err)
})
}
// List, GetHostSummary work for all
}
func TestListHosts(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return []*fleet.Host{
{ID: 1},
}, nil
}
userContext := test.UserContext(ctx, test.UserAdmin)
hosts, err := svc.ListHosts(userContext, fleet.HostListOptions{})
require.NoError(t, err)
require.Len(t, hosts, 1)
// a user is required
_, err = svc.ListHosts(ctx, fleet.HostListOptions{})
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
var shouldIncludeCVEScores bool
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
require.Equal(t, shouldIncludeCVEScores, includeCVEScores)
return nil
}
// free license disallows getting vuln details
hosts, err = svc.ListHosts(userContext, fleet.HostListOptions{PopulateSoftware: true, PopulateSoftwareVulnerabilityDetails: true})
require.NoError(t, err)
require.Len(t, hosts, 1)
require.True(t, ds.LoadHostSoftwareFuncInvoked)
ds.LoadHostSoftwareFuncInvoked = false
// you're allowed to skip vuln details on Premium
userContext = license.NewContext(userContext, &fleet.LicenseInfo{Tier: fleet.TierPremium})
hosts, err = svc.ListHosts(userContext, fleet.HostListOptions{PopulateSoftware: true, PopulateSoftwareVulnerabilityDetails: false})
require.NoError(t, err)
require.Len(t, hosts, 1)
require.True(t, ds.LoadHostSoftwareFuncInvoked)
ds.LoadHostSoftwareFuncInvoked = false
// you're allowed to retrieve vuln details on Premium
shouldIncludeCVEScores = true
hosts, err = svc.ListHosts(userContext, fleet.HostListOptions{PopulateSoftware: true, PopulateSoftwareVulnerabilityDetails: true})
require.NoError(t, err)
require.Len(t, hosts, 1)
require.True(t, ds.LoadHostSoftwareFuncInvoked)
}
func TestStreamHosts(t *testing.T) {
t.Run("Happy path", func(t *testing.T) {
// Create a mock iterator for the hosts.
hostIterator := func() iter.Seq2[*fleet.HostResponse, error] {
return func(yield func(*fleet.HostResponse, error) bool) {
for i := 1; i <= 3; i++ {
host := &fleet.HostResponse{Host: &fleet.Host{ID: uint(i)}} // nolint:gosec
if !yield(host, nil) {
return
}
}
}
}
resp := streamHostsResponse{
HostResponseIterator: hostIterator(),
listHostsResponse: listHostsResponse{
Software: &fleet.Software{
ID: uint(1),
},
SoftwareTitle: &fleet.SoftwareTitle{ID: uint(2)},
MDMSolution: &fleet.MDMSolution{
ID: uint(3),
},
MunkiIssue: &fleet.MunkiIssue{
ID: uint(4),
},
},
}
rr := httptest.NewRecorder()
resp.HijackRender(context.Background(), rr)
require.Equal(t, rr.Code, 200)
// Get the body into a string.
body := rr.Body.String()
// Unmarshal the string into a map.
var results map[string]any
err := json.Unmarshal([]byte(body), &results)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
// Assert that software.id == 1
require.Equal(t, float64(1), results["software"].(map[string]any)["id"])
// Assert that software_title.id == 2
require.Equal(t, float64(2), results["software_title"].(map[string]any)["id"])
// Assert that mdm_solution.id == 3
require.Equal(t, float64(3), results["mobile_device_management_solution"].(map[string]any)["id"])
// Assert that munki_issue.id == 4
require.Equal(t, float64(4), results["munki_issue"].(map[string]any)["id"])
// Assert that hosts array has length 3
hosts := results["hosts"].([]any)
require.Len(t, hosts, 3)
// Assert that host IDs are 1, 2, 3
for i, host := range hosts {
hostMap := host.(map[string]any)
require.Equal(t, float64(i+1), hostMap["id"])
}
// Assert that the output contains no error message
_, exists := results["error"]
require.False(t, exists)
})
t.Run("Minimal data", func(t *testing.T) {
// Create a mock iterator for the hosts.
hostIterator := func() iter.Seq2[*fleet.HostResponse, error] {
return func(yield func(*fleet.HostResponse, error) bool) {
// Yield no hosts.
}
}
resp := streamHostsResponse{
HostResponseIterator: hostIterator(),
listHostsResponse: listHostsResponse{},
}
rr := httptest.NewRecorder()
resp.HijackRender(context.Background(), rr)
require.Equal(t, rr.Code, 200)
// Get the body into a string.
body := rr.Body.String()
// Unmarshal the string into a map.
var results map[string]any
err := json.Unmarshal([]byte(body), &results)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
_, ok := results["software"]
require.False(t, ok)
_, ok = results["software_title"]
require.False(t, ok)
_, ok = results["mobile_device_management_solution"]
require.False(t, ok)
_, ok = results["munki_issue"]
require.False(t, ok)
hosts := results["hosts"].([]any)
require.Len(t, hosts, 0)
// Assert that the output contains no error message
_, exists := results["error"]
require.False(t, exists)
})
errorTestCases := []struct {
Name string
ExpectedError string
}{
{
"Error marshalling Software",
"marshaling software",
},
{
"Error marshalling SoftwareTitle",
"marshaling software_title",
},
{
"Error marshalling MDMSolution",
"marshaling mobile_device_management_solution",
},
{
"Error marshalling MunkiIssue",
"marshaling munki_issue",
},
{
"Error iterating over Hosts",
"getting host",
},
{
"Error marshalling Hosts",
"marshaling host response",
},
}
for _, tc := range errorTestCases {
t.Run(tc.Name, func(t *testing.T) {
hostIterator := func() iter.Seq2[*fleet.HostResponse, error] {
return func(yield func(*fleet.HostResponse, error) bool) {
// Yield one good host.
host := &fleet.HostResponse{Host: &fleet.Host{ID: uint(1)}}
if !yield(host, nil) {
return
}
if tc.Name == "Error iterating over Hosts" {
// Yield an error immediately.
yield(nil, errors.New("getting host"))
return
}
host = &fleet.HostResponse{Host: &fleet.Host{ID: uint(2)}}
if !yield(host, nil) {
return
}
}
}
resp := streamHostsResponse{
HostResponseIterator: hostIterator(),
listHostsResponse: listHostsResponse{
Software: &fleet.Software{
ID: uint(1),
},
SoftwareTitle: &fleet.SoftwareTitle{ID: uint(2)},
MDMSolution: &fleet.MDMSolution{
ID: uint(3),
},
MunkiIssue: &fleet.MunkiIssue{
ID: uint(4),
},
},
MarshalJSON: func(v any) ([]byte, error) {
switch v.(type) {
case *fleet.Software:
if tc.Name == "Error marshalling Software" {
return nil, errors.New(`got some "error" marshaling {software}`)
}
case *fleet.SoftwareTitle:
if tc.Name == "Error marshalling SoftwareTitle" {
return nil, errors.New(`got some "error" marshaling {software title}`)
}
case *fleet.MDMSolution:
if tc.Name == "Error marshalling MDMSolution" {
return nil, errors.New(`got some "error" marshaling {mdm solution}`)
}
case *fleet.MunkiIssue:
if tc.Name == "Error marshalling MunkiIssue" {
return nil, errors.New(`got some "error" marshaling {munki issue}`)
}
case *fleet.HostResponse:
if tc.Name == "Error marshalling Hosts" {
return nil, errors.New(`got some "error" marshaling {host response}`)
}
}
// Default to normal marshalling.
return json.Marshal(v)
},
}
rr := httptest.NewRecorder()
resp.HijackRender(context.Background(), rr)
// Assert that the output contains the error message
require.Equal(t, rr.Code, 200)
body := rr.Body.String()
// Unmarshal the string into a map.
var results map[string]any
err := json.Unmarshal([]byte(body), &results)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
// Assert that error message is present
require.Contains(t, results["error"], tc.ExpectedError)
// If the error isn't in the hosts array, ensure that no hosts were returned.
hosts, ok := results["hosts"].([]any)
if tc.Name != "Error marshalling Hosts" && tc.Name != "Error iterating over Hosts" {
require.False(t, ok)
} else {
require.True(t, ok)
if tc.Name == "Error iterating over Hosts" {
require.Len(t, hosts, 1)
} else {
require.Len(t, hosts, 0)
}
}
})
}
}
func TestGetHostSummary(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.GenerateHostStatusStatisticsFunc = func(ctx context.Context, filter fleet.TeamFilter, now time.Time, platform *string, lowDiskSpace *int) (*fleet.HostSummary, error) {
return &fleet.HostSummary{
OnlineCount: 1,
OfflineCount: 5, // offline hosts also includes mia hosts as of Fleet 4.15
MIACount: 3,
NewCount: 4,
TotalsHostsCount: 5,
Platforms: []*fleet.HostSummaryPlatform{{Platform: "darwin", HostsCount: 1}, {Platform: "debian", HostsCount: 2}, {Platform: "centos", HostsCount: 3}, {Platform: "ubuntu", HostsCount: 4}},
}, nil
}
ds.LabelsSummaryFunc = func(ctx context.Context, filter fleet.TeamFilter) ([]*fleet.LabelSummary, error) {
return []*fleet.LabelSummary{{ID: 1, Name: "All hosts", Description: "All hosts enrolled in Fleet", LabelType: fleet.LabelTypeBuiltIn}, {ID: 10, Name: "Other label", Description: "Not a builtin label", LabelType: fleet.LabelTypeRegular}}, nil
}
summary, err := svc.GetHostSummary(test.UserContext(ctx, test.UserAdmin), nil, nil, nil)
require.NoError(t, err)
require.Nil(t, summary.TeamID)
require.Equal(t, uint(1), summary.OnlineCount)
require.Equal(t, uint(5), summary.OfflineCount)
require.Equal(t, uint(3), summary.MIACount)
require.Equal(t, uint(4), summary.NewCount)
require.Equal(t, uint(5), summary.TotalsHostsCount)
require.Len(t, summary.Platforms, 4)
require.Equal(t, uint(9), summary.AllLinuxCount)
require.Nil(t, summary.LowDiskSpaceCount)
require.Len(t, summary.BuiltinLabels, 1)
require.Equal(t, "All hosts", summary.BuiltinLabels[0].Name)
// a user is required
_, err = svc.GetHostSummary(ctx, nil, nil, nil)
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}
func TestDeleteHost(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc, ctx := newTestService(t, ds, nil, nil)
// Create a user for the deletion (needed for activity creation)
user := &fleet.User{
Name: "Test User",
Email: "testuser@example.com",
GlobalRole: ptr.String(fleet.RoleAdmin),
Password: []byte("password"),
Salt: "salt",
}
user, err := ds.NewUser(ctx, user)
require.NoError(t, err)
mockClock := clock.NewMockClock()
host := test.NewHost(t, ds, "foo", "192.168.1.10", "1", "1", mockClock.Now())
assert.NotZero(t, host.ID)
err = svc.DeleteHost(test.UserContext(ctx, user), host.ID)
assert.Nil(t, err)
filter := fleet.TeamFilter{User: user}
hosts, err := ds.ListHosts(ctx, filter, fleet.HostListOptions{})
assert.Nil(t, err)
assert.Len(t, hosts, 0)
}
func TestDeleteHostCreatesActivity(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc, ctx := newTestService(t, ds, nil, nil)
// Create a user for the deletion
user := &fleet.User{
Name: "Test User",
Email: "testuser@example.com",
GlobalRole: ptr.String(fleet.RoleAdmin),
Password: []byte("password"),
Salt: "salt",
}
user, err := ds.NewUser(ctx, user)
require.NoError(t, err)
mockClock := clock.NewMockClock()
host := test.NewHost(t, ds, "foo", "192.168.1.10", "1", "1", mockClock.Now())
host.HardwareSerial = "ABC123"
host.ComputerName = "Test Computer"
err = ds.UpdateHost(ctx, host)
require.NoError(t, err)
// Get activities before deletion
prevActivities, _, err := ds.ListActivities(ctx, fleet.ListActivitiesOptions{})
require.NoError(t, err)
// Delete the host
err = svc.DeleteHost(test.UserContext(ctx, user), host.ID)
require.NoError(t, err)
// Verify the activity was created
activities, _, err := ds.ListActivities(ctx, fleet.ListActivitiesOptions{
ListOptions: fleet.ListOptions{
OrderKey: "id",
OrderDirection: fleet.OrderDescending,
PerPage: 1,
},
})
require.NoError(t, err)
require.Len(t, activities, 1)
require.Greater(t, len(activities), len(prevActivities)-1)
activity := activities[0]
expectedActivityType := fleet.ActivityTypeDeletedHost{}.ActivityName()
require.Equal(t, expectedActivityType, activity.Type)
require.NotNil(t, activity.Details)
var details fleet.ActivityTypeDeletedHost
err = json.Unmarshal(*activity.Details, &details)
require.NoError(t, err)
require.Equal(t, host.ID, details.HostID)
require.Equal(t, "Test Computer", details.HostDisplayName)
require.Equal(t, "ABC123", details.HostSerial)
require.Equal(t, fleet.DeletedHostTriggeredByManual, details.TriggeredBy)
}
func TestDeleteHostsCreatesActivities(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc, ctx := newTestService(t, ds, nil, nil)
// Create a user for the deletion
user := &fleet.User{
Name: "Test User",
Email: "testuser@example.com",
GlobalRole: ptr.String(fleet.RoleAdmin),
Password: []byte("password"),
Salt: "salt",
}
user, err := ds.NewUser(ctx, user)
require.NoError(t, err)
mockClock := clock.NewMockClock()
// Create multiple hosts
host1 := test.NewHost(t, ds, "host1", "192.168.1.10", "1", "1", mockClock.Now())
host1.HardwareSerial = "SERIAL1"
host1.ComputerName = "Computer 1"
err = ds.UpdateHost(ctx, host1)
require.NoError(t, err)
host2 := test.NewHost(t, ds, "host2", "192.168.1.11", "2", "2", mockClock.Now())
host2.HardwareSerial = "SERIAL2"
host2.ComputerName = "Computer 2"
err = ds.UpdateHost(ctx, host2)
require.NoError(t, err)
// Get activities before deletion
prevActivities, _, err := ds.ListActivities(ctx, fleet.ListActivitiesOptions{})
require.NoError(t, err)
// Delete the hosts
err = svc.DeleteHosts(test.UserContext(ctx, user), []uint{host1.ID, host2.ID}, nil)
require.NoError(t, err)
// Verify activities were created
activities, _, err := ds.ListActivities(ctx, fleet.ListActivitiesOptions{
ListOptions: fleet.ListOptions{
OrderKey: "id",
OrderDirection: fleet.OrderDescending,
PerPage: 10,
},
})
require.NoError(t, err)
require.GreaterOrEqual(t, len(activities), 2)
// Verify we have at least 2 more activities than before
require.Greater(t, len(activities), len(prevActivities)-1)
// Check the first two activities are for deleted hosts
expectedActivityType := fleet.ActivityTypeDeletedHost{}.ActivityName()
for i := 0; i < 2; i++ {
activity := activities[i]
require.Equal(t, expectedActivityType, activity.Type)
require.NotNil(t, activity.Details)
var details fleet.ActivityTypeDeletedHost
err = json.Unmarshal(*activity.Details, &details)
require.NoError(t, err)
require.Contains(t, []uint{host1.ID, host2.ID}, details.HostID)
require.Contains(t, []string{"Computer 1", "Computer 2"}, details.HostDisplayName)
require.Contains(t, []string{"SERIAL1", "SERIAL2"}, details.HostSerial)
require.Equal(t, fleet.DeletedHostTriggeredByManual, details.TriggeredBy)
}
}
func TestCleanupExpiredHostsActivities(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc, ctx := newTestService(t, ds, nil, nil)
// Set global host expiry
const globalExpiryWindow = 10
const team1ExpiryWindow = 5
const team2ExpiryWindow = 15
ac, err := ds.AppConfig(ctx)
require.NoError(t, err)
ac.HostExpirySettings.HostExpiryEnabled = true
ac.HostExpirySettings.HostExpiryWindow = globalExpiryWindow
err = ds.SaveAppConfig(ctx, ac)
require.NoError(t, err)
// Create Team 1 with custom expiry window
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team1.Config.HostExpirySettings.HostExpiryEnabled = true
team1.Config.HostExpirySettings.HostExpiryWindow = team1ExpiryWindow
_, err = ds.SaveTeam(ctx, team1)
require.NoError(t, err)
// Create Team 2 with different custom expiry window
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
team2.Config.HostExpirySettings.HostExpiryEnabled = true
team2.Config.HostExpirySettings.HostExpiryWindow = team2ExpiryWindow
_, err = ds.SaveTeam(ctx, team2)
require.NoError(t, err)
// Create Team 3 that uses global expiry (no custom setting)
team3, err := ds.NewTeam(ctx, &fleet.Team{Name: "team3"})
require.NoError(t, err)
// team3 does not have custom host expiry settings, so it uses global
mockClock := clock.NewMockClock()
// Create expired hosts for Team 1 (use team1ExpiryWindow)
team1ExpiredTime := mockClock.Now().Add(-time.Duration(team1ExpiryWindow+1) * 24 * time.Hour)
host1 := test.NewHost(t, ds, "team1-host1", "192.168.1.10", "1", "1", team1ExpiredTime)
host1.HardwareSerial = "TEAM1_SERIAL1"
host1.ComputerName = "Team 1 Computer 1"
host1.TeamID = &team1.ID
err = ds.UpdateHost(ctx, host1)
require.NoError(t, err)
host2 := test.NewHost(t, ds, "team1-host2", "192.168.1.11", "2", "2", team1ExpiredTime)
host2.HardwareSerial = "TEAM1_SERIAL2"
host2.ComputerName = "Team 1 Computer 2"
host2.TeamID = &team1.ID
err = ds.UpdateHost(ctx, host2)
require.NoError(t, err)
// Create expired host for Team 2 (use team2ExpiryWindow)
team2ExpiredTime := mockClock.Now().Add(-time.Duration(team2ExpiryWindow+1) * 24 * time.Hour)
host3 := test.NewHost(t, ds, "team2-host1", "192.168.1.12", "3", "3", team2ExpiredTime)
host3.HardwareSerial = "TEAM2_SERIAL1"
host3.ComputerName = "Team 2 Computer 1"
host3.TeamID = &team2.ID
err = ds.UpdateHost(ctx, host3)
require.NoError(t, err)
// Create expired host for Team 3 (uses global expiry)
globalExpiredTime := mockClock.Now().Add(-time.Duration(globalExpiryWindow+1) * 24 * time.Hour)
host4 := test.NewHost(t, ds, "team3-host1", "192.168.1.13", "4", "4", globalExpiredTime)
host4.HardwareSerial = "TEAM3_SERIAL1"
host4.ComputerName = "Team 3 Computer 1"
host4.TeamID = &team3.ID
err = ds.UpdateHost(ctx, host4)
require.NoError(t, err)
// Create expired host with no team (uses global expiry)
host5 := test.NewHost(t, ds, "no-team-host", "192.168.1.14", "5", "5", globalExpiredTime)
host5.HardwareSerial = "NOTEAM_SERIAL1"
host5.ComputerName = "No Team Computer 1"
host5.TeamID = nil
err = ds.UpdateHost(ctx, host5)
require.NoError(t, err)
// Get activities before cleanup
prevActivities, _, err := ds.ListActivities(ctx, fleet.ListActivitiesOptions{})
require.NoError(t, err)
// Run the cleanup service method
deletedHosts, err := svc.CleanupExpiredHosts(ctx)
require.NoError(t, err)
require.Len(t, deletedHosts, 5, "Should have deleted 5 hosts")
// Verify activities were created
activities, _, err := ds.ListActivities(ctx, fleet.ListActivitiesOptions{
ListOptions: fleet.ListOptions{
OrderKey: "id",
OrderDirection: fleet.OrderDescending,
PerPage: 20,
},
})
require.NoError(t, err)
require.Greater(t, len(activities), len(prevActivities), "Should have new activities")
// Collect all deleted host activities
type hostActivity struct {
hostID uint
displayName string
serial string
expiryWindow int
}
deletedHostActivities := []hostActivity{}
expectedActivityType := fleet.ActivityTypeDeletedHost{}.ActivityName()
for _, activity := range activities {
if activity.Type != expectedActivityType {
continue
}
require.NotNil(t, activity.Details)
var details fleet.ActivityTypeDeletedHost
err = json.Unmarshal(*activity.Details, &details)
require.NoError(t, err)
if details.TriggeredBy == fleet.DeletedHostTriggeredByExpiration {
require.NotNil(t, details.HostExpiryWindow, "HostExpiryWindow should be set for expired hosts")
deletedHostActivities = append(deletedHostActivities, hostActivity{
hostID: details.HostID,
displayName: details.HostDisplayName,
serial: details.HostSerial,
expiryWindow: *details.HostExpiryWindow,
})
}
}
require.Len(t, deletedHostActivities, 5, "Should have 5 deleted host activities")
// Verify each host has the correct expiry window
for _, ha := range deletedHostActivities {
switch ha.hostID {
case host1.ID:
require.Equal(t, "Team 1 Computer 1", ha.displayName)
require.Equal(t, "TEAM1_SERIAL1", ha.serial)
require.Equal(t, team1ExpiryWindow, ha.expiryWindow, "Team 1 host should have team1 expiry window")
case host2.ID:
require.Equal(t, "Team 1 Computer 2", ha.displayName)
require.Equal(t, "TEAM1_SERIAL2", ha.serial)
require.Equal(t, team1ExpiryWindow, ha.expiryWindow, "Team 1 host should have team1 expiry window")
case host3.ID:
require.Equal(t, "Team 2 Computer 1", ha.displayName)
require.Equal(t, "TEAM2_SERIAL1", ha.serial)
require.Equal(t, team2ExpiryWindow, ha.expiryWindow, "Team 2 host should have team2 expiry window")
case host4.ID:
require.Equal(t, "Team 3 Computer 1", ha.displayName)
require.Equal(t, "TEAM3_SERIAL1", ha.serial)
require.Equal(t, globalExpiryWindow, ha.expiryWindow, "Team 3 host should use global expiry window")
case host5.ID:
require.Equal(t, "No Team Computer 1", ha.displayName)
require.Equal(t, "NOTEAM_SERIAL1", ha.serial)
require.Equal(t, globalExpiryWindow, ha.expiryWindow, "No team host should use global expiry window")
default:
t.Fatalf("Unexpected host ID in activities: %d", ha.hostID)
}
}
}
func TestAddHostsToTeamByFilter(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
expectedHostIDs := []uint{1, 2, 4}
expectedTeam := (*uint)(nil)
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
var hosts []*fleet.Host
for _, id := range expectedHostIDs {
hosts = append(hosts, &fleet.Host{ID: id})
}
return hosts, nil
}
ds.AddHostsToTeamFunc = func(ctx context.Context, params *fleet.AddHostsToTeamParams) error {
assert.Equal(t, expectedTeam, params.TeamID)
assert.Equal(t, expectedHostIDs, params.HostIDs)
return nil
}
ds.BulkSetPendingMDMHostProfilesFunc = func(ctx context.Context, hids, tids []uint, puuids, uuids []string,
) (updates fleet.MDMProfilesUpdates, err error) {
return fleet.MDMProfilesUpdates{}, nil
}
ds.ListMDMAppleDEPSerialsInHostIDsFunc = func(ctx context.Context, hids []uint) ([]string, error) {
return nil, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
return nil
}
emptyRequest := &map[string]interface{}{}
require.NoError(t, svc.AddHostsToTeamByFilter(test.UserContext(ctx, test.UserAdmin), expectedTeam, emptyRequest))
assert.True(t, ds.ListHostsFuncInvoked)
assert.True(t, ds.AddHostsToTeamFuncInvoked)
}
func TestAddHostsToTeamByFilterLabel(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
expectedHostIDs := []uint{6}
expectedTeam := ptr.Uint(1)
expectedLabel := float64(2)
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.ListHostsInLabelFunc = func(ctx context.Context, filter fleet.TeamFilter, lid uint, opt fleet.HostListOptions) ([]*fleet.Host, error) {
assert.Equal(t, uint(expectedLabel), lid)
var hosts []*fleet.Host
for _, id := range expectedHostIDs {
hosts = append(hosts, &fleet.Host{ID: id})
}
return hosts, nil
}
ds.AddHostsToTeamFunc = func(ctx context.Context, params *fleet.AddHostsToTeamParams) error {
assert.Equal(t, expectedHostIDs, params.HostIDs)
return nil
}
ds.BulkSetPendingMDMHostProfilesFunc = func(ctx context.Context, hids, tids []uint, puuids, uuids []string,
) (updates fleet.MDMProfilesUpdates, err error) {
return fleet.MDMProfilesUpdates{}, nil
}
ds.ListMDMAppleDEPSerialsInHostIDsFunc = func(ctx context.Context, hids []uint) ([]string, error) {
return nil, nil
}
ds.TeamLiteFunc = func(ctx context.Context, id uint) (*fleet.TeamLite, error) {
return &fleet.TeamLite{ID: id}, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
return nil
}
filter := &map[string]interface{}{"label_id": expectedLabel}
require.NoError(t, svc.AddHostsToTeamByFilter(test.UserContext(ctx, test.UserAdmin), expectedTeam, filter))
assert.True(t, ds.ListHostsInLabelFuncInvoked)
assert.True(t, ds.AddHostsToTeamFuncInvoked)
}
func TestAddHostsToTeamByFilterEmptyHosts(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return []*fleet.Host{}, nil
}
ds.AddHostsToTeamFunc = func(ctx context.Context, params *fleet.AddHostsToTeamParams) error {
return nil
}
ds.BulkSetPendingMDMHostProfilesFunc = func(ctx context.Context, hids, tids []uint, puuids, uuids []string,
) (updates fleet.MDMProfilesUpdates, err error) {
return fleet.MDMProfilesUpdates{}, nil
}
emptyFilter := &map[string]interface{}{}
require.NoError(t, svc.AddHostsToTeamByFilter(test.UserContext(ctx, test.UserAdmin), nil, emptyFilter))
assert.True(t, ds.ListHostsFuncInvoked)
assert.False(t, ds.AddHostsToTeamFuncInvoked)
}
func TestRefetchHost(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
host := &fleet.Host{ID: 3}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return host, nil
}
ds.UpdateHostRefetchRequestedFunc = func(ctx context.Context, id uint, value bool) error {
assert.Equal(t, host.ID, id)
assert.True(t, value)
return nil
}
require.NoError(t, svc.RefetchHost(test.UserContext(ctx, test.UserAdmin), host.ID))
require.NoError(t, svc.RefetchHost(test.UserContext(ctx, test.UserObserver), host.ID))
require.NoError(t, svc.RefetchHost(test.UserContext(ctx, test.UserObserverPlus), host.ID))
require.NoError(t, svc.RefetchHost(test.UserContext(ctx, test.UserMaintainer), host.ID))
assert.True(t, ds.HostLiteFuncInvoked)
assert.True(t, ds.UpdateHostRefetchRequestedFuncInvoked)
}
func TestRefetchHostUserInTeams(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
host := &fleet.Host{ID: 3, TeamID: ptr.Uint(4)}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return host, nil
}
ds.UpdateHostRefetchRequestedFunc = func(ctx context.Context, id uint, value bool) error {
assert.Equal(t, host.ID, id)
assert.True(t, value)
return nil
}
maintainer := &fleet.User{
Teams: []fleet.UserTeam{
{
Team: fleet.Team{ID: 4},
Role: fleet.RoleMaintainer,
},
},
}
require.NoError(t, svc.RefetchHost(test.UserContext(ctx, maintainer), host.ID))
assert.True(t, ds.HostLiteFuncInvoked)
assert.True(t, ds.UpdateHostRefetchRequestedFuncInvoked)
ds.HostLiteFuncInvoked, ds.UpdateHostRefetchRequestedFuncInvoked = false, false
observer := &fleet.User{
Teams: []fleet.UserTeam{
{
Team: fleet.Team{ID: 4},
Role: fleet.RoleObserver,
},
},
}
require.NoError(t, svc.RefetchHost(test.UserContext(ctx, observer), host.ID))
assert.True(t, ds.HostLiteFuncInvoked)
assert.True(t, ds.UpdateHostRefetchRequestedFuncInvoked)
}
func TestEmptyTeamOSVersions(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
testVersions := []fleet.OSVersion{{HostsCount: 1, Name: "macOS 12.1", Platform: "darwin"}}
ds.TeamExistsFunc = func(ctx context.Context, teamID uint) (bool, error) {
if teamID == 3 {
return false, nil
}
return true, nil
}
ds.OSVersionsFunc = func(
ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string,
) (*fleet.OSVersions, error) {
if *teamFilter.TeamID == 1 {
return &fleet.OSVersions{CountsUpdatedAt: time.Now(), OSVersions: testVersions}, nil
}
if *teamFilter.TeamID == 4 {
return nil, errors.New("some unknown error")
}
return nil, newNotFoundError()
}
ds.ListVulnsByMultipleOSVersionsFunc = func(ctx context.Context, osVersions []fleet.OSVersion, includeCVSS bool,
teamID *uint, maxVulnerabilities *int,
) (map[string]fleet.OSVulnerabilitiesWithCount, error) {
return nil, nil
}
// team exists with stats
vers, _, _, err := svc.OSVersions(test.UserContext(ctx, test.UserAdmin), ptr.Uint(1), ptr.String("darwin"), nil, nil, fleet.ListOptions{}, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 1)
// team exists but no stats
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), ptr.Uint(2), ptr.String("darwin"), nil, nil, fleet.ListOptions{}, false, nil)
require.NoError(t, err)
assert.Empty(t, vers.OSVersions)
// team does not exist
_, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), ptr.Uint(3), ptr.String("darwin"), nil, nil, fleet.ListOptions{}, false, nil)
require.Error(t, err)
require.Contains(t, fmt.Sprint(err), "does not exist")
// some unknown error
_, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), ptr.Uint(4), ptr.String("darwin"), nil, nil, fleet.ListOptions{}, false, nil)
require.Error(t, err)
require.Equal(t, "some unknown error", fmt.Sprint(err))
}
func TestOSVersionsListOptions(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
testVersions := []fleet.OSVersion{
{HostsCount: 4, NameOnly: "Windows 11 Pro 22H2", Platform: "windows"},
{HostsCount: 1, NameOnly: "macOS 12.1", Platform: "darwin"},
{HostsCount: 2, NameOnly: "macOS 12.2", Platform: "darwin"},
{HostsCount: 3, NameOnly: "Windows 11 Pro 21H2", Platform: "windows"},
{HostsCount: 5, NameOnly: "Ubuntu 20.04", Platform: "ubuntu"},
{HostsCount: 6, NameOnly: "Ubuntu 21.04", Platform: "ubuntu"},
}
now := time.Now()
ds.OSVersionsFunc = func(
ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string,
) (*fleet.OSVersions, error) {
return &fleet.OSVersions{CountsUpdatedAt: now, OSVersions: testVersions}, nil
}
ds.ListVulnsByMultipleOSVersionsFunc = func(ctx context.Context, osVersions []fleet.OSVersion, includeCVSS bool,
teamID *uint, maxVulnerabilities *int,
) (map[string]fleet.OSVulnerabilitiesWithCount, error) {
return nil, nil
}
// test default descending count sort
opts := fleet.ListOptions{}
vers, _, _, err := svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 6)
assert.Equal(t, "Ubuntu 21.04", vers.OSVersions[0].NameOnly)
assert.Equal(t, "Ubuntu 20.04", vers.OSVersions[1].NameOnly)
assert.Equal(t, "Windows 11 Pro 22H2", vers.OSVersions[2].NameOnly)
assert.Equal(t, "Windows 11 Pro 21H2", vers.OSVersions[3].NameOnly)
assert.Equal(t, "macOS 12.2", vers.OSVersions[4].NameOnly)
assert.Equal(t, "macOS 12.1", vers.OSVersions[5].NameOnly)
assert.Equal(t, now, vers.CountsUpdatedAt)
// test ascending count sort
opts = fleet.ListOptions{OrderKey: "hosts_count", OrderDirection: fleet.OrderAscending}
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 6)
assert.Equal(t, "macOS 12.1", vers.OSVersions[0].NameOnly)
assert.Equal(t, "macOS 12.2", vers.OSVersions[1].NameOnly)
assert.Equal(t, "Windows 11 Pro 21H2", vers.OSVersions[2].NameOnly)
assert.Equal(t, "Windows 11 Pro 22H2", vers.OSVersions[3].NameOnly)
assert.Equal(t, "Ubuntu 20.04", vers.OSVersions[4].NameOnly)
assert.Equal(t, "Ubuntu 21.04", vers.OSVersions[5].NameOnly)
assert.Equal(t, now, vers.CountsUpdatedAt)
// pagination
opts = fleet.ListOptions{Page: 0, PerPage: 2}
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 2)
assert.Equal(t, "Ubuntu 21.04", vers.OSVersions[0].NameOnly)
assert.Equal(t, "Ubuntu 20.04", vers.OSVersions[1].NameOnly)
assert.Equal(t, now, vers.CountsUpdatedAt)
opts = fleet.ListOptions{Page: 1, PerPage: 2}
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 2)
assert.Equal(t, "Windows 11 Pro 22H2", vers.OSVersions[0].NameOnly)
assert.Equal(t, "Windows 11 Pro 21H2", vers.OSVersions[1].NameOnly)
assert.Equal(t, now, vers.CountsUpdatedAt)
// pagination + ascending hosts_count sort
opts = fleet.ListOptions{Page: 0, PerPage: 2, OrderKey: "hosts_count", OrderDirection: fleet.OrderAscending}
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 2)
assert.Equal(t, "macOS 12.1", vers.OSVersions[0].NameOnly)
assert.Equal(t, "macOS 12.2", vers.OSVersions[1].NameOnly)
assert.Equal(t, now, vers.CountsUpdatedAt)
// per page too high
opts = fleet.ListOptions{Page: 0, PerPage: 1000}
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 6)
assert.Equal(t, now, vers.CountsUpdatedAt)
// Page number too high
opts = fleet.ListOptions{Page: 1000, PerPage: 2}
vers, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 0)
assert.Equal(t, now, vers.CountsUpdatedAt)
}
func TestOSVersionsDefaultPagination(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
testVersions := []fleet.OSVersion{}
for i := range 50 {
testVersions = append(testVersions, fleet.OSVersion{NameOnly: fmt.Sprintf("Version %02d", i), HostsCount: i, Platform: "windows"})
}
ds.OSVersionsFunc = func(
ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string,
) (*fleet.OSVersions, error) {
return &fleet.OSVersions{CountsUpdatedAt: time.Now(), OSVersions: testVersions}, nil
}
ds.ListVulnsByMultipleOSVersionsFunc = func(ctx context.Context, osVersions []fleet.OSVersion, includeCVSS bool,
teamID *uint, maxVulnerabilities *int,
) (map[string]fleet.OSVulnerabilitiesWithCount, error) {
return nil, nil
}
// test default descending count sort + default pagination (page 0, per_page 20)
opts := fleet.ListOptions{}
vers, _, _, err := svc.OSVersions(test.UserContext(ctx, test.UserAdmin), nil, nil, nil, nil, opts, false, nil)
require.NoError(t, err)
assert.Len(t, vers.OSVersions, 20)
assert.Equal(t, "Version 49", vers.OSVersions[0].NameOnly)
assert.Equal(t, "Version 30", vers.OSVersions[19].NameOnly)
}
func TestHostEncryptionKey(t *testing.T) {
cases := []struct {
name string
host *fleet.Host
allowedUsers []*fleet.User
disallowedUsers []*fleet.User
}{
{
name: "global host",
host: &fleet.Host{
ID: 1,
Platform: "darwin",
NodeKey: ptr.String("test_key"),
Hostname: "test_hostname",
UUID: "test_uuid",
TeamID: nil,
},
allowedUsers: []*fleet.User{
test.UserAdmin,
test.UserMaintainer,
test.UserObserver,
test.UserObserverPlus,
},
disallowedUsers: []*fleet.User{
test.UserTeamAdminTeam1,
test.UserTeamMaintainerTeam1,
test.UserTeamObserverTeam1,
test.UserNoRoles,
},
},
{
name: "team host",
host: &fleet.Host{
ID: 2,
Platform: "darwin",
NodeKey: ptr.String("test_key_2"),
Hostname: "test_hostname_2",
UUID: "test_uuid_2",
TeamID: ptr.Uint(1),
},
allowedUsers: []*fleet.User{
test.UserAdmin,
test.UserMaintainer,
test.UserObserver,
test.UserObserverPlus,
test.UserTeamAdminTeam1,
test.UserTeamMaintainerTeam1,
test.UserTeamObserverTeam1,
test.UserTeamObserverPlusTeam1,
},
disallowedUsers: []*fleet.User{
test.UserTeamAdminTeam2,
test.UserTeamMaintainerTeam2,
test.UserTeamObserverTeam2,
test.UserTeamObserverPlusTeam2,
test.UserNoRoles,
},
},
}
testCert, testKey, err := apple_mdm.NewSCEPCACertKey()
require.NoError(t, err)
testCertPEM := tokenpki.PEMCertificate(testCert.Raw)
testKeyPEM := tokenpki.PEMRSAPrivateKey(testKey)
fleetCfg := config.TestConfig()
config.SetTestMDMConfig(t, &fleetCfg, testCertPEM, testKeyPEM, "")
recoveryKey := "AAA-BBB-CCC"
encryptedKey, err := pkcs7.Encrypt([]byte(recoveryKey), []*x509.Certificate{testCert})
require.NoError(t, err)
base64EncryptedKey := base64.StdEncoding.EncodeToString(encryptedKey)
wstep, _, _, err := fleetCfg.MDM.MicrosoftWSTEP()
require.NoError(t, err)
winEncryptedKey, err := pkcs7.Encrypt([]byte(recoveryKey), []*x509.Certificate{wstep.Leaf})
require.NoError(t, err)
winBase64EncryptedKey := base64.StdEncoding.EncodeToString(winEncryptedKey)
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
ds := new(mock.Store)
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: true}}, nil
}
svc, ctx := newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
require.Equal(t, tt.host.ID, id)
return tt.host, nil
}
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{
Base64Encrypted: base64EncryptedKey,
Decryptable: ptr.Bool(true),
}, nil
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
act := activity.(fleet.ActivityTypeReadHostDiskEncryptionKey)
require.Equal(t, tt.host.ID, act.HostID)
require.Equal(t, []uint{tt.host.ID}, act.HostIDs())
require.EqualValues(t, act.HostDisplayName, tt.host.DisplayName())
return nil
}
ds.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName,
_ sqlx.QueryerContext,
) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) {
return map[fleet.MDMAssetName]fleet.MDMConfigAsset{
fleet.MDMAssetCACert: {Name: fleet.MDMAssetCACert, Value: testCertPEM},
fleet.MDMAssetCAKey: {Name: fleet.MDMAssetCAKey, Value: testKeyPEM},
}, nil
}
t.Run("allowed users", func(t *testing.T) {
for _, u := range tt.allowedUsers {
_, err := svc.HostEncryptionKey(test.UserContext(ctx, u), tt.host.ID)
require.NoError(t, err)
}
})
t.Run("disallowed users", func(t *testing.T) {
for _, u := range tt.disallowedUsers {
_, err := svc.HostEncryptionKey(test.UserContext(ctx, u), tt.host.ID)
require.Error(t, err)
require.Contains(t, authz.ForbiddenErrorMessage, err.Error())
}
})
t.Run("no user in context", func(t *testing.T) {
_, err := svc.HostEncryptionKey(ctx, tt.host.ID)
require.Error(t, err)
require.Contains(t, authz.ForbiddenErrorMessage, err.Error())
})
})
}
t.Run("test error cases", func(t *testing.T) {
ds := new(mock.Store)
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: true}}, nil
}
svc, ctx := newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
hostErr := errors.New("host error")
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return nil, hostErr
}
_, err := svc.HostEncryptionKey(ctx, 1)
require.ErrorIs(t, err, hostErr)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{}, nil
}
keyErr := errors.New("key error")
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return nil, keyErr
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
ds.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName,
_ sqlx.QueryerContext,
) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) {
return map[fleet.MDMAssetName]fleet.MDMConfigAsset{
fleet.MDMAssetCACert: {Name: fleet.MDMAssetCACert, Value: testCertPEM},
fleet.MDMAssetCAKey: {Name: fleet.MDMAssetCAKey, Value: testKeyPEM},
}, nil
}
_, err = svc.HostEncryptionKey(ctx, 1)
require.ErrorIs(t, err, keyErr)
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{Base64Encrypted: "key"}, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
return errors.New("activity error")
}
_, err = svc.HostEncryptionKey(ctx, 1)
require.Error(t, err)
})
t.Run("host platform mdm enabled", func(t *testing.T) {
cases := []struct {
hostPlatform string
macMDMEnabled bool
winMDMEnabled bool
shouldFail bool
}{
{"windows", true, false, true},
{"windows", false, true, false},
{"windows", true, true, false},
{"darwin", true, false, false},
{"darwin", false, true, true},
{"darwin", true, true, false},
}
for _, c := range cases {
t.Run(fmt.Sprintf("%s: mac mdm: %t; win mdm: %t", c.hostPlatform, c.macMDMEnabled, c.winMDMEnabled), func(t *testing.T) {
ds := new(mock.Store)
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: c.macMDMEnabled, WindowsEnabledAndConfigured: c.winMDMEnabled}}, nil
}
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{Platform: c.hostPlatform}, nil
}
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
key := base64EncryptedKey
if c.hostPlatform == "windows" {
key = winBase64EncryptedKey
}
return &fleet.HostDiskEncryptionKey{
Base64Encrypted: key,
Decryptable: ptr.Bool(true),
}, nil
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
return nil
}
ds.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName,
_ sqlx.QueryerContext,
) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) {
return map[fleet.MDMAssetName]fleet.MDMConfigAsset{
fleet.MDMAssetCACert: {Name: fleet.MDMAssetCACert, Value: testCertPEM},
fleet.MDMAssetCAKey: {Name: fleet.MDMAssetCAKey, Value: testKeyPEM},
}, nil
}
svc, ctx := newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
_, err := svc.HostEncryptionKey(ctx, 1)
if c.shouldFail {
require.Error(t, err)
if c.macMDMEnabled && !c.winMDMEnabled && c.hostPlatform == "windows" {
require.ErrorContains(t, err, fleet.ErrWindowsMDMNotConfigured.Error())
} else {
require.ErrorContains(t, err, fleet.ErrMDMNotConfigured.Error())
}
} else {
require.NoError(t, err)
}
})
}
})
t.Run("Linux encryption", func(t *testing.T) {
ds := new(mock.Store)
host := &fleet.Host{ID: 1, Platform: "ubuntu"}
symmetricKey := "this_is_a_32_byte_symmetric_key!"
passphrase := "this_is_a_passphrase"
base64EncryptedKey, err := mdm.EncryptAndEncode(passphrase, symmetricKey)
require.NoError(t, err)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return host, nil
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { // needed for new activity
return &fleet.AppConfig{}, nil
}
// error when no server private key
fleetCfg.Server.PrivateKey = ""
svc, ctx := newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
key, err := svc.HostEncryptionKey(ctx, 1)
require.Error(t, err, "private key is unavailable")
require.Nil(t, key)
// error when key is not set
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{}, nil
}
fleetCfg.Server.PrivateKey = symmetricKey
svc, ctx = newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
key, err = svc.HostEncryptionKey(ctx, 1)
require.Error(t, err, "host encryption key is not set")
require.Nil(t, key)
// error when key is not set
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{
Base64Encrypted: "thisIsWrong",
Decryptable: ptr.Bool(true),
}, nil
}
svc, ctx = newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
key, err = svc.HostEncryptionKey(ctx, 1)
require.Error(t, err, "decrypt host encryption key")
require.Nil(t, key)
// happy path
ds.GetHostDiskEncryptionKeyFunc = func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{
Base64Encrypted: base64EncryptedKey,
Decryptable: ptr.Bool(true),
}, nil
}
svc, ctx = newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
key, err = svc.HostEncryptionKey(ctx, 1)
require.NoError(t, err)
require.Equal(t, passphrase, key.DecryptedValue)
})
}
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostMDMProfileDetail(t *testing.T) {
ds := new(mock.Store)
testCert, testKey, err := apple_mdm.NewSCEPCACertKey()
require.NoError(t, err)
testCertPEM := tokenpki.PEMCertificate(testCert.Raw)
testKeyPEM := tokenpki.PEMRSAPrivateKey(testKey)
fleetCfg := config.TestConfig()
config.SetTestMDMConfig(t, &fleetCfg, testCertPEM, testKeyPEM, "")
svc, ctx := newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
ds.HostFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{
ID: 1,
Platform: "darwin",
}, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hid uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostMDMMacOSSetupFunc = func(ctx context.Context, hid uint) (*fleet.HostMDMMacOSSetup, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{
MDM: fleet.MDM{
EnabledAndConfigured: true,
},
}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.GetNanoMDMEnrollmentTimesFunc = func(ctx context.Context, hostUUID string) (*time.Time, *time.Time, error) {
return nil, nil, nil
}
ds.UpdateHostIssuesFailingPoliciesFunc = func(ctx context.Context, hostIDs []uint) error {
return nil
}
ds.UpdateHostIssuesFailingPoliciesForSingleHostFunc = func(ctx context.Context, hostID uint) error {
return nil
}
ds.GetHostIssuesLastUpdatedFunc = func(ctx context.Context, hostId uint) (time.Time, error) {
return time.Time{}, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
cases := []struct {
name string
storedDetail string
expectedDetail string
}{
{
name: "no detail",
storedDetail: "",
expectedDetail: "",
},
{
name: "other detail",
storedDetail: "other detail",
expectedDetail: "other detail",
},
{
name: "failed was verifying",
storedDetail: string(fleet.HostMDMProfileDetailFailedWasVerifying),
expectedDetail: fleet.HostMDMProfileDetailFailedWasVerifying.Message(),
},
{
name: "failed was verified",
storedDetail: string(fleet.HostMDMProfileDetailFailedWasVerified),
expectedDetail: fleet.HostMDMProfileDetailFailedWasVerified.Message(),
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, host_uuid string) ([]fleet.HostMDMAppleProfile, error) {
return []fleet.HostMDMAppleProfile{
{
Name: "test",
Identifier: "test",
OperationType: fleet.MDMOperationTypeInstall,
Status: &fleet.MDMDeliveryFailed,
Detail: tt.storedDetail,
},
}, nil
}
h, err := svc.GetHost(ctx, uint(1), fleet.HostDetailOptions{})
require.NoError(t, err)
require.NotNil(t, h.MDM.Profiles)
profs := *h.MDM.Profiles
require.Len(t, profs, 1)
require.Equal(t, tt.expectedDetail, profs[0].Detail)
})
}
}
// Fragile test: This test is fragile because of the large reliance on Datastore mocks. Consider refactoring test/logic or removing the test. It may be slowing us down more than helping us.
func TestHostMDMProfileScopes(t *testing.T) {
ds := new(mock.Store)
testCert, testKey, err := apple_mdm.NewSCEPCACertKey()
require.NoError(t, err)
testCertPEM := tokenpki.PEMCertificate(testCert.Raw)
testKeyPEM := tokenpki.PEMRSAPrivateKey(testKey)
fleetCfg := config.TestConfig()
config.SetTestMDMConfig(t, &fleetCfg, testCertPEM, testKeyPEM, "")
svc, ctx := newTestServiceWithConfig(t, ds, fleetCfg, nil, nil)
ctx = test.UserContext(ctx, test.UserAdmin)
appleHost := &fleet.Host{
ID: 1,
UUID: "apple-host-uuid",
Platform: "darwin",
}
windowsHost := &fleet.Host{
ID: 2,
UUID: "windows-host-uuid",
Platform: "windows",
}
ds.HostFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
if id == appleHost.ID {
return appleHost, nil
}
require.Equal(t, id, windowsHost.ID, "Host should only be called with Apple or Windows host IDs")
return windowsHost, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hid uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostMDMMacOSSetupFunc = func(ctx context.Context, hid uint) (*fleet.HostMDMMacOSSetup, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.UpdateHostIssuesFailingPoliciesForSingleHostFunc = func(ctx context.Context, hostID uint) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{
MDM: fleet.MDM{
EnabledAndConfigured: true,
WindowsEnabledAndConfigured: true,
},
}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.GetNanoMDMEnrollmentTimesFunc = func(ctx context.Context, hostUUID string) (*time.Time, *time.Time, error) {
return nil, nil, nil
}
ds.UpdateHostIssuesFailingPoliciesFunc = func(ctx context.Context, hostIDs []uint) error {
return nil
}
ds.GetHostIssuesLastUpdatedFunc = func(ctx context.Context, hostId uint) (time.Time, error) {
return time.Time{}, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
appleCases := []struct {
name string
storedProfiles []fleet.HostMDMAppleProfile
expectedProfiles []fleet.HostMDMProfile
}{
{
name: "no profiles",
storedProfiles: nil,
expectedProfiles: nil,
},
{
name: "system scoped profile",
storedProfiles: []fleet.HostMDMAppleProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: &fleet.MDMDeliveryVerified, Scope: fleet.PayloadScopeSystem}},
expectedProfiles: []fleet.HostMDMProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: ptr.String(string(fleet.MDMDeliveryVerified)), Scope: ptr.String("device"), ManagedLocalAccount: ptr.String("")}},
},
{
name: "User scoped profile with username",
storedProfiles: []fleet.HostMDMAppleProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: &fleet.MDMDeliveryVerified, Scope: fleet.PayloadScopeUser, ManagedLocalAccount: "fleetie"}},
expectedProfiles: []fleet.HostMDMProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: ptr.String(string(fleet.MDMDeliveryVerified)), Scope: ptr.String("user"), ManagedLocalAccount: ptr.String("fleetie")}},
},
{
name: "User scoped profile without username for some reason",
storedProfiles: []fleet.HostMDMAppleProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: &fleet.MDMDeliveryVerified, Scope: fleet.PayloadScopeUser}},
expectedProfiles: []fleet.HostMDMProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: ptr.String(string(fleet.MDMDeliveryVerified)), Scope: ptr.String("user"), ManagedLocalAccount: ptr.String("")}},
},
{
name: "system + user scoped profiles",
storedProfiles: []fleet.HostMDMAppleProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: &fleet.MDMDeliveryVerified, Scope: fleet.PayloadScopeSystem}, {OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid2", Name: "Profile2", Status: &fleet.MDMDeliveryVerified, Scope: fleet.PayloadScopeUser, ManagedLocalAccount: "fleetie"}},
expectedProfiles: []fleet.HostMDMProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: ptr.String(string(fleet.MDMDeliveryVerified)), Scope: ptr.String("device"), ManagedLocalAccount: ptr.String("")}, {OperationType: fleet.MDMOperationTypeInstall, HostUUID: appleHost.UUID, ProfileUUID: "profile-uuid2", Name: "Profile2", Status: ptr.String(string(fleet.MDMDeliveryVerified)), Scope: ptr.String("user"), ManagedLocalAccount: ptr.String("fleetie")}},
},
}
windowsCases := []struct {
name string
storedProfiles []fleet.HostMDMWindowsProfile
expectedProfiles []fleet.HostMDMProfile
}{
{
name: "no profiles",
storedProfiles: nil,
expectedProfiles: nil,
},
// Windows does not support scopes or managed local accounts yet but we should not error and
// should set these to nil which is checked below
{
name: "example profile",
storedProfiles: []fleet.HostMDMWindowsProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: windowsHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: &fleet.MDMDeliveryVerified}},
expectedProfiles: []fleet.HostMDMProfile{{OperationType: fleet.MDMOperationTypeInstall, HostUUID: windowsHost.UUID, ProfileUUID: "profile-uuid1", Name: "Profile1", Status: ptr.String(string(fleet.MDMDeliveryVerified))}},
},
}
for _, tt := range appleCases {
t.Run(tt.name, func(t *testing.T) {
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, host_uuid string) ([]fleet.HostMDMAppleProfile, error) {
return tt.storedProfiles, nil
}
h, err := svc.GetHost(ctx, appleHost.ID, fleet.HostDetailOptions{})
require.NoError(t, err)
if tt.storedProfiles == nil {
require.NotNil(t, h.MDM.Profiles)
require.Empty(t, *h.MDM.Profiles)
return
}
profs := *h.MDM.Profiles
require.Len(t, profs, len(tt.expectedProfiles))
for i := range profs {
require.Equal(t, tt.expectedProfiles[i].OperationType, profs[i].OperationType)
require.Equal(t, tt.expectedProfiles[i].HostUUID, profs[i].HostUUID)
require.Equal(t, tt.expectedProfiles[i].ProfileUUID, profs[i].ProfileUUID)
require.Equal(t, tt.expectedProfiles[i].Name, profs[i].Name)
require.Equal(t, tt.expectedProfiles[i].Status, profs[i].Status)
require.NotNil(t, profs[i].Scope)
require.Equal(t, *tt.expectedProfiles[i].Scope, *profs[i].Scope)
require.NotNil(t, profs[i].ManagedLocalAccount)
require.Equal(t, *tt.expectedProfiles[i].ManagedLocalAccount, *profs[i].ManagedLocalAccount)
}
})
}
for _, tt := range windowsCases {
t.Run(tt.name, func(t *testing.T) {
ds.GetHostMDMWindowsProfilesFunc = func(ctx context.Context, host_uuid string) ([]fleet.HostMDMWindowsProfile, error) {
return tt.storedProfiles, nil
}
h, err := svc.GetHost(ctx, windowsHost.ID, fleet.HostDetailOptions{})
require.NoError(t, err)
if tt.storedProfiles == nil {
require.NotNil(t, h.MDM.Profiles)
require.Empty(t, *h.MDM.Profiles)
return
}
profs := *h.MDM.Profiles
require.Len(t, profs, len(tt.expectedProfiles))
for i := range profs {
require.Equal(t, tt.expectedProfiles[i].OperationType, profs[i].OperationType)
require.Equal(t, tt.expectedProfiles[i].HostUUID, profs[i].HostUUID)
require.Equal(t, tt.expectedProfiles[i].ProfileUUID, profs[i].ProfileUUID)
require.Equal(t, tt.expectedProfiles[i].Name, profs[i].Name)
require.Equal(t, tt.expectedProfiles[i].Status, profs[i].Status)
require.Nil(t, profs[i].Scope)
require.Nil(t, profs[i].ManagedLocalAccount)
}
})
}
}
func TestLockUnlockWipeHostAuth(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}})
const (
teamHostID = 1
globalHostID = 2
)
teamHost := &fleet.Host{TeamID: ptr.Uint(1), Platform: "darwin"}
globalHost := &fleet.Host{Platform: "darwin"}
ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
if identifier == fmt.Sprint(teamHostID) {
return teamHost, nil
}
return globalHost, nil
}
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) (packs []*fleet.Pack, err error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, id uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.GetHostMDMAppleProfilesFunc = func(ctx context.Context, hostUUID string) ([]fleet.HostMDMAppleProfile, error) {
return nil, nil
}
ds.GetHostMDMWindowsProfilesFunc = func(ctx context.Context, hostUUID string) ([]fleet.HostMDMWindowsProfile, error) {
return nil, nil
}
ds.GetHostMDMMacOSSetupFunc = func(ctx context.Context, hostID uint) (*fleet.HostMDMMacOSSetup, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.LockHostViaScriptFunc = func(ctx context.Context, request *fleet.HostScriptRequestPayload, platform string) error {
return nil
}
// Some functions use Host, others HostLite. For our purposes either is fine
ds.HostFunc = func(ctx context.Context, hostID uint) (*fleet.Host, error) {
if hostID == teamHostID {
return teamHost, nil
}
return globalHost, nil
}
ds.HostLiteFunc = mock.HostLiteFunc(ds.HostFunc)
ds.GetMDMWindowsBitLockerStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostMDMDiskEncryption, error) {
return nil, nil
}
ds.GetHostMDMFunc = func(ctx context.Context, hostID uint) (*fleet.HostMDM, error) {
return &fleet.HostMDM{Enrolled: true, Name: fleet.WellKnownMDMFleet}, nil
}
ds.NewActivityFunc = func(
ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time,
) error {
return nil
}
ds.UnlockHostManuallyFunc = func(ctx context.Context, hostID uint, platform string, ts time.Time) error {
return nil
}
ds.IsHostConnectedToFleetMDMFunc = func(ctx context.Context, host *fleet.Host) (bool, error) {
return true, nil
}
ds.GetNanoMDMEnrollmentTimesFunc = func(ctx context.Context, hostUUID string) (*time.Time, *time.Time, error) {
return nil, nil, nil
}
cases := []struct {
name string
user *fleet.User
shouldFailGlobalWrite bool
shouldFailTeamWrite bool
}{
{
name: "global observer",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "team observer",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "global observer plus",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserverPlus)},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "team observer plus",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserverPlus}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "global admin",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
shouldFailGlobalWrite: false,
shouldFailTeamWrite: false,
},
{
name: "team admin",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: false,
},
{
name: "global maintainer",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
shouldFailGlobalWrite: false,
shouldFailTeamWrite: false,
},
{
name: "team maintainer",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: false,
},
{
name: "team admin wrong team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 42}, Role: fleet.RoleAdmin}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "team maintainer wrong team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 42}, Role: fleet.RoleMaintainer}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "global gitops",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleGitOps)},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
{
name: "team gitops",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleGitOps}}},
shouldFailGlobalWrite: true,
shouldFailTeamWrite: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true, WindowsEnabledAndConfigured: true},
ServerSettings: fleet.ServerSettings{ScriptsDisabled: true}, // scripts being disabled shouldn't stop lock/unlock/wipe
}, nil
}
ctx := viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
_, err := svc.LockHost(ctx, globalHostID, false)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.LockHost(ctx, teamHostID, false)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
// Pretend we locked the host
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{HostFleetPlatform: host.FleetPlatform(), LockMDMCommand: &fleet.MDMCommand{}, LockMDMCommandResult: &fleet.MDMCommandResult{Status: fleet.MDMAppleStatusAcknowledged}}, nil
}
_, err = svc.UnlockHost(ctx, globalHostID)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.UnlockHost(ctx, teamHostID)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
// Reset so we're now pretending host is unlocked
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
err = svc.WipeHost(ctx, globalHostID, nil)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
err = svc.WipeHost(ctx, teamHostID, nil)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
})
}
}
func TestBulkOperationFilterValidation(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
viewerCtx := test.UserContext(ctx, test.UserAdmin)
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return []*fleet.Host{}, nil
}
ds.ListHostsInLabelFunc = func(ctx context.Context, filter fleet.TeamFilter, lid uint, opt fleet.HostListOptions) ([]*fleet.Host, error) {
return []*fleet.Host{}, nil
}
// TODO(sarah): Future improvement to auto-generate a list of all possible filter values
// from `fleet.HostListOptions` and iterate to test that only a limited subset of filter (i.e.
// label_id, team_id, status, query) are allowed for bulk operations.
tc := []struct {
name string
filters *map[string]interface{}
has400Err bool
}{
{
name: "valid status filter",
filters: &map[string]interface{}{
"status": "new",
},
},
{
name: "invalid status",
filters: &map[string]interface{}{
"status": "invalid",
},
has400Err: true,
},
{
name: "empty status is invalid",
filters: &map[string]interface{}{
"status": "",
},
has400Err: true,
},
{
name: "valid team filter",
filters: &map[string]interface{}{
"team_id": float64(1), // json unmarshals to float64
},
},
{
name: "invalid team_id type",
filters: &map[string]interface{}{
"team_id": "invalid",
},
has400Err: true,
},
{
name: "valid label_id filter",
filters: &map[string]interface{}{
"label_id": float64(1),
},
},
{
name: "invalid label_id type",
filters: &map[string]interface{}{
"label_id": "invalid",
},
has400Err: true,
},
{
name: "invalid status type",
filters: &map[string]interface{}{
"status": float64(1),
},
has400Err: true,
},
{
name: "empty filter",
filters: &map[string]interface{}{},
},
{
name: "valid query filter",
filters: &map[string]interface{}{
"query": "test",
},
},
{
name: "invalid query type",
filters: &map[string]interface{}{
"query": float64(1),
},
has400Err: true,
},
{
name: "empty query is invalid",
filters: &map[string]interface{}{
"query": "",
},
has400Err: true,
},
{
name: "multiple valid filters",
filters: &map[string]interface{}{
"status": "new",
"team_id": float64(1),
"query": "test",
},
},
{
name: "mixed valid and invalid filters",
filters: &map[string]interface{}{
"status": "new",
"team_id": "invalid",
},
has400Err: true,
},
{
name: "mixed invalid filters and valid filters (different order)",
filters: &map[string]interface{}{
"status": "invalid",
"team_id": 1,
},
has400Err: true,
},
{
name: "mixed valid and unknown filters",
filters: &map[string]interface{}{
"status": "new",
"unknown": "filter",
},
has400Err: true,
},
{
name: "unknown filter",
filters: &map[string]interface{}{
"unknown": "filter",
},
has400Err: true,
},
}
checkErr := func(t *testing.T, err error, has400Err bool) {
if has400Err {
require.Error(t, err)
var be *fleet.BadRequestError
require.ErrorAs(t, err, &be)
} else {
require.NoError(t, err)
}
}
for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
checkErr(t, svc.AddHostsToTeamByFilter(viewerCtx, nil, tt.filters), tt.has400Err)
checkErr(t, svc.DeleteHosts(viewerCtx, nil, tt.filters), tt.has400Err)
})
}
}
func TestSetDiskEncryptionNotifications(t *testing.T) {
ds := new(mock.Store)
ctx := context.Background()
svc := &Service{ds: ds, logger: kitlog.NewNopLogger()}
tests := []struct {
name string
host *fleet.Host
appConfig *fleet.AppConfig
diskEncryptionConfigured bool
isConnectedToFleetMDM bool
mdmInfo *fleet.HostMDM
getHostDiskEncryptionKey func(context.Context, uint) (*fleet.HostDiskEncryptionKey, error)
expectedNotifications *fleet.OrbitConfigNotifications
expectedError bool
disableCapability bool
}{
{
name: "no MDM configured",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: false},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: nil,
expectedNotifications: &fleet.OrbitConfigNotifications{},
expectedError: false,
},
{
name: "not connected to Fleet MDM",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: false,
mdmInfo: nil,
getHostDiskEncryptionKey: nil,
expectedNotifications: &fleet.OrbitConfigNotifications{},
expectedError: false,
},
{
name: "host not enrolled in osquery",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: nil},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: nil,
expectedNotifications: &fleet.OrbitConfigNotifications{},
expectedError: false,
},
{
name: "disk encryption not configured",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: false,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: nil,
expectedNotifications: &fleet.OrbitConfigNotifications{},
expectedError: false,
},
{
name: "darwin with decryptable key",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{Decryptable: ptr.Bool(true)}, nil
},
expectedNotifications: &fleet.OrbitConfigNotifications{
RotateDiskEncryptionKey: false,
},
expectedError: false,
},
{
name: "darwin needs rotation but client is old",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{Decryptable: ptr.Bool(false)}, nil
},
expectedNotifications: &fleet.OrbitConfigNotifications{
RotateDiskEncryptionKey: true,
},
expectedError: false,
disableCapability: true,
},
{
name: "darwin needs rotation",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{Decryptable: ptr.Bool(false)}, nil
},
expectedNotifications: &fleet.OrbitConfigNotifications{
RotateDiskEncryptionKey: true,
},
expectedError: false,
},
{
name: "windows server with no encryption needed",
host: &fleet.Host{ID: 1, Platform: "windows", DiskEncryptionEnabled: ptr.Bool(true), OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: &fleet.HostMDM{IsServer: true},
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return nil, newNotFoundError()
},
expectedNotifications: &fleet.OrbitConfigNotifications{
EnforceBitLockerEncryption: false,
},
expectedError: false,
},
{
name: "windows with encryption enabled but key missing",
host: &fleet.Host{ID: 1, Platform: "windows", DiskEncryptionEnabled: ptr.Bool(true), OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: &fleet.HostMDM{IsServer: false},
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return nil, newNotFoundError()
},
expectedNotifications: &fleet.OrbitConfigNotifications{
EnforceBitLockerEncryption: true,
},
expectedError: false,
},
{
name: "darwin with missing encryption key",
host: &fleet.Host{ID: 1, Platform: "darwin", OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: nil,
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return nil, newNotFoundError()
},
expectedNotifications: &fleet.OrbitConfigNotifications{
RotateDiskEncryptionKey: false,
},
expectedError: false,
},
{
name: "windows with encryption key and not decryptable",
host: &fleet.Host{ID: 1, Platform: "windows", DiskEncryptionEnabled: ptr.Bool(true), OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: &fleet.HostMDM{IsServer: false},
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return &fleet.HostDiskEncryptionKey{Decryptable: ptr.Bool(false)}, nil
},
expectedNotifications: &fleet.OrbitConfigNotifications{
EnforceBitLockerEncryption: true,
},
expectedError: false,
},
{
name: "windows with enforce BitLocker",
host: &fleet.Host{ID: 1, Platform: "windows", DiskEncryptionEnabled: ptr.Bool(false), OsqueryHostID: ptr.String("foo")},
appConfig: &fleet.AppConfig{
MDM: fleet.MDM{EnabledAndConfigured: true},
},
diskEncryptionConfigured: true,
isConnectedToFleetMDM: true,
mdmInfo: &fleet.HostMDM{IsServer: false},
getHostDiskEncryptionKey: func(ctx context.Context, id uint) (*fleet.HostDiskEncryptionKey, error) {
return nil, newNotFoundError()
},
expectedNotifications: &fleet.OrbitConfigNotifications{
EnforceBitLockerEncryption: true,
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.getHostDiskEncryptionKey != nil {
ds.GetHostDiskEncryptionKeyFunc = tt.getHostDiskEncryptionKey
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return tt.appConfig, nil
}
ds.GetHostArchivedDiskEncryptionKeyFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostArchivedDiskEncryptionKey, error) {
return &fleet.HostArchivedDiskEncryptionKey{}, nil
}
if !tt.disableCapability {
r := http.Request{
Header: http.Header{fleet.CapabilitiesHeader: []string{string(fleet.CapabilityEscrowBuddy)}},
}
ctx = capabilities.NewContext(ctx, &r)
}
notifs := &fleet.OrbitConfigNotifications{}
err := svc.setDiskEncryptionNotifications(
ctx,
notifs,
tt.host,
tt.appConfig,
tt.diskEncryptionConfigured,
tt.isConnectedToFleetMDM,
tt.mdmInfo,
)
if tt.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.expectedNotifications.RotateDiskEncryptionKey, notifs.RotateDiskEncryptionKey)
})
}
}
func TestGetHostDetailsExcludeSoftwareFlag(t *testing.T) {
ds := new(mock.Store)
svc := &Service{ds: ds}
baseHost := &fleet.Host{ID: 42}
// common DS mocks
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.ListUpcomingHostMaintenanceWindowsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostMaintenanceWindow, error) {
return nil, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.ScimUserByHostIDFunc = func(ctx context.Context, hostID uint) (*fleet.ScimUser, error) {
return nil, nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, id uint) ([]*fleet.HostDeviceMapping, error) {
return nil, nil
}
ds.IsHostDiskEncryptionKeyArchivedFunc = func(ctx context.Context, hostID uint) (bool, error) {
return false, nil
}
t.Run("ExcludeSoftware=true returns empty slice", func(t *testing.T) {
ds.LoadHostSoftwareFuncInvoked = false
ds.LoadHostSoftwareFunc = func(ctx context.Context, h *fleet.Host, includeCVEScores bool) error {
t.Fatalf("LoadHostSoftwareFunc should not be called when ExcludeSoftware is true")
return nil
}
opts := fleet.HostDetailOptions{ExcludeSoftware: true}
hostDetail, err := svc.getHostDetails(test.UserContext(context.Background(), test.UserAdmin), baseHost, opts)
require.NoError(t, err)
require.NotNil(t, hostDetail.Software, "Software slice should not be nil")
assert.Len(t, hostDetail.Software, 0, "Software slice should be empty when excluded")
})
t.Run("ExcludeSoftware=false returns filled slice", func(t *testing.T) {
expectedSoftware := []fleet.HostSoftwareEntry{
{
Software: fleet.Software{
ID: 1,
Name: "test-app",
Version: "1.0.0",
Source: "apps",
},
InstalledPaths: []string{"/Applications/test-app.app"},
},
{
Software: fleet.Software{
ID: 2,
Name: "another-app",
Version: "2.3.4",
Source: "apps",
},
InstalledPaths: []string{"/Applications/another-app.app"},
},
}
ds.LoadHostSoftwareFuncInvoked = false
ds.LoadHostSoftwareFunc = func(ctx context.Context, h *fleet.Host, includeCVEScores bool) error {
h.HostSoftware.Software = expectedSoftware
return nil
}
opts := fleet.HostDetailOptions{ExcludeSoftware: false}
hostDetail, err := svc.getHostDetails(test.UserContext(context.Background(), test.UserAdmin), baseHost, opts)
require.NoError(t, err)
require.NotNil(t, hostDetail.Software)
assert.Equal(t, expectedSoftware, hostDetail.Software)
assert.True(t, ds.LoadHostSoftwareFuncInvoked, "LoadHostSoftwareFunc should have been called")
})
}
func TestSetHostDeviceMapping(t *testing.T) {
t.Run("custom source success", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
ds.SetOrUpdateCustomHostDeviceMappingFunc = func(ctx context.Context, hostID uint, email, source string) ([]*fleet.HostDeviceMapping, error) {
return []*fleet.HostDeviceMapping{{HostID: hostID, Email: email, Source: source}}, nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
result, err := svc.SetHostDeviceMapping(userCtx, 1, "user@example.com", "custom")
require.NoError(t, err)
require.True(t, ds.SetOrUpdateCustomHostDeviceMappingFuncInvoked)
require.NotNil(t, result)
require.Len(t, result, 1)
assert.Equal(t, uint(1), result[0].HostID)
assert.Equal(t, "user@example.com", result[0].Email)
})
t.Run("empty source defaults to custom", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
ds.SetOrUpdateCustomHostDeviceMappingFunc = func(ctx context.Context, hostID uint, email, source string) ([]*fleet.HostDeviceMapping, error) {
require.Equal(t, fleet.DeviceMappingCustomOverride, source) // Should store as custom_override for user-authenticated calls
return []*fleet.HostDeviceMapping{{HostID: hostID, Email: email, Source: fleet.DeviceMappingCustomReplacement}}, nil // But return as "custom" for display
}
userCtx := test.UserContext(ctx, test.UserAdmin)
result, err := svc.SetHostDeviceMapping(userCtx, 1, "user@example.com", "")
require.NoError(t, err)
require.True(t, ds.SetOrUpdateCustomHostDeviceMappingFuncInvoked)
require.NotNil(t, result)
})
t.Run("IDP source success with valid SCIM user", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
ds.ScimUserByUserNameOrEmailFunc = func(ctx context.Context, userName, email string) (*fleet.ScimUser, error) {
return &fleet.ScimUser{ID: 1, UserName: "user@example.com"}, nil
}
ds.SetOrUpdateHostSCIMUserMappingFunc = func(ctx context.Context, hostID uint, scimUserID uint) error {
return nil
}
ds.SetOrUpdateIDPHostDeviceMappingFunc = func(ctx context.Context, hostID uint, email string) error {
return nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostDeviceMapping, error) {
return []*fleet.HostDeviceMapping{{HostID: hostID, Email: "user@example.com", Source: fleet.DeviceMappingMDMIdpAccounts}}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
result, err := svc.SetHostDeviceMapping(userCtx, 1, "user@example.com", fleet.DeviceMappingIDP)
require.NoError(t, err)
require.True(t, ds.SetOrUpdateIDPHostDeviceMappingFuncInvoked)
require.True(t, ds.SetOrUpdateHostSCIMUserMappingFuncInvoked) // Should be called since SCIM user exists
require.NotNil(t, result)
require.Len(t, result, 1)
assert.Equal(t, uint(1), result[0].HostID)
assert.Equal(t, "user@example.com", result[0].Email)
assert.Equal(t, fleet.DeviceMappingMDMIdpAccounts, result[0].Source)
})
t.Run("IDP source success with any username when SCIM user not found", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
ds.ScimUserByUserNameOrEmailFunc = func(ctx context.Context, userName, email string) (*fleet.ScimUser, error) {
return nil, sql.ErrNoRows // SCIM user not found
}
ds.SetOrUpdateIDPHostDeviceMappingFunc = func(ctx context.Context, hostID uint, email string) error {
return nil
}
ds.DeleteHostSCIMUserMappingFunc = func(ctx context.Context, hostID uint) error {
return nil
}
ds.ListHostDeviceMappingFunc = func(ctx context.Context, hostID uint) ([]*fleet.HostDeviceMapping, error) {
return []*fleet.HostDeviceMapping{{HostID: hostID, Email: "any@username.com", Source: fleet.DeviceMappingMDMIdpAccounts}}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
result, err := svc.SetHostDeviceMapping(userCtx, 1, "any@username.com", fleet.DeviceMappingIDP)
require.NoError(t, err)
require.True(t, ds.SetOrUpdateIDPHostDeviceMappingFuncInvoked)
require.False(t, ds.SetOrUpdateHostSCIMUserMappingFuncInvoked) // Should NOT be called since SCIM user doesn't exist
require.True(t, ds.DeleteHostSCIMUserMappingFuncInvoked) // Should be called to remove any existing SCIM mapping
require.NotNil(t, result)
require.Len(t, result, 1)
assert.Equal(t, uint(1), result[0].HostID)
assert.Equal(t, "any@username.com", result[0].Email)
assert.Equal(t, fleet.DeviceMappingMDMIdpAccounts, result[0].Source)
})
t.Run("IDP source fails without premium license", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierFree}})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
_, err := svc.SetHostDeviceMapping(userCtx, 1, "user@example.com", fleet.DeviceMappingIDP)
require.Error(t, err)
assert.Equal(t, fleet.ErrMissingLicense, err)
})
t.Run("invalid source returns validation error", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
_, err := svc.SetHostDeviceMapping(userCtx, 1, "user@example.com", "invalid")
require.Error(t, err)
require.Contains(t, err.Error(), "must be 'custom' or 'idp'")
require.True(t, ds.HostLiteFuncInvoked) // Authorization was checked
})
t.Run("authorization failure for observer user", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
// Use observer user who shouldn't have write permission
user := &fleet.User{
ID: 42,
Email: "observer@example.com",
GlobalRole: ptr.String(fleet.RoleObserver),
}
userCtx := viewer.NewContext(ctx, viewer.Viewer{User: user})
_, err := svc.SetHostDeviceMapping(userCtx, 1, "user@example.com", "custom")
require.Error(t, err)
require.Contains(t, err.Error(), "forbidden")
})
t.Run("host not found returns error", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return nil, sql.ErrNoRows
}
userCtx := test.UserContext(ctx, test.UserAdmin)
_, err := svc.SetHostDeviceMapping(userCtx, 999, "user@example.com", "custom")
require.Error(t, err)
assert.Contains(t, err.Error(), "get host")
})
t.Run("orbit installer source override", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
// Create orbit context that simulates installer authentication
authzCtx := &authzctx.AuthorizationContext{}
orbitCtx := authzctx.NewContext(ctx, authzCtx)
orbitCtx = hostctx.NewContext(orbitCtx, &fleet.Host{ID: 1})
if ac, ok := authzctx.FromContext(orbitCtx); ok {
ac.SetAuthnMethod(authzctx.AuthnOrbitToken)
}
ds.SetOrUpdateCustomHostDeviceMappingFunc = func(ctx context.Context, hostID uint, email, source string) ([]*fleet.HostDeviceMapping, error) {
// Should use installer source for orbit token
require.Equal(t, fleet.DeviceMappingCustomInstaller, source)
return []*fleet.HostDeviceMapping{{HostID: hostID, Email: email, Source: source}}, nil
}
result, err := svc.SetHostDeviceMapping(orbitCtx, 1, "user@example.com", "custom")
require.NoError(t, err)
require.True(t, ds.SetOrUpdateCustomHostDeviceMappingFuncInvoked)
require.NotNil(t, result)
})
}
func TestDeleteHostDeviceIDPMapping(t *testing.T) {
t.Run("success by admin on premium", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
ds.DeleteHostIDPFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
err := svc.DeleteHostIDP(userCtx, 1)
require.True(t, ds.DeleteHostIDPFuncInvoked)
require.True(t, ds.NewActivityFuncInvoked)
require.NoError(t, err)
})
t.Run("failure by admin on free", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierFree}})
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
return &fleet.Host{ID: 1}, nil
}
ds.DeleteHostIDPFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
userCtx := test.UserContext(ctx, test.UserAdmin)
err := svc.DeleteHostIDP(userCtx, 1)
// err is license err
assert.Equal(t, fleet.ErrMissingLicense, err)
require.False(t, ds.DeleteHostIDPFuncInvoked)
require.False(t, ds.NewActivityFuncInvoked)
})
t.Run("authorization tests", func(t *testing.T) {
teamHost := &fleet.Host{ID: 1, TeamID: ptr.Uint(1)}
globalHost := &fleet.Host{ID: 2}
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}})
ds.DeleteHostIDPFunc = func(ctx context.Context, id uint) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails, details []byte, createdAt time.Time) error {
return nil
}
testCases := []struct {
name string
user *fleet.User
host *fleet.Host
shouldFail bool
}{
// Global roles
{
name: "global admin can delete global host IDP",
user: test.UserAdmin,
host: globalHost,
shouldFail: false,
},
{
name: "global admin can delete team host IDP",
user: test.UserAdmin,
host: teamHost,
shouldFail: false,
},
{
name: "global maintainer can delete global host IDP",
user: test.UserMaintainer,
host: globalHost,
shouldFail: false,
},
{
name: "global maintainer can delete team host IDP",
user: test.UserMaintainer,
host: teamHost,
shouldFail: false,
},
{
name: "global observer cannot delete global host IDP",
user: test.UserObserver,
host: globalHost,
shouldFail: true,
},
{
name: "global observer cannot delete team host IDP",
user: test.UserObserver,
host: teamHost,
shouldFail: true,
},
{
name: "global observer plus cannot delete global host IDP",
user: test.UserObserverPlus,
host: globalHost,
shouldFail: true,
},
{
name: "global observer plus cannot delete team host IDP",
user: test.UserObserverPlus,
host: teamHost,
shouldFail: true,
},
{
name: "global gitops cannot delete global host IDP",
user: test.UserGitOps,
host: globalHost,
shouldFail: true,
},
{
name: "global gitops cannot delete team host IDP",
user: test.UserGitOps,
host: teamHost,
shouldFail: true,
},
// Team roles - correct team
{
name: "team admin can delete team host IDP",
user: test.UserTeamAdminTeam1,
host: teamHost,
shouldFail: false,
},
{
name: "team admin cannot delete global host IDP",
user: test.UserTeamAdminTeam1,
host: globalHost,
shouldFail: true,
},
{
name: "team maintainer can delete team host IDP",
user: test.UserTeamMaintainerTeam1,
host: teamHost,
shouldFail: false,
},
{
name: "team maintainer cannot delete global host IDP",
user: test.UserTeamMaintainerTeam1,
host: globalHost,
shouldFail: true,
},
{
name: "team observer cannot delete team host IDP",
user: test.UserTeamObserverTeam1,
host: teamHost,
shouldFail: true,
},
{
name: "team observer cannot delete global host IDP",
user: test.UserTeamObserverTeam1,
host: globalHost,
shouldFail: true,
},
{
name: "team observer plus cannot delete team host IDP",
user: test.UserTeamObserverPlusTeam1,
host: teamHost,
shouldFail: true,
},
{
name: "team observer plus cannot delete global host IDP",
user: test.UserTeamObserverPlusTeam1,
host: globalHost,
shouldFail: true,
},
{
name: "team gitops cannot delete team host IDP",
user: test.UserTeamGitOpsTeam1,
host: teamHost,
shouldFail: true,
},
{
name: "team gitops cannot delete global host IDP",
user: test.UserTeamGitOpsTeam1,
host: globalHost,
shouldFail: true,
},
// Team roles - wrong team
{
name: "team admin from different team cannot delete team host IDP",
user: test.UserTeamAdminTeam2,
host: teamHost,
shouldFail: true,
},
{
name: "team maintainer from different team cannot delete team host IDP",
user: test.UserTeamMaintainerTeam2,
host: teamHost,
shouldFail: true,
},
// No roles
{
name: "user with no roles cannot delete global host IDP",
user: test.UserNoRoles,
host: globalHost,
shouldFail: true,
},
{
name: "user with no roles cannot delete team host IDP",
user: test.UserNoRoles,
host: teamHost,
shouldFail: true,
},
}
for _, tc := range testCases {
// reset ds mock flags
ds.DeleteHostIDPFuncInvoked = false
ds.NewActivityFuncInvoked = false
// redefine this datastore mock for each test case since its return value is specific per case
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
// this will always be true, since the method is called with tc.host.ID in the first place
if id == tc.host.ID {
return tc.host, nil
}
return nil, sql.ErrNoRows
}
t.Run(tc.name, func(t *testing.T) {
userCtx := test.UserContext(ctx, tc.user)
err := svc.DeleteHostIDP(userCtx, tc.host.ID)
if tc.shouldFail {
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
require.False(t, ds.DeleteHostIDPFuncInvoked)
require.False(t, ds.NewActivityFuncInvoked)
} else {
require.NoError(t, err)
require.True(t, ds.DeleteHostIDPFuncInvoked)
require.True(t, ds.NewActivityFuncInvoked)
}
})
}
})
}