mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Use fleetctl new templates for new instances (#42768)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #41409 # Details This PR updates the `ApplyStarterLibrary` method and functionality to rely on the same templates and mechanisms as `fleetctl new`. The end result is that running `fleetctl new` and `fleetctl gitops` on a new instance should be a no-op; no changes should be made. Similarly, changing the templates in a Fleet release will automatically affect `fleetctl new` and `ApplyStarterLibrary` in the same exact way for that release. > Note that this moves the template files out of `fleetctl` and into their own shared package. This move comprises the majority of the file changes in the PR. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [X] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. ## Testing - [X] Added/updated automated tests Note that <img width="668" height="44" alt="image" src="https://github.com/user-attachments/assets/066cd566-f91d-4661-84fc-2aabbfce2ef9" /> will fail until the 4.83 Fleet docker image is published, since it's trying to push 4.83 config (including `exceptions`) to a 4.82 server. - [X] QA'd all new/changed functionality manually - [X] Created a new instance and validated that the fleets, policies and labels created matched the ones created by `fleetctl new` - [X] Ran `fleetctl new` and verified that it created the expected folders and files - [X] Ran `fleetctl gitops` with the files created by `fleetctl new` and verified that the instance was unchanged. - [X] Ran `fleetctl preview` successfully using a dev build of the Fleet server image (since it won't work against the latest published build, which doesn't support `exceptions`). Verified it shows the expected teams, policies and labels
This commit is contained in:
parent
a2e7c95c6c
commit
c4aa6f5529
15 changed files with 424 additions and 1048 deletions
39
.github/workflows/fleetctl-preview-latest.yml
vendored
39
.github/workflows/fleetctl-preview-latest.yml
vendored
|
|
@ -1,7 +1,8 @@
|
|||
name: Test latest changes in fleetctl preview
|
||||
|
||||
# Tests the `fleetctl preview` command with latest changes in fleetctl and
|
||||
# docs/01-Using-Fleet/starter-library/starter-library.yml
|
||||
# Tests the `fleetctl preview` command with the Fleet server and fleetctl
|
||||
# built from the same commit, ensuring the starter library and GitOps
|
||||
# pipeline work end-to-end.
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
@ -16,7 +17,6 @@ on:
|
|||
- 'server/context/**.go'
|
||||
- 'orbit/**.go'
|
||||
- 'ee/fleetctl/**.go'
|
||||
- 'docs/01-Using-Fleet/starter-library/starter-library.yml'
|
||||
- '.github/workflows/fleetctl-preview-latest.yml'
|
||||
- 'tools/osquery/in-a-box'
|
||||
pull_request:
|
||||
|
|
@ -27,7 +27,6 @@ on:
|
|||
- 'server/context/**.go'
|
||||
- 'orbit/**.go'
|
||||
- 'ee/fleetctl/**.go'
|
||||
- 'docs/01-Using-Fleet/starter-library/starter-library.yml'
|
||||
- '.github/workflows/fleetctl-preview-latest.yml'
|
||||
- 'tools/osquery/in-a-box'
|
||||
workflow_dispatch: # Manual
|
||||
|
|
@ -71,15 +70,43 @@ jobs:
|
|||
with:
|
||||
go-version-file: 'go.mod'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@5e21ff4d9bc1a8cf6de233a3057d20ec6b3fb69d # v3.8.1
|
||||
with:
|
||||
node-version-file: package.json
|
||||
check-latest: true
|
||||
|
||||
- name: Install JS dependencies
|
||||
run: make deps
|
||||
|
||||
- name: Generate assets
|
||||
run: make generate
|
||||
|
||||
- name: Build Fleetctl
|
||||
run: make fleetctl
|
||||
|
||||
- name: Build Fleet server Docker image
|
||||
run: |
|
||||
make fleet-static
|
||||
cp ./build/fleet fleet
|
||||
docker build -t fleetdm/fleet:dev -f tools/fleet-docker/Dockerfile .
|
||||
rm fleet
|
||||
|
||||
- name: Prepare preview config
|
||||
run: |
|
||||
# Copy the in-a-box config and set pull_policy so Docker uses the
|
||||
# locally built image instead of trying to pull from Docker Hub.
|
||||
cp -a tools/osquery/in-a-box /tmp/preview-config
|
||||
# Add pull_policy: never to fleet01 and fleet02 services
|
||||
sed -i '/^ fleet01:/,/^ [^ ]/{s/^\( image: fleetdm\/fleet.*\)/\1\n pull_policy: never/}' /tmp/preview-config/docker-compose.yml
|
||||
sed -i '/^ fleet02:/,/^ [^ ]/{s/^\( image: fleetdm\/fleet.*\)/\1\n pull_policy: never/}' /tmp/preview-config/docker-compose.yml
|
||||
|
||||
- name: Run fleetctl preview
|
||||
run: |
|
||||
./build/fleetctl preview \
|
||||
--tag dev \
|
||||
--disable-open-browser \
|
||||
--starter-library-file-path $(pwd)/docs/01-Using-Fleet/starter-library/starter-library.yml \
|
||||
--preview-config-path ./tools/osquery/in-a-box
|
||||
--preview-config-path /tmp/preview-config
|
||||
sleep 10
|
||||
./build/fleetctl get hosts | tee hosts.txt
|
||||
[ $( cat hosts.txt | grep online | wc -l) -eq 8 ]
|
||||
|
|
|
|||
9
Makefile
9
Makefile
|
|
@ -58,14 +58,16 @@ ifdef CIRCLE_TAG
|
|||
DOCKER_IMAGE_TAG = ${CIRCLE_TAG}
|
||||
endif
|
||||
|
||||
LDFLAGS_VERSION = "\
|
||||
LDFLAGS_VERSION_RAW = \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.appName=${APP_NAME} \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.version=${VERSION} \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.branch=${BRANCH} \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.revision=${REVISION} \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.buildDate=${NOW} \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.buildUser=${USER} \
|
||||
-X github.com/fleetdm/fleet/v4/server/version.goVersion=${GOVERSION}"
|
||||
-X github.com/fleetdm/fleet/v4/server/version.goVersion=${GOVERSION}
|
||||
LDFLAGS_VERSION = "${LDFLAGS_VERSION_RAW}"
|
||||
LDFLAGS_VERSION_STATIC = "${LDFLAGS_VERSION_RAW} -extldflags '-static'"
|
||||
|
||||
# Macro to allow targets to filter out their own arguments from the arguments
|
||||
# passed to the final command.
|
||||
|
|
@ -198,6 +200,9 @@ endif
|
|||
fleet: .prefix .pre-build .pre-fleet
|
||||
CGO_ENABLED=1 go build -race=${GO_BUILD_RACE_ENABLED_VAR} -tags full,fts5,netgo -o build/${OUTPUT} -ldflags ${LDFLAGS_VERSION} ./cmd/fleet
|
||||
|
||||
fleet-static: .prefix .pre-build .pre-fleet
|
||||
CGO_ENABLED=1 go build -tags full,fts5,netgo -trimpath -o build/${OUTPUT} -ldflags ${LDFLAGS_VERSION_STATIC} ./cmd/fleet
|
||||
|
||||
fleet-dev: GO_BUILD_RACE_ENABLED_VAR=true
|
||||
fleet-dev: fleet
|
||||
|
||||
|
|
|
|||
1
changes/41409-use-fleetctl-new-templates-as-starter-lib
Normal file
1
changes/41409-use-fleetctl-new-templates-as-starter-lib
Normal file
|
|
@ -0,0 +1 @@
|
|||
- Use the same templates for `fleetctl new` and new instance initialization
|
||||
|
|
@ -23,6 +23,7 @@ import (
|
|||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/e-dard/netbug"
|
||||
"github.com/fleetdm/fleet/v4/cmd/fleetctl/fleetctl"
|
||||
"github.com/fleetdm/fleet/v4/ee/server/licensing"
|
||||
"github.com/fleetdm/fleet/v4/ee/server/scim"
|
||||
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
|
||||
|
|
@ -1496,7 +1497,15 @@ func runServeCmd(cmd *cobra.Command, configManager configpkg.Manager, debug, dev
|
|||
// By performing the same check inside main, we can make server startups
|
||||
// more efficient after the first startup.
|
||||
if setupRequired {
|
||||
apiHandler = service.WithSetup(svc, logger, apiHandler)
|
||||
// Pass in a closure to run the fleetctl command, so that the service layer
|
||||
// doesn't need to import the CLI package.
|
||||
applyStarterLibrary := func(ctx context.Context, serverURL, token string) error {
|
||||
return service.ApplyStarterLibrary(ctx, serverURL, token, logger, func(args []string) error {
|
||||
_, err := fleetctl.RunApp(args)
|
||||
return err
|
||||
})
|
||||
}
|
||||
apiHandler = service.WithSetup(svc, logger, applyStarterLibrary, apiHandler)
|
||||
frontendHandler = service.RedirectLoginToSetup(svc, logger, frontendHandler, config.Server.URLPrefix)
|
||||
} else {
|
||||
frontendHandler = service.RedirectSetupToLogin(svc, logger, frontendHandler, config.Server.URLPrefix)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package fleetctl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
eefleetctl "github.com/fleetdm/fleet/v4/ee/fleetctl"
|
||||
"github.com/fleetdm/fleet/v4/server/version"
|
||||
|
|
@ -78,3 +80,14 @@ func CreateApp(
|
|||
}
|
||||
return app
|
||||
}
|
||||
|
||||
func RunApp(args []string) (*bytes.Buffer, error) {
|
||||
// first arg must be the binary name. Allow tests to omit it.
|
||||
args = append([]string{""}, args...)
|
||||
|
||||
w := new(bytes.Buffer)
|
||||
app := CreateApp(nil, w, os.Stderr, noopExitErrHandler)
|
||||
StashRawArgs(app, args)
|
||||
err := app.Run(args)
|
||||
return w, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -400,9 +400,10 @@ Use the stop and reset subcommands to manage the server and dependencies once st
|
|||
address,
|
||||
token,
|
||||
logger,
|
||||
fleethttp.NewClient,
|
||||
service.NewClient,
|
||||
nil, // No mock ApplyGroup for production code
|
||||
func(args []string) error {
|
||||
_, err := RunApp(args)
|
||||
return err
|
||||
},
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to apply starter library: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package fleetctl
|
|||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -23,17 +22,6 @@ func RunAppCheckErr(t *testing.T, args []string, errorMsg string) string {
|
|||
return w.String()
|
||||
}
|
||||
|
||||
func RunAppNoChecks(args []string) (*bytes.Buffer, error) {
|
||||
// first arg must be the binary name. Allow tests to omit it.
|
||||
args = append([]string{""}, args...)
|
||||
|
||||
w := new(bytes.Buffer)
|
||||
app := CreateApp(nil, w, os.Stderr, noopExitErrHandler)
|
||||
StashRawArgs(app, args)
|
||||
err := app.Run(args)
|
||||
return w, err
|
||||
}
|
||||
|
||||
func RunWithErrWriter(args []string, errWriter io.Writer) (*bytes.Buffer, error) {
|
||||
args = append([]string{""}, args...)
|
||||
|
||||
|
|
@ -45,3 +33,9 @@ func RunWithErrWriter(args []string, errWriter io.Writer) (*bytes.Buffer, error)
|
|||
}
|
||||
|
||||
func noopExitErrHandler(c *cli.Context, err error) {}
|
||||
|
||||
// Alias for RunApp; added rather than changing all existing calls to `RunApp`,
|
||||
// to avoid confusion and in case the behavior of `RunApp` needs to diverge in the future.
|
||||
func RunAppNoChecks(args []string) (*bytes.Buffer, error) {
|
||||
return RunApp(args)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
package gitops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/cmd/fleetctl/fleetctl"
|
||||
"github.com/fleetdm/fleet/v4/cmd/fleetctl/integrationtest"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
func TestIntegrationsEnterpriseStarterLibrary(t *testing.T) {
|
||||
testingSuite := new(starterLibraryIntegrationEnterpriseTestSuite)
|
||||
testingSuite.WithServer.Suite = &testingSuite.Suite
|
||||
suite.Run(t, testingSuite)
|
||||
}
|
||||
|
||||
type starterLibraryIntegrationEnterpriseTestSuite struct {
|
||||
suite.Suite
|
||||
integrationtest.WithServer
|
||||
}
|
||||
|
||||
func (s *starterLibraryIntegrationEnterpriseTestSuite) SetupSuite() {
|
||||
s.WithDS.SetupSuite("starterLibraryIntegrationEnterpriseTestSuite")
|
||||
|
||||
appConf, err := s.DS.AppConfig(context.Background())
|
||||
s.Require().NoError(err)
|
||||
err = s.DS.SaveAppConfig(context.Background(), appConf)
|
||||
s.Require().NoError(err)
|
||||
|
||||
fleetCfg := config.TestConfig()
|
||||
fleetCfg.Osquery.EnrollCooldown = 0
|
||||
|
||||
redisPool := redistest.SetupRedis(s.T(), "starter_library_enterprise", false, false, false)
|
||||
|
||||
serverConfig := service.TestServerOpts{
|
||||
License: &fleet.LicenseInfo{
|
||||
Tier: fleet.TierPremium,
|
||||
},
|
||||
FleetConfig: &fleetCfg,
|
||||
Pool: redisPool,
|
||||
}
|
||||
|
||||
users, server := service.RunServerForTestsWithDS(s.T(), s.DS, &serverConfig)
|
||||
s.T().Setenv("FLEET_SERVER_ADDRESS", server.URL)
|
||||
s.Server = server
|
||||
s.Users = users
|
||||
|
||||
appConf, err = s.DS.AppConfig(context.Background())
|
||||
s.Require().NoError(err)
|
||||
appConf.ServerSettings.ServerURL = server.URL
|
||||
appConf.OrgInfo.OrgName = "Test Org"
|
||||
appConf.GitOpsConfig.Exceptions = fleet.GitOpsExceptions{}
|
||||
err = s.DS.SaveAppConfig(context.Background(), appConf)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// TestApplyStarterLibraryPremium verifies that ApplyStarterLibrary applies the
|
||||
// global config and team configs when using a premium license.
|
||||
func (s *starterLibraryIntegrationEnterpriseTestSuite) TestApplyStarterLibraryPremium() {
|
||||
t := s.T()
|
||||
ctx := context.Background()
|
||||
|
||||
token := s.GetTestToken("admin1@example.com", test.GoodPassword)
|
||||
logger := slog.New(slog.DiscardHandler)
|
||||
|
||||
err := service.ApplyStarterLibrary(
|
||||
ctx,
|
||||
s.Server.URL,
|
||||
token,
|
||||
logger,
|
||||
func(args []string) error {
|
||||
_, err := fleetctl.RunAppNoChecks(args)
|
||||
return err
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the org name was applied.
|
||||
appConfig, err := s.DS.AppConfig(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test Org", appConfig.OrgInfo.OrgName)
|
||||
|
||||
// Verify that the teams from the starter templates were created.
|
||||
teams, err := s.DS.ListTeams(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
teamNames := make([]string, len(teams))
|
||||
for i, tm := range teams {
|
||||
teamNames[i] = tm.Name
|
||||
}
|
||||
assert.Contains(t, teamNames, "💻 Workstations")
|
||||
assert.Contains(t, teamNames, "📱🔐 Personal mobile devices")
|
||||
|
||||
// Verify labels were created (global labels, team_id=0).
|
||||
labelSpecs, err := s.DS.GetLabelSpecs(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}})
|
||||
require.NoError(t, err)
|
||||
var customLabelNames []string
|
||||
for _, l := range labelSpecs {
|
||||
if l.LabelType != fleet.LabelTypeBuiltIn {
|
||||
customLabelNames = append(customLabelNames, l.Name)
|
||||
}
|
||||
}
|
||||
assert.Contains(t, customLabelNames, "Apple Silicon macOS hosts")
|
||||
assert.Contains(t, customLabelNames, "ARM-based Windows hosts")
|
||||
assert.Contains(t, customLabelNames, "Debian-based Linux hosts")
|
||||
assert.Contains(t, customLabelNames, "x86-based Windows hosts")
|
||||
}
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
package gitops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/cmd/fleetctl/fleetctl"
|
||||
"github.com/fleetdm/fleet/v4/cmd/fleetctl/integrationtest"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/service"
|
||||
"github.com/fleetdm/fleet/v4/server/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
func TestIntegrationsStarterLibrary(t *testing.T) {
|
||||
testingSuite := new(starterLibraryIntegrationTestSuite)
|
||||
testingSuite.WithServer.Suite = &testingSuite.Suite
|
||||
suite.Run(t, testingSuite)
|
||||
}
|
||||
|
||||
type starterLibraryIntegrationTestSuite struct {
|
||||
suite.Suite
|
||||
integrationtest.WithServer
|
||||
}
|
||||
|
||||
func (s *starterLibraryIntegrationTestSuite) SetupSuite() {
|
||||
s.WithDS.SetupSuite("starterLibraryIntegrationTestSuite")
|
||||
|
||||
appConf, err := s.DS.AppConfig(context.Background())
|
||||
s.Require().NoError(err)
|
||||
err = s.DS.SaveAppConfig(context.Background(), appConf)
|
||||
s.Require().NoError(err)
|
||||
|
||||
fleetCfg := config.TestConfig()
|
||||
fleetCfg.Osquery.EnrollCooldown = 0
|
||||
|
||||
redisPool := redistest.SetupRedis(s.T(), "starter_library", false, false, false)
|
||||
|
||||
serverConfig := service.TestServerOpts{
|
||||
FleetConfig: &fleetCfg,
|
||||
Pool: redisPool,
|
||||
}
|
||||
|
||||
users, server := service.RunServerForTestsWithDS(s.T(), s.DS, &serverConfig)
|
||||
s.T().Setenv("FLEET_SERVER_ADDRESS", server.URL)
|
||||
s.Server = server
|
||||
s.Users = users
|
||||
|
||||
appConf, err = s.DS.AppConfig(context.Background())
|
||||
s.Require().NoError(err)
|
||||
appConf.ServerSettings.ServerURL = server.URL
|
||||
appConf.OrgInfo.OrgName = "Test Org"
|
||||
appConf.GitOpsConfig.Exceptions = fleet.GitOpsExceptions{}
|
||||
err = s.DS.SaveAppConfig(context.Background(), appConf)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// TestApplyStarterLibraryFree verifies that ApplyStarterLibrary applies only
|
||||
// the global config (no teams) when using a free license.
|
||||
func (s *starterLibraryIntegrationTestSuite) TestApplyStarterLibraryFree() {
|
||||
t := s.T()
|
||||
ctx := context.Background()
|
||||
|
||||
token := s.GetTestToken("admin1@example.com", test.GoodPassword)
|
||||
logger := slog.New(slog.DiscardHandler)
|
||||
|
||||
err := service.ApplyStarterLibrary(
|
||||
ctx,
|
||||
s.Server.URL,
|
||||
token,
|
||||
logger,
|
||||
func(args []string) error {
|
||||
_, err := fleetctl.RunAppNoChecks(args)
|
||||
return err
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the org name was applied.
|
||||
appConfig, err := s.DS.AppConfig(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test Org", appConfig.OrgInfo.OrgName)
|
||||
|
||||
// Verify that no teams were created for a free license.
|
||||
teams, err := s.DS.ListTeams(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, teams)
|
||||
|
||||
// Verify labels were created (global labels, team_id=0).
|
||||
labelSpecs, err := s.DS.GetLabelSpecs(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}})
|
||||
require.NoError(t, err)
|
||||
var customLabelNames []string
|
||||
for _, l := range labelSpecs {
|
||||
if l.LabelType != fleet.LabelTypeBuiltIn {
|
||||
customLabelNames = append(customLabelNames, l.Name)
|
||||
}
|
||||
}
|
||||
assert.Contains(t, customLabelNames, "Apple Silicon macOS hosts")
|
||||
assert.Contains(t, customLabelNames, "ARM-based Windows hosts")
|
||||
assert.Contains(t, customLabelNames, "Debian-based Linux hosts")
|
||||
assert.Contains(t, customLabelNames, "x86-based Windows hosts")
|
||||
}
|
||||
|
|
@ -9,10 +9,12 @@ import (
|
|||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/platform/endpointer"
|
||||
|
|
@ -23,6 +25,37 @@ import (
|
|||
|
||||
var yamlSeparator = regexp.MustCompile(`(?m:^---[\t ]*)`)
|
||||
|
||||
var (
|
||||
envOverridesMu sync.RWMutex
|
||||
envOverrides map[string]string
|
||||
)
|
||||
|
||||
// SetEnvOverrides sets environment variable overrides that take precedence over
|
||||
// os.LookupEnv during env expansion in GitOps file parsing. Pass nil to clear.
|
||||
func SetEnvOverrides(overrides map[string]string) {
|
||||
envOverridesMu.Lock()
|
||||
defer envOverridesMu.Unlock()
|
||||
if overrides == nil {
|
||||
envOverrides = nil
|
||||
return
|
||||
}
|
||||
envOverrides = make(map[string]string, len(overrides))
|
||||
maps.Copy(envOverrides, overrides)
|
||||
}
|
||||
|
||||
// lookupEnv checks env overrides first, then falls back to os.LookupEnv.
|
||||
func lookupEnv(key string) (string, bool) {
|
||||
envOverridesMu.RLock()
|
||||
if envOverrides != nil {
|
||||
if v, ok := envOverrides[key]; ok {
|
||||
envOverridesMu.RUnlock()
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
envOverridesMu.RUnlock()
|
||||
return os.LookupEnv(key)
|
||||
}
|
||||
|
||||
// Group holds a set of "specs" that can be applied to a Fleet server.
|
||||
type Group struct {
|
||||
Queries []*fleet.QuerySpec
|
||||
|
|
@ -301,7 +334,7 @@ func expandEnv(s string, secretMode secretHandling) (string, error) {
|
|||
switch secretMode {
|
||||
case secretsExpand:
|
||||
// Expand secrets for client-side validation
|
||||
v, ok := os.LookupEnv(env)
|
||||
v, ok := lookupEnv(env)
|
||||
if ok {
|
||||
if !documentIsXML {
|
||||
return v, true
|
||||
|
|
@ -334,7 +367,7 @@ func expandEnv(s string, secretMode secretHandling) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
v, ok := os.LookupEnv(env)
|
||||
v, ok := lookupEnv(env)
|
||||
if !ok {
|
||||
err = multierror.Append(err, fmt.Errorf("environment variable %q not set", env))
|
||||
return "", false
|
||||
|
|
@ -398,7 +431,7 @@ func LookupEnvSecrets(s string, secretsMap map[string]string) error {
|
|||
_ = fleet.MaybeExpand(s, func(env string, startPos, endPos int) (string, bool) {
|
||||
if strings.HasPrefix(env, fleet.ServerSecretPrefix) {
|
||||
// lookup the secret and save it, but don't replace
|
||||
v, ok := os.LookupEnv(env)
|
||||
v, ok := lookupEnv(env)
|
||||
if !ok {
|
||||
err = multierror.Append(err, fmt.Errorf("environment variable %q not set", env))
|
||||
return "", false
|
||||
|
|
|
|||
|
|
@ -674,10 +674,9 @@ func (c *Client) ApplyGroup(
|
|||
specs.AppConfig.(map[string]interface{})["yara_rules"] = rulePayloads
|
||||
}
|
||||
|
||||
// Keep any existing GitOps mode config rather than attempting to set via GitOps.
|
||||
if appconfig != nil {
|
||||
specs.AppConfig.(map[string]any)["gitops"] = appconfig.GitOpsConfig
|
||||
}
|
||||
// GitOps mode config is managed server-side; remove it from the PATCH
|
||||
// payload so the existing settings are preserved.
|
||||
delete(specs.AppConfig.(map[string]any), "gitops")
|
||||
|
||||
if err := c.ApplyAppConfig(specs.AppConfig, opts.ApplySpecOptions); err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("applying fleet config: %w", err)
|
||||
|
|
|
|||
|
|
@ -2,17 +2,11 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
"github.com/fleetdm/fleet/v4/pkg/spec"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
|
@ -20,11 +14,6 @@ import (
|
|||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
const (
|
||||
starterLibraryURL = "https://raw.githubusercontent.com/fleetdm/fleet/main/docs/01-Using-Fleet/starter-library/starter-library.yml"
|
||||
scriptsBaseURL = "https://raw.githubusercontent.com/fleetdm/fleet/main/"
|
||||
)
|
||||
|
||||
type setupRequest struct {
|
||||
Admin *fleet.UserPayload `json:"admin"`
|
||||
OrgInfo *fleet.OrgInfo `json:"org_info"`
|
||||
|
|
@ -41,11 +30,9 @@ type setupResponse struct {
|
|||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type applyGroupFunc func(context.Context, *spec.Group) error
|
||||
|
||||
func (r setupResponse) Error() error { return r.Err }
|
||||
|
||||
func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger) endpoint.Endpoint {
|
||||
func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger, applyStarterLibrary func(ctx context.Context, serverURL, token string) error) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(setupRequest)
|
||||
config := &fleet.AppConfig{}
|
||||
|
|
@ -94,15 +81,7 @@ func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger) endpoint.Endpoint
|
|||
|
||||
// Apply starter library using the admin token we just created
|
||||
if req.ServerURL != nil {
|
||||
if err := ApplyStarterLibrary(
|
||||
ctx,
|
||||
*req.ServerURL,
|
||||
session.Key,
|
||||
logger,
|
||||
fleethttp.NewClient,
|
||||
NewClient,
|
||||
nil, // No mock ApplyGroup for production code
|
||||
); err != nil {
|
||||
if err := applyStarterLibrary(ctx, *req.ServerURL, session.Key); err != nil {
|
||||
logger.DebugContext(ctx, "setup apply starter library", "endpoint", "setup", "op", "applyStarterLibrary", "err", err)
|
||||
// Continue even if there's an error applying the starter library
|
||||
}
|
||||
|
|
@ -120,309 +99,89 @@ func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger) endpoint.Endpoint
|
|||
}
|
||||
}
|
||||
|
||||
// ApplyStarterLibrary downloads the starter library from GitHub
|
||||
// and applies it to the Fleet server using an authenticated client.
|
||||
// TODO: Move the apply starter library logic to use the serve command as an entry point to simplify and leverage the entire fleet.Service.
|
||||
// Entry point: https://github.com/fleetdm/fleet/blob/2dfadc0971c6ba45c19dad2f5f1f4cd0f1b89b20/cmd/fleet/serve.go#L1099-L1100
|
||||
// ApplyStarterLibrary scaffolds the starter GitOps templates via `fleetctl new`
|
||||
// and applies them via `fleetctl gitops`, producing the same result as a user
|
||||
// running those commands manually.
|
||||
//
|
||||
// The runFleetctl callback should run the fleetctl CLI with the given arguments.
|
||||
// This keeps the CLI dependency out of the service package.
|
||||
func ApplyStarterLibrary(
|
||||
ctx context.Context,
|
||||
serverURL string,
|
||||
token string,
|
||||
logger *slog.Logger,
|
||||
httpClientFactory func(opts ...fleethttp.ClientOpt) *http.Client,
|
||||
clientFactory func(serverURL string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error),
|
||||
// For testing only - if provided, this function will be used instead of client.ApplyGroup
|
||||
mockApplyGroup func(ctx context.Context, specs *spec.Group) error,
|
||||
runFleetctl func(args []string) error,
|
||||
) error {
|
||||
logger.DebugContext(ctx, "Applying starter library")
|
||||
|
||||
// Create a request with context for downloading the starter library
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, starterLibraryURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request for starter library: %w", err)
|
||||
}
|
||||
|
||||
// Download the starter library from GitHub using the provided HTTP client factory
|
||||
httpClient := httpClientFactory(fleethttp.WithTimeout(5 * time.Second))
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download starter library: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download starter library, status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read starter library response body: %w", err)
|
||||
}
|
||||
|
||||
// Create a temporary directory to store downloaded scripts
|
||||
tempDir, err := os.MkdirTemp("", "fleet-scripts-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir) // Clean up the temporary directory when done
|
||||
|
||||
logger.DebugContext(ctx, "Created temporary directory for scripts", "path", tempDir)
|
||||
|
||||
// Parse the YAML content into specs
|
||||
specs, err := spec.GroupFromBytes(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse starter library: %w", err)
|
||||
}
|
||||
|
||||
// Find all script references in the YAML and download them
|
||||
scriptNames := ExtractScriptNames(specs)
|
||||
logger.DebugContext(ctx, "Found script references in starter library", "count", len(scriptNames))
|
||||
|
||||
// Download scripts and update references in specs
|
||||
if len(scriptNames) > 0 {
|
||||
err = DownloadAndUpdateScripts(ctx, specs, scriptNames, tempDir, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download and update scripts: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create an authenticated client and apply specs using the provided client factory
|
||||
client, err := clientFactory(serverURL, true, "", "")
|
||||
// Create an authenticated client to fetch app config.
|
||||
client, err := NewClient(serverURL, true, "", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create client: %w", err)
|
||||
}
|
||||
client.SetToken(token)
|
||||
|
||||
// Always check if license is free and skip teams for free licenses
|
||||
appConfig, err := client.GetAppConfig()
|
||||
if err != nil {
|
||||
logger.DebugContext(ctx, "Error getting app config", "err", err)
|
||||
// Continue even if there's an error getting the app config
|
||||
} else if appConfig.License == nil || !appConfig.License.IsPremium() {
|
||||
// Remove teams from specs to avoid applying them
|
||||
logger.DebugContext(ctx, "Free license detected, skipping teams and team-related content in starter library")
|
||||
specs.Teams = nil
|
||||
return fmt.Errorf("failed to get app config: %w", err)
|
||||
}
|
||||
|
||||
// Filter out policies that reference teams
|
||||
if specs.Policies != nil {
|
||||
var filteredPolicies []*fleet.PolicySpec
|
||||
for _, policy := range specs.Policies {
|
||||
// Keep only policies that don't reference a team
|
||||
if policy.Team == "" {
|
||||
filteredPolicies = append(filteredPolicies, policy)
|
||||
}
|
||||
}
|
||||
specs.Policies = filteredPolicies
|
||||
orgName := appConfig.OrgInfo.OrgName
|
||||
if orgName == "" {
|
||||
orgName = "Fleet"
|
||||
}
|
||||
|
||||
// Create a temp directory for the rendered templates.
|
||||
tempDir, err := os.MkdirTemp("", "fleet-starter-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
outDir := filepath.Join(tempDir, "gitops")
|
||||
|
||||
// Render templates using `fleetctl new`.
|
||||
if err := runFleetctl([]string{"new", "--org-name", orgName, "--dir", outDir}); err != nil {
|
||||
return fmt.Errorf("fleetctl new: %w", err)
|
||||
}
|
||||
|
||||
// Set env overrides so GitOpsFromFile can expand $FLEET_URL without
|
||||
// polluting the process environment.
|
||||
spec.SetEnvOverrides(map[string]string{
|
||||
"FLEET_URL": serverURL,
|
||||
})
|
||||
defer spec.SetEnvOverrides(nil)
|
||||
|
||||
// Write a temporary fleetctl config file with auth credentials.
|
||||
configFile, err := os.CreateTemp(tempDir, "fleetctl-config-*.yml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create fleetctl config: %w", err)
|
||||
}
|
||||
fmt.Fprintf(configFile, "contexts:\n default:\n address: %s\n tls-skip-verify: true\n token: %s\n",
|
||||
serverURL, token)
|
||||
configFile.Close()
|
||||
|
||||
// Build the gitops args: global config first, then team configs (premium only).
|
||||
args := []string{"gitops", "--config", configFile.Name(), "-f", filepath.Join(outDir, "default.yml")}
|
||||
|
||||
if appConfig.License != nil && appConfig.License.IsPremium() {
|
||||
fleetDir := filepath.Join(outDir, "fleets")
|
||||
entries, err := os.ReadDir(fleetDir)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to read fleets directory: %w", err)
|
||||
}
|
||||
|
||||
// Note: QuerySpec doesn't have a Team field, so we can't filter queries by team
|
||||
|
||||
// Remove scripts from AppConfig if present
|
||||
if specs.AppConfig != nil {
|
||||
appConfigMap, ok := specs.AppConfig.(map[string]interface{})
|
||||
if ok {
|
||||
// Remove scripts from AppConfig
|
||||
delete(appConfigMap, "scripts")
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || filepath.Ext(entry.Name()) != ".yml" {
|
||||
continue
|
||||
}
|
||||
args = append(args, "-f", filepath.Join(fleetDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
// Log function for ApplyGroup (minimal logging)
|
||||
logf := func(format string, a ...interface{}) {}
|
||||
|
||||
// Assign the real implementation
|
||||
var applyGroupFn applyGroupFunc = func(ctx context.Context, specs *spec.Group) error {
|
||||
teamsSoftwareInstallers := make(map[string][]fleet.SoftwarePackageResponse)
|
||||
teamsScripts := make(map[string][]fleet.ScriptResponse)
|
||||
teamsVPPApps := make(map[string][]fleet.VPPAppResponse)
|
||||
|
||||
_, _, _, _, err := client.ApplyGroup(
|
||||
ctx,
|
||||
false,
|
||||
specs,
|
||||
tempDir,
|
||||
logf,
|
||||
nil,
|
||||
fleet.ApplyClientSpecOptions{},
|
||||
teamsSoftwareInstallers,
|
||||
teamsVPPApps,
|
||||
teamsScripts,
|
||||
nil,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply mock if mockApplyGroup is supplied
|
||||
if mockApplyGroup != nil {
|
||||
applyGroupFn = mockApplyGroup
|
||||
}
|
||||
|
||||
if err := applyGroupFn(ctx, specs); err != nil {
|
||||
return fmt.Errorf("failed to apply starter library: %w", err)
|
||||
if err := runFleetctl(args); err != nil {
|
||||
return fmt.Errorf("fleetctl gitops: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugContext(ctx, "Starter library applied successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractScriptNames extracts all script names from the specs
|
||||
func ExtractScriptNames(specs *spec.Group) []string {
|
||||
var scriptNames []string
|
||||
scriptMap := make(map[string]bool) // Use a map to deduplicate script names
|
||||
|
||||
// Process team specs
|
||||
for _, teamRaw := range specs.Teams {
|
||||
var teamData map[string]interface{}
|
||||
if err := json.Unmarshal(teamRaw, &teamData); err != nil {
|
||||
continue // Skip if we can't unmarshal
|
||||
}
|
||||
|
||||
if scripts, ok := teamData["scripts"].([]interface{}); ok {
|
||||
for _, script := range scripts {
|
||||
if scriptName, ok := script.(string); ok && !scriptMap[scriptName] {
|
||||
scriptMap[scriptName] = true
|
||||
scriptNames = append(scriptNames, scriptName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scriptNames
|
||||
}
|
||||
|
||||
// DownloadAndUpdateScripts downloads scripts from URLs and updates the specs to reference local files
|
||||
func DownloadAndUpdateScripts(ctx context.Context, specs *spec.Group, scriptNames []string, tempDir string, logger *slog.Logger) error {
|
||||
// Create a single HTTP client to be reused for all requests
|
||||
httpClient := fleethttp.NewClient(fleethttp.WithTimeout(5 * time.Second))
|
||||
|
||||
// Map to store local paths for each script
|
||||
scriptPaths := make(map[string]string, len(scriptNames))
|
||||
|
||||
// Download each script sequentially
|
||||
for _, scriptName := range scriptNames {
|
||||
// Sanitize the script name to prevent path traversal
|
||||
sanitizedName := filepath.Clean(scriptName)
|
||||
if strings.HasPrefix(sanitizedName, "..") || filepath.IsAbs(sanitizedName) {
|
||||
return fmt.Errorf("invalid script name %s: must be a relative path", scriptName)
|
||||
}
|
||||
|
||||
localPath := filepath.Join(tempDir, sanitizedName)
|
||||
scriptPaths[scriptName] = localPath
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
parentDir := filepath.Dir(localPath)
|
||||
if err := os.MkdirAll(parentDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directories for script %s: %w", scriptName, err)
|
||||
}
|
||||
|
||||
scriptURL := fmt.Sprintf("%s/%s", scriptsBaseURL, scriptName)
|
||||
logger.DebugContext(ctx, "Downloading script", "name", scriptName, "url", scriptURL, "local_path", localPath)
|
||||
|
||||
// Create the request with context
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, scriptURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request for script %s: %w", scriptName, err)
|
||||
}
|
||||
|
||||
// Download the script using the shared HTTP client
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download script %s: %w", scriptName, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download script %s, status: %d", scriptName, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Create the local file
|
||||
file, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local file for script %s: %w", scriptName, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Copy the content to the local file
|
||||
_, err = io.Copy(file, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write script %s to local file: %w", scriptName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read script contents and store them in memory
|
||||
scriptContents := make(map[string][]byte, len(scriptNames))
|
||||
for _, scriptName := range scriptNames {
|
||||
localPath := scriptPaths[scriptName]
|
||||
content, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read script %s from local file: %w", scriptName, err)
|
||||
}
|
||||
scriptContents[scriptName] = content
|
||||
}
|
||||
|
||||
// Extract scripts from AppConfig if present
|
||||
appConfigScripts := extractAppCfgScripts(specs.AppConfig)
|
||||
if appConfigScripts != nil {
|
||||
// Replace script paths with actual script contents
|
||||
appScripts := make([]string, 0, len(appConfigScripts))
|
||||
for _, scriptPath := range appConfigScripts {
|
||||
if content, exists := scriptContents[scriptPath]; exists {
|
||||
// Create a temporary file with the script content
|
||||
tempFile, err := os.CreateTemp(tempDir, "script-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary script file: %w", err)
|
||||
}
|
||||
if _, err := tempFile.Write(content); err != nil {
|
||||
tempFile.Close()
|
||||
return fmt.Errorf("failed to write script content to temporary file: %w", err)
|
||||
}
|
||||
tempFile.Close()
|
||||
|
||||
// Add the temporary file path to the list
|
||||
appScripts = append(appScripts, tempFile.Name())
|
||||
} else {
|
||||
// Keep the original path if it's not one of our downloaded scripts
|
||||
appScripts = append(appScripts, scriptPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Update the AppConfig with the new script paths
|
||||
if specs.AppConfig != nil {
|
||||
specs.AppConfig.(map[string]interface{})["scripts"] = appScripts
|
||||
}
|
||||
}
|
||||
|
||||
// Update script references in the specs to point to local files
|
||||
for i, teamRaw := range specs.Teams {
|
||||
var teamData map[string]interface{}
|
||||
if err := json.Unmarshal(teamRaw, &teamData); err != nil {
|
||||
continue // Skip if we can't unmarshal
|
||||
}
|
||||
|
||||
if scripts, ok := teamData["scripts"].([]interface{}); ok {
|
||||
for j, script := range scripts {
|
||||
if scriptName, ok := script.(string); ok {
|
||||
// Update the script reference to the local path from our map
|
||||
if localPath, exists := scriptPaths[scriptName]; exists {
|
||||
scripts[j] = localPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the team data with modified scripts
|
||||
teamData["scripts"] = scripts
|
||||
|
||||
// Marshal back to JSON
|
||||
updatedTeamRaw, err := json.Marshal(teamData)
|
||||
if err != nil {
|
||||
logger.DebugContext(ctx, "Failed to marshal updated team data", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Update the team in the specs
|
||||
specs.Teams[i] = updatedTeamRaw
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,712 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
"github.com/fleetdm/fleet/v4/pkg/spec"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockRoundTripper is a custom http.RoundTripper that redirects requests to a mock server
|
||||
type mockRoundTripper struct {
|
||||
mockServer string
|
||||
origBaseURL string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// If the request URL contains the original base URL, replace it with the mock server URL
|
||||
if strings.Contains(req.URL.String(), rt.origBaseURL) {
|
||||
// Extract the path from the original URL
|
||||
path := strings.TrimPrefix(req.URL.Path, "/")
|
||||
|
||||
// Create a new URL with the mock server
|
||||
newURL := fmt.Sprintf("%s/%s", rt.mockServer, path)
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy headers
|
||||
newReq.Header = req.Header
|
||||
|
||||
// Use the next transport to perform the request
|
||||
return rt.next.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// For other requests, use the next transport
|
||||
return rt.next.RoundTrip(req)
|
||||
}
|
||||
|
||||
// Helper function to create HTTP responses for testing
|
||||
func createTestResponse(code int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
// testRoundTripper2 is a mock implementation of http.RoundTripper for tracking URL calls
|
||||
type testRoundTripper2 struct {
|
||||
RoundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
calls []string // Track URLs that were called
|
||||
}
|
||||
|
||||
func (m *testRoundTripper2) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
m.calls = append(m.calls, req.URL.String())
|
||||
return m.RoundTripFunc(req)
|
||||
}
|
||||
|
||||
func TestExtractScriptNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teams []map[string]interface{}
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "multiple teams with scripts",
|
||||
teams: []map[string]interface{}{
|
||||
{
|
||||
"name": "Team1",
|
||||
"scripts": []interface{}{"script1.sh", "script2.sh"},
|
||||
},
|
||||
{
|
||||
"name": "Team2",
|
||||
"scripts": []interface{}{"script2.sh", "script3.sh"}, // Note: script2.sh is duplicated
|
||||
},
|
||||
{
|
||||
"name": "Team3", // No scripts
|
||||
},
|
||||
},
|
||||
expected: []string{"script1.sh", "script2.sh", "script3.sh"},
|
||||
},
|
||||
{
|
||||
name: "no teams",
|
||||
teams: []map[string]interface{}{},
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create spec group from test data
|
||||
var teams []json.RawMessage
|
||||
for _, team := range tt.teams {
|
||||
teamRaw, err := json.Marshal(team)
|
||||
require.NoError(t, err)
|
||||
teams = append(teams, teamRaw)
|
||||
}
|
||||
specs := &spec.Group{Teams: teams}
|
||||
|
||||
// Call the function
|
||||
scriptNames := ExtractScriptNames(specs)
|
||||
|
||||
// Verify the results
|
||||
assert.Len(t, scriptNames, len(tt.expected))
|
||||
for _, name := range tt.expected {
|
||||
assert.Contains(t, scriptNames, name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndUpdateScripts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scriptNames []string
|
||||
scriptPaths []string
|
||||
}{
|
||||
{
|
||||
name: "single script",
|
||||
scriptNames: []string{"test-script.sh"},
|
||||
scriptPaths: []string{"test-script.sh"},
|
||||
},
|
||||
{
|
||||
name: "multiple scripts with nested path",
|
||||
scriptNames: []string{"test-script.sh", "subfolder/nested-script.sh"},
|
||||
scriptPaths: []string{"test-script.sh", "subfolder/nested-script.sh"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a mock server to serve the scripts
|
||||
scriptContent := "#!/bin/bash\necho 'Hello, World!'"
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(scriptContent))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Save the original HTTP transport
|
||||
origTransport := http.DefaultTransport
|
||||
|
||||
// Create a custom transport that redirects requests to our mock server
|
||||
mockTransport := &mockRoundTripper{
|
||||
mockServer: mockServer.URL,
|
||||
origBaseURL: scriptsBaseURL,
|
||||
next: origTransport,
|
||||
}
|
||||
|
||||
// Replace the default transport with our mock transport
|
||||
http.DefaultTransport = mockTransport
|
||||
|
||||
// Restore the original transport when the test is done
|
||||
defer func() {
|
||||
http.DefaultTransport = origTransport
|
||||
}()
|
||||
|
||||
// Create a temporary directory
|
||||
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test spec group
|
||||
teamData := map[string]interface{}{
|
||||
"name": "Team1",
|
||||
"scripts": []interface{}{tt.scriptNames[0]},
|
||||
}
|
||||
teamRaw, err := json.Marshal(teamData)
|
||||
require.NoError(t, err)
|
||||
|
||||
specs := &spec.Group{
|
||||
Teams: []json.RawMessage{teamRaw},
|
||||
}
|
||||
|
||||
// Call the actual production function
|
||||
err = DownloadAndUpdateScripts(context.Background(), specs, tt.scriptNames, tempDir, slog.New(slog.DiscardHandler))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the scripts were downloaded
|
||||
for _, scriptName := range tt.scriptPaths {
|
||||
scriptPath := filepath.Join(tempDir, scriptName)
|
||||
_, err := os.Stat(scriptPath)
|
||||
assert.NoError(t, err, "Script should exist: %s", scriptPath)
|
||||
|
||||
// Verify the content
|
||||
content, err := os.ReadFile(scriptPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, scriptContent, string(content))
|
||||
}
|
||||
|
||||
// Verify the specs were updated
|
||||
var updatedTeamData map[string]interface{}
|
||||
err = json.Unmarshal(specs.Teams[0], &updatedTeamData)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedScripts, ok := updatedTeamData["scripts"].([]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// The scripts should now be local paths
|
||||
for i, script := range updatedScripts {
|
||||
scriptPath, ok := script.(string)
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, scriptPath, tempDir)
|
||||
assert.Contains(t, scriptPath, tt.scriptNames[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndUpdateScriptsWithInvalidPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scriptNames []string
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
scriptNames: []string{"../test-script.sh"},
|
||||
errorMsg: "invalid script name",
|
||||
},
|
||||
{
|
||||
name: "absolute path attempt",
|
||||
scriptNames: []string{"/etc/passwd"},
|
||||
errorMsg: "invalid script name",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a mock server to serve the scripts
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("script content"))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Save the original HTTP transport
|
||||
origTransport := http.DefaultTransport
|
||||
|
||||
// Create a custom transport that redirects requests to our mock server
|
||||
mockTransport := &mockRoundTripper{
|
||||
mockServer: mockServer.URL,
|
||||
origBaseURL: scriptsBaseURL,
|
||||
next: origTransport,
|
||||
}
|
||||
|
||||
// Replace the default transport with our mock transport
|
||||
http.DefaultTransport = mockTransport
|
||||
|
||||
// Restore the original transport when the test is done
|
||||
defer func() {
|
||||
http.DefaultTransport = origTransport
|
||||
}()
|
||||
|
||||
// Create a temporary directory
|
||||
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test spec group
|
||||
teamData := map[string]interface{}{
|
||||
"name": "Team1",
|
||||
"scripts": []interface{}{tt.scriptNames[0]},
|
||||
}
|
||||
teamRaw, err := json.Marshal(teamData)
|
||||
require.NoError(t, err)
|
||||
|
||||
specs := &spec.Group{
|
||||
Teams: []json.RawMessage{teamRaw},
|
||||
}
|
||||
|
||||
// Call the actual production function
|
||||
err = DownloadAndUpdateScripts(context.Background(), specs, tt.scriptNames, tempDir, slog.New(slog.DiscardHandler))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFleetHTTPClientOverride(t *testing.T) {
|
||||
// Create a mock server
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("test response"))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Create a custom client that uses our mock server
|
||||
client := fleethttp.NewClient()
|
||||
client.Transport = &mockRoundTripper{
|
||||
mockServer: mockServer.URL,
|
||||
origBaseURL: scriptsBaseURL,
|
||||
next: http.DefaultTransport,
|
||||
}
|
||||
|
||||
// Create a fleethttp client with our custom transport
|
||||
fleetClient := fleethttp.NewClient()
|
||||
// Replace the client's transport with our mock transport
|
||||
fleetClient.Transport = client.Transport
|
||||
|
||||
// Create a request to the original URL
|
||||
req, err := http.NewRequest("GET", scriptsBaseURL+"/test.sh", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send the request
|
||||
resp, err := fleetClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Verify the response
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test response", string(body))
|
||||
}
|
||||
|
||||
func TestDownloadAndUpdateScriptsTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sleepTime time.Duration
|
||||
contextTime time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "timeout occurs",
|
||||
sleepTime: 6 * time.Second, // Server sleeps longer than timeout
|
||||
contextTime: 2 * time.Second, // Short context timeout
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "no timeout",
|
||||
sleepTime: 1 * time.Second, // Server responds quickly
|
||||
contextTime: 5 * time.Second, // Longer context timeout
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Skip the long-running test in short mode
|
||||
if tt.sleepTime > 2*time.Second && testing.Short() {
|
||||
t.Skip("Skipping long-running test in short mode")
|
||||
}
|
||||
|
||||
// Create a mock server that delays its response
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(tt.sleepTime)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("test response"))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Save the original HTTP transport
|
||||
origTransport := http.DefaultTransport
|
||||
|
||||
// Create a custom transport that redirects requests to our mock server
|
||||
mockTransport := &mockRoundTripper{
|
||||
mockServer: mockServer.URL,
|
||||
origBaseURL: scriptsBaseURL,
|
||||
next: origTransport,
|
||||
}
|
||||
|
||||
// Replace the default transport with our mock transport
|
||||
http.DefaultTransport = mockTransport
|
||||
|
||||
// Restore the original transport when the test is done
|
||||
defer func() {
|
||||
http.DefaultTransport = origTransport
|
||||
}()
|
||||
|
||||
// Create a temporary directory
|
||||
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test spec group
|
||||
teamData := map[string]interface{}{
|
||||
"name": "Team1",
|
||||
"scripts": []interface{}{"test-script.sh"},
|
||||
}
|
||||
teamRaw, err := json.Marshal(teamData)
|
||||
require.NoError(t, err)
|
||||
|
||||
specs := &spec.Group{
|
||||
Teams: []json.RawMessage{teamRaw},
|
||||
}
|
||||
|
||||
// Define script names
|
||||
scriptNames := []string{"test-script.sh"}
|
||||
|
||||
// Set context timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tt.contextTime)
|
||||
defer cancel()
|
||||
|
||||
// Call the actual production function
|
||||
err = DownloadAndUpdateScripts(ctx, specs, scriptNames, tempDir, slog.New(slog.DiscardHandler))
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
// The error message might vary depending on the HTTP client implementation,
|
||||
// but it should contain either "timeout", "deadline exceeded", or "context canceled"
|
||||
errorMsg := strings.ToLower(err.Error())
|
||||
timeoutRelated := strings.Contains(errorMsg, "timeout") ||
|
||||
strings.Contains(errorMsg, "deadline exceeded") ||
|
||||
strings.Contains(errorMsg, "context canceled")
|
||||
assert.True(t, timeoutRelated, "Expected a timeout-related error, got: %s", err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyStarterLibraryWithMockClient(t *testing.T) {
|
||||
// Read the real production starter library YAML file
|
||||
starterLibraryPath := "../../docs/01-Using-Fleet/starter-library/starter-library.yml"
|
||||
starterLibraryContent, err := os.ReadFile(starterLibraryPath)
|
||||
require.NoError(t, err, "Should be able to read starter library YAML file")
|
||||
|
||||
// Create mock HTTP client for downloading the starter library and scripts
|
||||
mockRT := &testRoundTripper2{
|
||||
calls: []string{},
|
||||
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case req.URL.String() == starterLibraryURL:
|
||||
// Return the real starter library content
|
||||
return createTestResponse(200, string(starterLibraryContent)), nil
|
||||
case strings.Contains(req.URL.String(), "uninstall-fleetd"):
|
||||
// Return a simple script for any script URL
|
||||
return createTestResponse(200, "#!/bin/bash\necho ok"), nil
|
||||
default:
|
||||
// For any other URL, return a 404
|
||||
return createTestResponse(404, "Not found"), nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
|
||||
client := fleethttp.NewClient(opts...)
|
||||
client.Transport = mockRT
|
||||
return client
|
||||
}
|
||||
|
||||
// Create a client factory that returns a real client
|
||||
// We're not testing the token setting functionality
|
||||
clientFactory := NewClient
|
||||
|
||||
// Track if ApplyGroup was called and capture the specs
|
||||
applyGroupCalled := false
|
||||
var capturedSpecs *spec.Group
|
||||
// Create a mock ApplyGroup function
|
||||
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
|
||||
applyGroupCalled = true
|
||||
capturedSpecs = specs
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call the function under test
|
||||
testErr := ApplyStarterLibrary(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
"test-token",
|
||||
slog.New(slog.DiscardHandler),
|
||||
httpClientFactory,
|
||||
clientFactory,
|
||||
mockApplyGroup,
|
||||
)
|
||||
|
||||
// Verify results
|
||||
require.NoError(t, testErr)
|
||||
assert.True(t, applyGroupCalled, "ApplyGroup should have been called")
|
||||
|
||||
// Verify that the specs were correctly parsed
|
||||
require.NotNil(t, capturedSpecs, "Specs should not be nil")
|
||||
require.NotEmpty(t, capturedSpecs.Teams, "Specs should contain teams")
|
||||
|
||||
// Verify that the first team has the expected structure
|
||||
var team1 map[string]interface{}
|
||||
unmarshalErr := json.Unmarshal(capturedSpecs.Teams[0], &team1)
|
||||
require.NoError(t, unmarshalErr, "Should be able to unmarshal team JSON")
|
||||
|
||||
// Verify that the team has a name
|
||||
require.Contains(t, team1, "name", "Team should have a name")
|
||||
teamName := team1["name"].(string)
|
||||
require.NotEmpty(t, teamName, "Team name should not be empty")
|
||||
|
||||
// Verify that the starter library URL was requested
|
||||
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
|
||||
}
|
||||
|
||||
func TestApplyStarterLibraryWithMalformedYAML(t *testing.T) {
|
||||
// Create mock HTTP client that returns malformed YAML
|
||||
malformedYAML := `
|
||||
teams:
|
||||
- name: "Malformed Team
|
||||
# Missing closing quote and improper indentation
|
||||
scripts:
|
||||
- "script1.sh
|
||||
`
|
||||
|
||||
mockRT := &testRoundTripper2{
|
||||
calls: []string{},
|
||||
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case req.URL.String() == starterLibraryURL:
|
||||
// Return malformed YAML content
|
||||
return createTestResponse(200, malformedYAML), nil
|
||||
default:
|
||||
// For any other URL, return a 404
|
||||
return createTestResponse(404, "Not found"), nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
|
||||
client := fleethttp.NewClient(opts...)
|
||||
client.Transport = mockRT
|
||||
return client
|
||||
}
|
||||
|
||||
// Create a client factory that returns a real client
|
||||
clientFactory := NewClient
|
||||
|
||||
// Create a mock ApplyGroup function that should not be called
|
||||
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
|
||||
t.Error("ApplyGroup should not be called with malformed YAML")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use a defer/recover to explicitly catch any panics
|
||||
var panicValue interface{}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicValue = r
|
||||
t.Fatalf("Panic occurred when processing malformed YAML: %v", panicValue)
|
||||
}
|
||||
}()
|
||||
|
||||
// Call the function under test
|
||||
testErr := ApplyStarterLibrary(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
"test-token",
|
||||
slog.New(slog.DiscardHandler),
|
||||
httpClientFactory,
|
||||
clientFactory,
|
||||
mockApplyGroup,
|
||||
)
|
||||
|
||||
// Verify results
|
||||
require.Error(t, testErr, "Should return an error with malformed YAML")
|
||||
assert.Contains(t, testErr.Error(), "failed to parse starter library",
|
||||
"Error should indicate YAML parsing failure")
|
||||
|
||||
// Verify that the starter library URL was requested
|
||||
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
|
||||
|
||||
// If we reach here, no panic occurred and the setup flow was not interrupted
|
||||
}
|
||||
|
||||
func TestApplyStarterLibraryWithFreeLicense(t *testing.T) {
|
||||
// Read the real production starter library YAML file
|
||||
starterLibraryPath := "../../docs/01-Using-Fleet/starter-library/starter-library.yml"
|
||||
starterLibraryContent, err := os.ReadFile(starterLibraryPath)
|
||||
require.NoError(t, err, "Should be able to read starter library YAML file")
|
||||
|
||||
// Create mock HTTP client for downloading the starter library and scripts
|
||||
mockRT := &testRoundTripper2{
|
||||
calls: []string{},
|
||||
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
case req.URL.String() == starterLibraryURL:
|
||||
// Return the real starter library content
|
||||
return createTestResponse(200, string(starterLibraryContent)), nil
|
||||
case strings.Contains(req.URL.String(), "uninstall-fleetd"):
|
||||
// Return a simple script for any script URL
|
||||
return createTestResponse(200, "#!/bin/bash\necho ok"), nil
|
||||
default:
|
||||
// For any other URL, return a 404
|
||||
return createTestResponse(404, "Not found"), nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
|
||||
client := fleethttp.NewClient(opts...)
|
||||
client.Transport = mockRT
|
||||
return client
|
||||
}
|
||||
|
||||
// Create a mock client that returns a free license
|
||||
// Create a properly structured EnrichedAppConfig
|
||||
mockEnrichedAppConfig := &fleet.EnrichedAppConfig{}
|
||||
// Set the License field using json marshaling/unmarshaling to bypass unexported field access
|
||||
configJSON := []byte(`{"license":{"tier":"free"}}`)
|
||||
if err := json.Unmarshal(configJSON, mockEnrichedAppConfig); err != nil {
|
||||
t.Fatal("Failed to unmarshal mock config:", err)
|
||||
}
|
||||
|
||||
// Create a mock client factory
|
||||
clientFactory := func(serverURL string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) {
|
||||
mockClient := &Client{}
|
||||
|
||||
// Override the baseClient with a mock implementation
|
||||
// Create a mock HTTP client
|
||||
mockHTTPClient := &mockHTTPClient{
|
||||
DoFunc: func(req *http.Request) (*http.Response, error) {
|
||||
// Mock the GetAppConfig response
|
||||
if req.URL.Path == "/api/v1/fleet/config" && req.Method == http.MethodGet {
|
||||
respBody, _ := json.Marshal(mockEnrichedAppConfig)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBuffer(respBody)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Set up the baseClient with the mock HTTP client and a valid baseURL
|
||||
baseURL, _ := url.Parse(serverURL)
|
||||
mockClient.baseClient = &baseClient{
|
||||
HTTP: mockHTTPClient,
|
||||
BaseURL: baseURL,
|
||||
}
|
||||
|
||||
return mockClient, nil
|
||||
}
|
||||
|
||||
// Track if ApplyGroup was called and capture the specs
|
||||
applyGroupCalled := false
|
||||
var capturedSpecs *spec.Group
|
||||
// Create a mock ApplyGroup function
|
||||
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
|
||||
applyGroupCalled = true
|
||||
capturedSpecs = specs
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call the function under test - teams will be skipped automatically for free license
|
||||
testErr := ApplyStarterLibrary(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
"test-token",
|
||||
slog.New(slog.DiscardHandler),
|
||||
httpClientFactory,
|
||||
clientFactory,
|
||||
mockApplyGroup,
|
||||
)
|
||||
|
||||
// Verify results
|
||||
require.NoError(t, testErr)
|
||||
assert.True(t, applyGroupCalled, "ApplyGroup should have been called")
|
||||
|
||||
// Verify that the specs were correctly parsed
|
||||
require.NotNil(t, capturedSpecs, "Specs should not be nil")
|
||||
|
||||
// Verify that teams were removed
|
||||
require.Empty(t, capturedSpecs.Teams, "Teams should be empty for free license")
|
||||
|
||||
// Verify that policies referencing teams were filtered out
|
||||
if capturedSpecs.Policies != nil {
|
||||
for _, policy := range capturedSpecs.Policies {
|
||||
assert.Empty(t, policy.Team, "Policies should not reference teams for free license")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that scripts were removed from AppConfig
|
||||
if capturedSpecs.AppConfig != nil {
|
||||
appConfigMap, ok := capturedSpecs.AppConfig.(map[string]interface{})
|
||||
if ok {
|
||||
_, hasScripts := appConfigMap["scripts"]
|
||||
assert.False(t, hasScripts, "AppConfig should not contain scripts for free license")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the starter library URL was requested
|
||||
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
|
||||
}
|
||||
|
||||
// mockHTTPClient is a mock implementation of the http.Client
|
||||
type mockHTTPClient struct {
|
||||
DoFunc func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return m.DoFunc(req)
|
||||
}
|
||||
|
|
@ -1120,12 +1120,12 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
|||
// WithSetup is an http middleware that checks if setup procedures have been completed.
|
||||
// If setup hasn't been completed it serves the API with a setup middleware.
|
||||
// If the server is already configured, the default API handler is exposed.
|
||||
func WithSetup(svc fleet.Service, logger *slog.Logger, next http.Handler) http.HandlerFunc {
|
||||
func WithSetup(svc fleet.Service, logger *slog.Logger, applyStarterLibrary func(ctx context.Context, serverURL, token string) error, next http.Handler) http.HandlerFunc {
|
||||
rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
configRouter := http.NewServeMux()
|
||||
srv := kithttp.NewServer(
|
||||
makeSetupEndpoint(svc, logger),
|
||||
makeSetupEndpoint(svc, logger, applyStarterLibrary),
|
||||
decodeSetupRequest,
|
||||
encodeResponse,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
"net/http/httptest"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -1541,3 +1542,24 @@ type acmeCSRSigner struct {
|
|||
func (a *acmeCSRSigner) SignCSR(_ context.Context, csr *x509.CertificateRequest) (*x509.Certificate, error) {
|
||||
return a.signer.Signx509CSR(csr)
|
||||
}
|
||||
|
||||
// mockRoundTripper is a custom http.RoundTripper that redirects requests to a mock server.
|
||||
type mockRoundTripper struct {
|
||||
mockServer string
|
||||
origBaseURL string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if strings.Contains(req.URL.String(), rt.origBaseURL) {
|
||||
path := strings.TrimPrefix(req.URL.Path, "/")
|
||||
newURL := fmt.Sprintf("%s/%s", rt.mockServer, path)
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body) //nolint:gosec // test helper, URL is from mock server
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return rt.next.RoundTrip(newReq)
|
||||
}
|
||||
return rt.next.RoundTrip(req)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue