Use fleetctl new templates for new instances (#42768)

<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #41409 

# Details

This PR updates the `ApplyStarterLibrary` method and functionality to
rely on the same templates and mechanisms as `fleetctl new`. The end
result is that running `fleetctl new` and `fleetctl gitops` on a new
instance should be a no-op; no changes should be made. Similarly,
changing the templates in a Fleet release will automatically affect
`fleetctl new` and `ApplyStarterLibrary` in the same exact way for that
release.

> Note that this moves the template files out of `fleetctl` and into
their own shared package. This move comprises the majority of the file
changes in the PR.

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [X] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

## Testing

- [X] Added/updated automated tests
Note that 

<img width="668" height="44" alt="image"
src="https://github.com/user-attachments/assets/066cd566-f91d-4661-84fc-2aabbfce2ef9"
/>

will fail until the 4.83 Fleet docker image is published, since it's
trying to push 4.83 config (including `exceptions`) to a 4.82 server.

- [X] QA'd all new/changed functionality manually
- [X] Created a new instance and validated that the fleets, policies and
labels created matched the ones created by `fleetctl new`
- [X] Ran `fleetctl new` and verified that it created the expected
folders and files
- [X] Ran `fleetctl gitops` with the files created by `fleetctl new` and
verified that the instance was unchanged.
- [X] Ran `fleetctl preview` successfully using a dev build of the Fleet
server image (since it won't work against the latest published build,
which doesn't support `exceptions`). Verified it shows the expected
teams, policies and labels
This commit is contained in:
Scott Gress 2026-04-03 09:58:03 -05:00 committed by GitHub
parent a2e7c95c6c
commit c4aa6f5529
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 424 additions and 1048 deletions

View file

@ -1,7 +1,8 @@
name: Test latest changes in fleetctl preview
# Tests the `fleetctl preview` command with latest changes in fleetctl and
# docs/01-Using-Fleet/starter-library/starter-library.yml
# Tests the `fleetctl preview` command with the Fleet server and fleetctl
# built from the same commit, ensuring the starter library and GitOps
# pipeline work end-to-end.
on:
push:
@ -16,7 +17,6 @@ on:
- 'server/context/**.go'
- 'orbit/**.go'
- 'ee/fleetctl/**.go'
- 'docs/01-Using-Fleet/starter-library/starter-library.yml'
- '.github/workflows/fleetctl-preview-latest.yml'
- 'tools/osquery/in-a-box'
pull_request:
@ -27,7 +27,6 @@ on:
- 'server/context/**.go'
- 'orbit/**.go'
- 'ee/fleetctl/**.go'
- 'docs/01-Using-Fleet/starter-library/starter-library.yml'
- '.github/workflows/fleetctl-preview-latest.yml'
- 'tools/osquery/in-a-box'
workflow_dispatch: # Manual
@ -71,15 +70,43 @@ jobs:
with:
go-version-file: 'go.mod'
- name: Set up Node.js
uses: actions/setup-node@5e21ff4d9bc1a8cf6de233a3057d20ec6b3fb69d # v3.8.1
with:
node-version-file: package.json
check-latest: true
- name: Install JS dependencies
run: make deps
- name: Generate assets
run: make generate
- name: Build Fleetctl
run: make fleetctl
- name: Build Fleet server Docker image
run: |
make fleet-static
cp ./build/fleet fleet
docker build -t fleetdm/fleet:dev -f tools/fleet-docker/Dockerfile .
rm fleet
- name: Prepare preview config
run: |
# Copy the in-a-box config and set pull_policy so Docker uses the
# locally built image instead of trying to pull from Docker Hub.
cp -a tools/osquery/in-a-box /tmp/preview-config
# Add pull_policy: never to fleet01 and fleet02 services
sed -i '/^ fleet01:/,/^ [^ ]/{s/^\( image: fleetdm\/fleet.*\)/\1\n pull_policy: never/}' /tmp/preview-config/docker-compose.yml
sed -i '/^ fleet02:/,/^ [^ ]/{s/^\( image: fleetdm\/fleet.*\)/\1\n pull_policy: never/}' /tmp/preview-config/docker-compose.yml
- name: Run fleetctl preview
run: |
./build/fleetctl preview \
--tag dev \
--disable-open-browser \
--starter-library-file-path $(pwd)/docs/01-Using-Fleet/starter-library/starter-library.yml \
--preview-config-path ./tools/osquery/in-a-box
--preview-config-path /tmp/preview-config
sleep 10
./build/fleetctl get hosts | tee hosts.txt
[ $( cat hosts.txt | grep online | wc -l) -eq 8 ]

View file

@ -58,14 +58,16 @@ ifdef CIRCLE_TAG
DOCKER_IMAGE_TAG = ${CIRCLE_TAG}
endif
LDFLAGS_VERSION = "\
LDFLAGS_VERSION_RAW = \
-X github.com/fleetdm/fleet/v4/server/version.appName=${APP_NAME} \
-X github.com/fleetdm/fleet/v4/server/version.version=${VERSION} \
-X github.com/fleetdm/fleet/v4/server/version.branch=${BRANCH} \
-X github.com/fleetdm/fleet/v4/server/version.revision=${REVISION} \
-X github.com/fleetdm/fleet/v4/server/version.buildDate=${NOW} \
-X github.com/fleetdm/fleet/v4/server/version.buildUser=${USER} \
-X github.com/fleetdm/fleet/v4/server/version.goVersion=${GOVERSION}"
-X github.com/fleetdm/fleet/v4/server/version.goVersion=${GOVERSION}
LDFLAGS_VERSION = "${LDFLAGS_VERSION_RAW}"
LDFLAGS_VERSION_STATIC = "${LDFLAGS_VERSION_RAW} -extldflags '-static'"
# Macro to allow targets to filter out their own arguments from the arguments
# passed to the final command.
@ -198,6 +200,9 @@ endif
fleet: .prefix .pre-build .pre-fleet
CGO_ENABLED=1 go build -race=${GO_BUILD_RACE_ENABLED_VAR} -tags full,fts5,netgo -o build/${OUTPUT} -ldflags ${LDFLAGS_VERSION} ./cmd/fleet
fleet-static: .prefix .pre-build .pre-fleet
CGO_ENABLED=1 go build -tags full,fts5,netgo -trimpath -o build/${OUTPUT} -ldflags ${LDFLAGS_VERSION_STATIC} ./cmd/fleet
fleet-dev: GO_BUILD_RACE_ENABLED_VAR=true
fleet-dev: fleet

View file

@ -0,0 +1 @@
- Use the same templates for `fleetctl new` and new instance initialization

View file

@ -23,6 +23,7 @@ import (
"github.com/WatchBeam/clock"
"github.com/e-dard/netbug"
"github.com/fleetdm/fleet/v4/cmd/fleetctl/fleetctl"
"github.com/fleetdm/fleet/v4/ee/server/licensing"
"github.com/fleetdm/fleet/v4/ee/server/scim"
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
@ -1496,7 +1497,15 @@ func runServeCmd(cmd *cobra.Command, configManager configpkg.Manager, debug, dev
// By performing the same check inside main, we can make server startups
// more efficient after the first startup.
if setupRequired {
apiHandler = service.WithSetup(svc, logger, apiHandler)
// Pass in a closure to run the fleetctl command, so that the service layer
// doesn't need to import the CLI package.
applyStarterLibrary := func(ctx context.Context, serverURL, token string) error {
return service.ApplyStarterLibrary(ctx, serverURL, token, logger, func(args []string) error {
_, err := fleetctl.RunApp(args)
return err
})
}
apiHandler = service.WithSetup(svc, logger, applyStarterLibrary, apiHandler)
frontendHandler = service.RedirectLoginToSetup(svc, logger, frontendHandler, config.Server.URLPrefix)
} else {
frontendHandler = service.RedirectSetupToLogin(svc, logger, frontendHandler, config.Server.URLPrefix)

View file

@ -1,8 +1,10 @@
package fleetctl
import (
"bytes"
"errors"
"io"
"os"
eefleetctl "github.com/fleetdm/fleet/v4/ee/fleetctl"
"github.com/fleetdm/fleet/v4/server/version"
@ -78,3 +80,14 @@ func CreateApp(
}
return app
}
func RunApp(args []string) (*bytes.Buffer, error) {
// first arg must be the binary name. Allow tests to omit it.
args = append([]string{""}, args...)
w := new(bytes.Buffer)
app := CreateApp(nil, w, os.Stderr, noopExitErrHandler)
StashRawArgs(app, args)
err := app.Run(args)
return w, err
}

View file

@ -400,9 +400,10 @@ Use the stop and reset subcommands to manage the server and dependencies once st
address,
token,
logger,
fleethttp.NewClient,
service.NewClient,
nil, // No mock ApplyGroup for production code
func(args []string) error {
_, err := RunApp(args)
return err
},
); err != nil {
return fmt.Errorf("failed to apply starter library: %w", err)
}

View file

@ -3,7 +3,6 @@ package fleetctl
import (
"bytes"
"io"
"os"
"testing"
"github.com/stretchr/testify/require"
@ -23,17 +22,6 @@ func RunAppCheckErr(t *testing.T, args []string, errorMsg string) string {
return w.String()
}
func RunAppNoChecks(args []string) (*bytes.Buffer, error) {
// first arg must be the binary name. Allow tests to omit it.
args = append([]string{""}, args...)
w := new(bytes.Buffer)
app := CreateApp(nil, w, os.Stderr, noopExitErrHandler)
StashRawArgs(app, args)
err := app.Run(args)
return w, err
}
func RunWithErrWriter(args []string, errWriter io.Writer) (*bytes.Buffer, error) {
args = append([]string{""}, args...)
@ -45,3 +33,9 @@ func RunWithErrWriter(args []string, errWriter io.Writer) (*bytes.Buffer, error)
}
func noopExitErrHandler(c *cli.Context, err error) {}
// Alias for RunApp; added rather than changing all existing calls to `RunApp`,
// to avoid confusion and in case the behavior of `RunApp` needs to diverge in the future.
func RunAppNoChecks(args []string) (*bytes.Buffer, error) {
return RunApp(args)
}

View file

@ -0,0 +1,117 @@
package gitops
import (
"context"
"log/slog"
"testing"
"github.com/fleetdm/fleet/v4/cmd/fleetctl/fleetctl"
"github.com/fleetdm/fleet/v4/cmd/fleetctl/integrationtest"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
func TestIntegrationsEnterpriseStarterLibrary(t *testing.T) {
testingSuite := new(starterLibraryIntegrationEnterpriseTestSuite)
testingSuite.WithServer.Suite = &testingSuite.Suite
suite.Run(t, testingSuite)
}
type starterLibraryIntegrationEnterpriseTestSuite struct {
suite.Suite
integrationtest.WithServer
}
func (s *starterLibraryIntegrationEnterpriseTestSuite) SetupSuite() {
s.WithDS.SetupSuite("starterLibraryIntegrationEnterpriseTestSuite")
appConf, err := s.DS.AppConfig(context.Background())
s.Require().NoError(err)
err = s.DS.SaveAppConfig(context.Background(), appConf)
s.Require().NoError(err)
fleetCfg := config.TestConfig()
fleetCfg.Osquery.EnrollCooldown = 0
redisPool := redistest.SetupRedis(s.T(), "starter_library_enterprise", false, false, false)
serverConfig := service.TestServerOpts{
License: &fleet.LicenseInfo{
Tier: fleet.TierPremium,
},
FleetConfig: &fleetCfg,
Pool: redisPool,
}
users, server := service.RunServerForTestsWithDS(s.T(), s.DS, &serverConfig)
s.T().Setenv("FLEET_SERVER_ADDRESS", server.URL)
s.Server = server
s.Users = users
appConf, err = s.DS.AppConfig(context.Background())
s.Require().NoError(err)
appConf.ServerSettings.ServerURL = server.URL
appConf.OrgInfo.OrgName = "Test Org"
appConf.GitOpsConfig.Exceptions = fleet.GitOpsExceptions{}
err = s.DS.SaveAppConfig(context.Background(), appConf)
s.Require().NoError(err)
}
// TestApplyStarterLibraryPremium verifies that ApplyStarterLibrary applies the
// global config and team configs when using a premium license.
func (s *starterLibraryIntegrationEnterpriseTestSuite) TestApplyStarterLibraryPremium() {
t := s.T()
ctx := context.Background()
token := s.GetTestToken("admin1@example.com", test.GoodPassword)
logger := slog.New(slog.DiscardHandler)
err := service.ApplyStarterLibrary(
ctx,
s.Server.URL,
token,
logger,
func(args []string) error {
_, err := fleetctl.RunAppNoChecks(args)
return err
},
)
require.NoError(t, err)
// Verify the org name was applied.
appConfig, err := s.DS.AppConfig(ctx)
require.NoError(t, err)
assert.Equal(t, "Test Org", appConfig.OrgInfo.OrgName)
// Verify that the teams from the starter templates were created.
teams, err := s.DS.ListTeams(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}}, fleet.ListOptions{})
require.NoError(t, err)
teamNames := make([]string, len(teams))
for i, tm := range teams {
teamNames[i] = tm.Name
}
assert.Contains(t, teamNames, "💻 Workstations")
assert.Contains(t, teamNames, "📱🔐 Personal mobile devices")
// Verify labels were created (global labels, team_id=0).
labelSpecs, err := s.DS.GetLabelSpecs(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}})
require.NoError(t, err)
var customLabelNames []string
for _, l := range labelSpecs {
if l.LabelType != fleet.LabelTypeBuiltIn {
customLabelNames = append(customLabelNames, l.Name)
}
}
assert.Contains(t, customLabelNames, "Apple Silicon macOS hosts")
assert.Contains(t, customLabelNames, "ARM-based Windows hosts")
assert.Contains(t, customLabelNames, "Debian-based Linux hosts")
assert.Contains(t, customLabelNames, "x86-based Windows hosts")
}

View file

@ -0,0 +1,108 @@
package gitops
import (
"context"
"log/slog"
"testing"
"github.com/fleetdm/fleet/v4/cmd/fleetctl/fleetctl"
"github.com/fleetdm/fleet/v4/cmd/fleetctl/integrationtest"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/redis/redistest"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
func TestIntegrationsStarterLibrary(t *testing.T) {
testingSuite := new(starterLibraryIntegrationTestSuite)
testingSuite.WithServer.Suite = &testingSuite.Suite
suite.Run(t, testingSuite)
}
type starterLibraryIntegrationTestSuite struct {
suite.Suite
integrationtest.WithServer
}
func (s *starterLibraryIntegrationTestSuite) SetupSuite() {
s.WithDS.SetupSuite("starterLibraryIntegrationTestSuite")
appConf, err := s.DS.AppConfig(context.Background())
s.Require().NoError(err)
err = s.DS.SaveAppConfig(context.Background(), appConf)
s.Require().NoError(err)
fleetCfg := config.TestConfig()
fleetCfg.Osquery.EnrollCooldown = 0
redisPool := redistest.SetupRedis(s.T(), "starter_library", false, false, false)
serverConfig := service.TestServerOpts{
FleetConfig: &fleetCfg,
Pool: redisPool,
}
users, server := service.RunServerForTestsWithDS(s.T(), s.DS, &serverConfig)
s.T().Setenv("FLEET_SERVER_ADDRESS", server.URL)
s.Server = server
s.Users = users
appConf, err = s.DS.AppConfig(context.Background())
s.Require().NoError(err)
appConf.ServerSettings.ServerURL = server.URL
appConf.OrgInfo.OrgName = "Test Org"
appConf.GitOpsConfig.Exceptions = fleet.GitOpsExceptions{}
err = s.DS.SaveAppConfig(context.Background(), appConf)
s.Require().NoError(err)
}
// TestApplyStarterLibraryFree verifies that ApplyStarterLibrary applies only
// the global config (no teams) when using a free license.
func (s *starterLibraryIntegrationTestSuite) TestApplyStarterLibraryFree() {
t := s.T()
ctx := context.Background()
token := s.GetTestToken("admin1@example.com", test.GoodPassword)
logger := slog.New(slog.DiscardHandler)
err := service.ApplyStarterLibrary(
ctx,
s.Server.URL,
token,
logger,
func(args []string) error {
_, err := fleetctl.RunAppNoChecks(args)
return err
},
)
require.NoError(t, err)
// Verify the org name was applied.
appConfig, err := s.DS.AppConfig(ctx)
require.NoError(t, err)
assert.Equal(t, "Test Org", appConfig.OrgInfo.OrgName)
// Verify that no teams were created for a free license.
teams, err := s.DS.ListTeams(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}}, fleet.ListOptions{})
require.NoError(t, err)
assert.Empty(t, teams)
// Verify labels were created (global labels, team_id=0).
labelSpecs, err := s.DS.GetLabelSpecs(ctx, fleet.TeamFilter{User: &fleet.User{GlobalRole: ptr.String("admin")}})
require.NoError(t, err)
var customLabelNames []string
for _, l := range labelSpecs {
if l.LabelType != fleet.LabelTypeBuiltIn {
customLabelNames = append(customLabelNames, l.Name)
}
}
assert.Contains(t, customLabelNames, "Apple Silicon macOS hosts")
assert.Contains(t, customLabelNames, "ARM-based Windows hosts")
assert.Contains(t, customLabelNames, "Debian-based Linux hosts")
assert.Contains(t, customLabelNames, "x86-based Windows hosts")
}

View file

@ -9,10 +9,12 @@ import (
"encoding/xml"
"errors"
"fmt"
"maps"
"os"
"regexp"
"sort"
"strings"
"sync"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/platform/endpointer"
@ -23,6 +25,37 @@ import (
var yamlSeparator = regexp.MustCompile(`(?m:^---[\t ]*)`)
var (
envOverridesMu sync.RWMutex
envOverrides map[string]string
)
// SetEnvOverrides sets environment variable overrides that take precedence over
// os.LookupEnv during env expansion in GitOps file parsing. Pass nil to clear.
func SetEnvOverrides(overrides map[string]string) {
envOverridesMu.Lock()
defer envOverridesMu.Unlock()
if overrides == nil {
envOverrides = nil
return
}
envOverrides = make(map[string]string, len(overrides))
maps.Copy(envOverrides, overrides)
}
// lookupEnv checks env overrides first, then falls back to os.LookupEnv.
func lookupEnv(key string) (string, bool) {
envOverridesMu.RLock()
if envOverrides != nil {
if v, ok := envOverrides[key]; ok {
envOverridesMu.RUnlock()
return v, true
}
}
envOverridesMu.RUnlock()
return os.LookupEnv(key)
}
// Group holds a set of "specs" that can be applied to a Fleet server.
type Group struct {
Queries []*fleet.QuerySpec
@ -301,7 +334,7 @@ func expandEnv(s string, secretMode secretHandling) (string, error) {
switch secretMode {
case secretsExpand:
// Expand secrets for client-side validation
v, ok := os.LookupEnv(env)
v, ok := lookupEnv(env)
if ok {
if !documentIsXML {
return v, true
@ -334,7 +367,7 @@ func expandEnv(s string, secretMode secretHandling) (string, error) {
}
}
v, ok := os.LookupEnv(env)
v, ok := lookupEnv(env)
if !ok {
err = multierror.Append(err, fmt.Errorf("environment variable %q not set", env))
return "", false
@ -398,7 +431,7 @@ func LookupEnvSecrets(s string, secretsMap map[string]string) error {
_ = fleet.MaybeExpand(s, func(env string, startPos, endPos int) (string, bool) {
if strings.HasPrefix(env, fleet.ServerSecretPrefix) {
// lookup the secret and save it, but don't replace
v, ok := os.LookupEnv(env)
v, ok := lookupEnv(env)
if !ok {
err = multierror.Append(err, fmt.Errorf("environment variable %q not set", env))
return "", false

View file

@ -674,10 +674,9 @@ func (c *Client) ApplyGroup(
specs.AppConfig.(map[string]interface{})["yara_rules"] = rulePayloads
}
// Keep any existing GitOps mode config rather than attempting to set via GitOps.
if appconfig != nil {
specs.AppConfig.(map[string]any)["gitops"] = appconfig.GitOpsConfig
}
// GitOps mode config is managed server-side; remove it from the PATCH
// payload so the existing settings are preserved.
delete(specs.AppConfig.(map[string]any), "gitops")
if err := c.ApplyAppConfig(specs.AppConfig, opts.ApplySpecOptions); err != nil {
return nil, nil, nil, nil, fmt.Errorf("applying fleet config: %w", err)

View file

@ -2,17 +2,11 @@ package service
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/pkg/spec"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -20,11 +14,6 @@ import (
"github.com/go-kit/kit/endpoint"
)
const (
starterLibraryURL = "https://raw.githubusercontent.com/fleetdm/fleet/main/docs/01-Using-Fleet/starter-library/starter-library.yml"
scriptsBaseURL = "https://raw.githubusercontent.com/fleetdm/fleet/main/"
)
type setupRequest struct {
Admin *fleet.UserPayload `json:"admin"`
OrgInfo *fleet.OrgInfo `json:"org_info"`
@ -41,11 +30,9 @@ type setupResponse struct {
Err error `json:"error,omitempty"`
}
type applyGroupFunc func(context.Context, *spec.Group) error
func (r setupResponse) Error() error { return r.Err }
func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger) endpoint.Endpoint {
func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger, applyStarterLibrary func(ctx context.Context, serverURL, token string) error) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(setupRequest)
config := &fleet.AppConfig{}
@ -94,15 +81,7 @@ func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger) endpoint.Endpoint
// Apply starter library using the admin token we just created
if req.ServerURL != nil {
if err := ApplyStarterLibrary(
ctx,
*req.ServerURL,
session.Key,
logger,
fleethttp.NewClient,
NewClient,
nil, // No mock ApplyGroup for production code
); err != nil {
if err := applyStarterLibrary(ctx, *req.ServerURL, session.Key); err != nil {
logger.DebugContext(ctx, "setup apply starter library", "endpoint", "setup", "op", "applyStarterLibrary", "err", err)
// Continue even if there's an error applying the starter library
}
@ -120,309 +99,89 @@ func makeSetupEndpoint(svc fleet.Service, logger *slog.Logger) endpoint.Endpoint
}
}
// ApplyStarterLibrary downloads the starter library from GitHub
// and applies it to the Fleet server using an authenticated client.
// TODO: Move the apply starter library logic to use the serve command as an entry point to simplify and leverage the entire fleet.Service.
// Entry point: https://github.com/fleetdm/fleet/blob/2dfadc0971c6ba45c19dad2f5f1f4cd0f1b89b20/cmd/fleet/serve.go#L1099-L1100
// ApplyStarterLibrary scaffolds the starter GitOps templates via `fleetctl new`
// and applies them via `fleetctl gitops`, producing the same result as a user
// running those commands manually.
//
// The runFleetctl callback should run the fleetctl CLI with the given arguments.
// This keeps the CLI dependency out of the service package.
func ApplyStarterLibrary(
ctx context.Context,
serverURL string,
token string,
logger *slog.Logger,
httpClientFactory func(opts ...fleethttp.ClientOpt) *http.Client,
clientFactory func(serverURL string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error),
// For testing only - if provided, this function will be used instead of client.ApplyGroup
mockApplyGroup func(ctx context.Context, specs *spec.Group) error,
runFleetctl func(args []string) error,
) error {
logger.DebugContext(ctx, "Applying starter library")
// Create a request with context for downloading the starter library
req, err := http.NewRequestWithContext(ctx, http.MethodGet, starterLibraryURL, nil)
if err != nil {
return fmt.Errorf("failed to create request for starter library: %w", err)
}
// Download the starter library from GitHub using the provided HTTP client factory
httpClient := httpClientFactory(fleethttp.WithTimeout(5 * time.Second))
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download starter library: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download starter library, status: %d", resp.StatusCode)
}
buf, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read starter library response body: %w", err)
}
// Create a temporary directory to store downloaded scripts
tempDir, err := os.MkdirTemp("", "fleet-scripts-*")
if err != nil {
return fmt.Errorf("failed to create temporary directory: %w", err)
}
defer os.RemoveAll(tempDir) // Clean up the temporary directory when done
logger.DebugContext(ctx, "Created temporary directory for scripts", "path", tempDir)
// Parse the YAML content into specs
specs, err := spec.GroupFromBytes(buf)
if err != nil {
return fmt.Errorf("failed to parse starter library: %w", err)
}
// Find all script references in the YAML and download them
scriptNames := ExtractScriptNames(specs)
logger.DebugContext(ctx, "Found script references in starter library", "count", len(scriptNames))
// Download scripts and update references in specs
if len(scriptNames) > 0 {
err = DownloadAndUpdateScripts(ctx, specs, scriptNames, tempDir, logger)
if err != nil {
return fmt.Errorf("failed to download and update scripts: %w", err)
}
}
// Create an authenticated client and apply specs using the provided client factory
client, err := clientFactory(serverURL, true, "", "")
// Create an authenticated client to fetch app config.
client, err := NewClient(serverURL, true, "", "")
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}
client.SetToken(token)
// Always check if license is free and skip teams for free licenses
appConfig, err := client.GetAppConfig()
if err != nil {
logger.DebugContext(ctx, "Error getting app config", "err", err)
// Continue even if there's an error getting the app config
} else if appConfig.License == nil || !appConfig.License.IsPremium() {
// Remove teams from specs to avoid applying them
logger.DebugContext(ctx, "Free license detected, skipping teams and team-related content in starter library")
specs.Teams = nil
return fmt.Errorf("failed to get app config: %w", err)
}
// Filter out policies that reference teams
if specs.Policies != nil {
var filteredPolicies []*fleet.PolicySpec
for _, policy := range specs.Policies {
// Keep only policies that don't reference a team
if policy.Team == "" {
filteredPolicies = append(filteredPolicies, policy)
}
}
specs.Policies = filteredPolicies
orgName := appConfig.OrgInfo.OrgName
if orgName == "" {
orgName = "Fleet"
}
// Create a temp directory for the rendered templates.
tempDir, err := os.MkdirTemp("", "fleet-starter-*")
if err != nil {
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir)
outDir := filepath.Join(tempDir, "gitops")
// Render templates using `fleetctl new`.
if err := runFleetctl([]string{"new", "--org-name", orgName, "--dir", outDir}); err != nil {
return fmt.Errorf("fleetctl new: %w", err)
}
// Set env overrides so GitOpsFromFile can expand $FLEET_URL without
// polluting the process environment.
spec.SetEnvOverrides(map[string]string{
"FLEET_URL": serverURL,
})
defer spec.SetEnvOverrides(nil)
// Write a temporary fleetctl config file with auth credentials.
configFile, err := os.CreateTemp(tempDir, "fleetctl-config-*.yml")
if err != nil {
return fmt.Errorf("failed to create fleetctl config: %w", err)
}
fmt.Fprintf(configFile, "contexts:\n default:\n address: %s\n tls-skip-verify: true\n token: %s\n",
serverURL, token)
configFile.Close()
// Build the gitops args: global config first, then team configs (premium only).
args := []string{"gitops", "--config", configFile.Name(), "-f", filepath.Join(outDir, "default.yml")}
if appConfig.License != nil && appConfig.License.IsPremium() {
fleetDir := filepath.Join(outDir, "fleets")
entries, err := os.ReadDir(fleetDir)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to read fleets directory: %w", err)
}
// Note: QuerySpec doesn't have a Team field, so we can't filter queries by team
// Remove scripts from AppConfig if present
if specs.AppConfig != nil {
appConfigMap, ok := specs.AppConfig.(map[string]interface{})
if ok {
// Remove scripts from AppConfig
delete(appConfigMap, "scripts")
for _, entry := range entries {
if entry.IsDir() || filepath.Ext(entry.Name()) != ".yml" {
continue
}
args = append(args, "-f", filepath.Join(fleetDir, entry.Name()))
}
}
// Log function for ApplyGroup (minimal logging)
logf := func(format string, a ...interface{}) {}
// Assign the real implementation
var applyGroupFn applyGroupFunc = func(ctx context.Context, specs *spec.Group) error {
teamsSoftwareInstallers := make(map[string][]fleet.SoftwarePackageResponse)
teamsScripts := make(map[string][]fleet.ScriptResponse)
teamsVPPApps := make(map[string][]fleet.VPPAppResponse)
_, _, _, _, err := client.ApplyGroup(
ctx,
false,
specs,
tempDir,
logf,
nil,
fleet.ApplyClientSpecOptions{},
teamsSoftwareInstallers,
teamsVPPApps,
teamsScripts,
nil,
)
return err
}
// Apply mock if mockApplyGroup is supplied
if mockApplyGroup != nil {
applyGroupFn = mockApplyGroup
}
if err := applyGroupFn(ctx, specs); err != nil {
return fmt.Errorf("failed to apply starter library: %w", err)
if err := runFleetctl(args); err != nil {
return fmt.Errorf("fleetctl gitops: %w", err)
}
logger.DebugContext(ctx, "Starter library applied successfully")
return nil
}
// ExtractScriptNames extracts all script names from the specs
func ExtractScriptNames(specs *spec.Group) []string {
var scriptNames []string
scriptMap := make(map[string]bool) // Use a map to deduplicate script names
// Process team specs
for _, teamRaw := range specs.Teams {
var teamData map[string]interface{}
if err := json.Unmarshal(teamRaw, &teamData); err != nil {
continue // Skip if we can't unmarshal
}
if scripts, ok := teamData["scripts"].([]interface{}); ok {
for _, script := range scripts {
if scriptName, ok := script.(string); ok && !scriptMap[scriptName] {
scriptMap[scriptName] = true
scriptNames = append(scriptNames, scriptName)
}
}
}
}
return scriptNames
}
// DownloadAndUpdateScripts downloads scripts from URLs and updates the specs to reference local files
func DownloadAndUpdateScripts(ctx context.Context, specs *spec.Group, scriptNames []string, tempDir string, logger *slog.Logger) error {
// Create a single HTTP client to be reused for all requests
httpClient := fleethttp.NewClient(fleethttp.WithTimeout(5 * time.Second))
// Map to store local paths for each script
scriptPaths := make(map[string]string, len(scriptNames))
// Download each script sequentially
for _, scriptName := range scriptNames {
// Sanitize the script name to prevent path traversal
sanitizedName := filepath.Clean(scriptName)
if strings.HasPrefix(sanitizedName, "..") || filepath.IsAbs(sanitizedName) {
return fmt.Errorf("invalid script name %s: must be a relative path", scriptName)
}
localPath := filepath.Join(tempDir, sanitizedName)
scriptPaths[scriptName] = localPath
// Create parent directories if they don't exist
parentDir := filepath.Dir(localPath)
if err := os.MkdirAll(parentDir, 0o755); err != nil {
return fmt.Errorf("failed to create parent directories for script %s: %w", scriptName, err)
}
scriptURL := fmt.Sprintf("%s/%s", scriptsBaseURL, scriptName)
logger.DebugContext(ctx, "Downloading script", "name", scriptName, "url", scriptURL, "local_path", localPath)
// Create the request with context
req, err := http.NewRequestWithContext(ctx, http.MethodGet, scriptURL, nil)
if err != nil {
return fmt.Errorf("failed to create request for script %s: %w", scriptName, err)
}
// Download the script using the shared HTTP client
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download script %s: %w", scriptName, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download script %s, status: %d", scriptName, resp.StatusCode)
}
// Create the local file
file, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create local file for script %s: %w", scriptName, err)
}
defer file.Close()
// Copy the content to the local file
_, err = io.Copy(file, resp.Body)
if err != nil {
return fmt.Errorf("failed to write script %s to local file: %w", scriptName, err)
}
}
// Read script contents and store them in memory
scriptContents := make(map[string][]byte, len(scriptNames))
for _, scriptName := range scriptNames {
localPath := scriptPaths[scriptName]
content, err := os.ReadFile(localPath)
if err != nil {
return fmt.Errorf("failed to read script %s from local file: %w", scriptName, err)
}
scriptContents[scriptName] = content
}
// Extract scripts from AppConfig if present
appConfigScripts := extractAppCfgScripts(specs.AppConfig)
if appConfigScripts != nil {
// Replace script paths with actual script contents
appScripts := make([]string, 0, len(appConfigScripts))
for _, scriptPath := range appConfigScripts {
if content, exists := scriptContents[scriptPath]; exists {
// Create a temporary file with the script content
tempFile, err := os.CreateTemp(tempDir, "script-*")
if err != nil {
return fmt.Errorf("failed to create temporary script file: %w", err)
}
if _, err := tempFile.Write(content); err != nil {
tempFile.Close()
return fmt.Errorf("failed to write script content to temporary file: %w", err)
}
tempFile.Close()
// Add the temporary file path to the list
appScripts = append(appScripts, tempFile.Name())
} else {
// Keep the original path if it's not one of our downloaded scripts
appScripts = append(appScripts, scriptPath)
}
}
// Update the AppConfig with the new script paths
if specs.AppConfig != nil {
specs.AppConfig.(map[string]interface{})["scripts"] = appScripts
}
}
// Update script references in the specs to point to local files
for i, teamRaw := range specs.Teams {
var teamData map[string]interface{}
if err := json.Unmarshal(teamRaw, &teamData); err != nil {
continue // Skip if we can't unmarshal
}
if scripts, ok := teamData["scripts"].([]interface{}); ok {
for j, script := range scripts {
if scriptName, ok := script.(string); ok {
// Update the script reference to the local path from our map
if localPath, exists := scriptPaths[scriptName]; exists {
scripts[j] = localPath
}
}
}
// Update the team data with modified scripts
teamData["scripts"] = scripts
// Marshal back to JSON
updatedTeamRaw, err := json.Marshal(teamData)
if err != nil {
logger.DebugContext(ctx, "Failed to marshal updated team data", "err", err)
continue
}
// Update the team in the specs
specs.Teams[i] = updatedTeamRaw
}
}
return nil
}

View file

@ -1,712 +0,0 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/pkg/spec"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockRoundTripper is a custom http.RoundTripper that redirects requests to a mock server
type mockRoundTripper struct {
mockServer string
origBaseURL string
next http.RoundTripper
}
func (rt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// If the request URL contains the original base URL, replace it with the mock server URL
if strings.Contains(req.URL.String(), rt.origBaseURL) {
// Extract the path from the original URL
path := strings.TrimPrefix(req.URL.Path, "/")
// Create a new URL with the mock server
newURL := fmt.Sprintf("%s/%s", rt.mockServer, path)
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
if err != nil {
return nil, err
}
// Copy headers
newReq.Header = req.Header
// Use the next transport to perform the request
return rt.next.RoundTrip(newReq)
}
// For other requests, use the next transport
return rt.next.RoundTrip(req)
}
// Helper function to create HTTP responses for testing
func createTestResponse(code int, body string) *http.Response {
return &http.Response{
StatusCode: code,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}
}
// testRoundTripper2 is a mock implementation of http.RoundTripper for tracking URL calls
type testRoundTripper2 struct {
RoundTripFunc func(req *http.Request) (*http.Response, error)
calls []string // Track URLs that were called
}
func (m *testRoundTripper2) RoundTrip(req *http.Request) (*http.Response, error) {
m.calls = append(m.calls, req.URL.String())
return m.RoundTripFunc(req)
}
func TestExtractScriptNames(t *testing.T) {
tests := []struct {
name string
teams []map[string]interface{}
expected []string
}{
{
name: "multiple teams with scripts",
teams: []map[string]interface{}{
{
"name": "Team1",
"scripts": []interface{}{"script1.sh", "script2.sh"},
},
{
"name": "Team2",
"scripts": []interface{}{"script2.sh", "script3.sh"}, // Note: script2.sh is duplicated
},
{
"name": "Team3", // No scripts
},
},
expected: []string{"script1.sh", "script2.sh", "script3.sh"},
},
{
name: "no teams",
teams: []map[string]interface{}{},
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create spec group from test data
var teams []json.RawMessage
for _, team := range tt.teams {
teamRaw, err := json.Marshal(team)
require.NoError(t, err)
teams = append(teams, teamRaw)
}
specs := &spec.Group{Teams: teams}
// Call the function
scriptNames := ExtractScriptNames(specs)
// Verify the results
assert.Len(t, scriptNames, len(tt.expected))
for _, name := range tt.expected {
assert.Contains(t, scriptNames, name)
}
})
}
}
func TestDownloadAndUpdateScripts(t *testing.T) {
tests := []struct {
name string
scriptNames []string
scriptPaths []string
}{
{
name: "single script",
scriptNames: []string{"test-script.sh"},
scriptPaths: []string{"test-script.sh"},
},
{
name: "multiple scripts with nested path",
scriptNames: []string{"test-script.sh", "subfolder/nested-script.sh"},
scriptPaths: []string{"test-script.sh", "subfolder/nested-script.sh"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a mock server to serve the scripts
scriptContent := "#!/bin/bash\necho 'Hello, World!'"
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(scriptContent))
}))
defer mockServer.Close()
// Save the original HTTP transport
origTransport := http.DefaultTransport
// Create a custom transport that redirects requests to our mock server
mockTransport := &mockRoundTripper{
mockServer: mockServer.URL,
origBaseURL: scriptsBaseURL,
next: origTransport,
}
// Replace the default transport with our mock transport
http.DefaultTransport = mockTransport
// Restore the original transport when the test is done
defer func() {
http.DefaultTransport = origTransport
}()
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a test spec group
teamData := map[string]interface{}{
"name": "Team1",
"scripts": []interface{}{tt.scriptNames[0]},
}
teamRaw, err := json.Marshal(teamData)
require.NoError(t, err)
specs := &spec.Group{
Teams: []json.RawMessage{teamRaw},
}
// Call the actual production function
err = DownloadAndUpdateScripts(context.Background(), specs, tt.scriptNames, tempDir, slog.New(slog.DiscardHandler))
require.NoError(t, err)
// Verify the scripts were downloaded
for _, scriptName := range tt.scriptPaths {
scriptPath := filepath.Join(tempDir, scriptName)
_, err := os.Stat(scriptPath)
assert.NoError(t, err, "Script should exist: %s", scriptPath)
// Verify the content
content, err := os.ReadFile(scriptPath)
require.NoError(t, err)
assert.Equal(t, scriptContent, string(content))
}
// Verify the specs were updated
var updatedTeamData map[string]interface{}
err = json.Unmarshal(specs.Teams[0], &updatedTeamData)
require.NoError(t, err)
updatedScripts, ok := updatedTeamData["scripts"].([]interface{})
require.True(t, ok)
// The scripts should now be local paths
for i, script := range updatedScripts {
scriptPath, ok := script.(string)
require.True(t, ok)
assert.Contains(t, scriptPath, tempDir)
assert.Contains(t, scriptPath, tt.scriptNames[i])
}
})
}
}
func TestDownloadAndUpdateScriptsWithInvalidPaths(t *testing.T) {
tests := []struct {
name string
scriptNames []string
errorMsg string
}{
{
name: "path traversal attempt",
scriptNames: []string{"../test-script.sh"},
errorMsg: "invalid script name",
},
{
name: "absolute path attempt",
scriptNames: []string{"/etc/passwd"},
errorMsg: "invalid script name",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a mock server to serve the scripts
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("script content"))
}))
defer mockServer.Close()
// Save the original HTTP transport
origTransport := http.DefaultTransport
// Create a custom transport that redirects requests to our mock server
mockTransport := &mockRoundTripper{
mockServer: mockServer.URL,
origBaseURL: scriptsBaseURL,
next: origTransport,
}
// Replace the default transport with our mock transport
http.DefaultTransport = mockTransport
// Restore the original transport when the test is done
defer func() {
http.DefaultTransport = origTransport
}()
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a test spec group
teamData := map[string]interface{}{
"name": "Team1",
"scripts": []interface{}{tt.scriptNames[0]},
}
teamRaw, err := json.Marshal(teamData)
require.NoError(t, err)
specs := &spec.Group{
Teams: []json.RawMessage{teamRaw},
}
// Call the actual production function
err = DownloadAndUpdateScripts(context.Background(), specs, tt.scriptNames, tempDir, slog.New(slog.DiscardHandler))
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
})
}
}
func TestFleetHTTPClientOverride(t *testing.T) {
// Create a mock server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test response"))
}))
defer mockServer.Close()
// Create a custom client that uses our mock server
client := fleethttp.NewClient()
client.Transport = &mockRoundTripper{
mockServer: mockServer.URL,
origBaseURL: scriptsBaseURL,
next: http.DefaultTransport,
}
// Create a fleethttp client with our custom transport
fleetClient := fleethttp.NewClient()
// Replace the client's transport with our mock transport
fleetClient.Transport = client.Transport
// Create a request to the original URL
req, err := http.NewRequest("GET", scriptsBaseURL+"/test.sh", nil)
require.NoError(t, err)
// Send the request
resp, err := fleetClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Verify the response
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "test response", string(body))
}
func TestDownloadAndUpdateScriptsTimeout(t *testing.T) {
tests := []struct {
name string
sleepTime time.Duration
contextTime time.Duration
expectError bool
}{
{
name: "timeout occurs",
sleepTime: 6 * time.Second, // Server sleeps longer than timeout
contextTime: 2 * time.Second, // Short context timeout
expectError: true,
},
{
name: "no timeout",
sleepTime: 1 * time.Second, // Server responds quickly
contextTime: 5 * time.Second, // Longer context timeout
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Skip the long-running test in short mode
if tt.sleepTime > 2*time.Second && testing.Short() {
t.Skip("Skipping long-running test in short mode")
}
// Create a mock server that delays its response
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(tt.sleepTime)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test response"))
}))
defer mockServer.Close()
// Save the original HTTP transport
origTransport := http.DefaultTransport
// Create a custom transport that redirects requests to our mock server
mockTransport := &mockRoundTripper{
mockServer: mockServer.URL,
origBaseURL: scriptsBaseURL,
next: origTransport,
}
// Replace the default transport with our mock transport
http.DefaultTransport = mockTransport
// Restore the original transport when the test is done
defer func() {
http.DefaultTransport = origTransport
}()
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a test spec group
teamData := map[string]interface{}{
"name": "Team1",
"scripts": []interface{}{"test-script.sh"},
}
teamRaw, err := json.Marshal(teamData)
require.NoError(t, err)
specs := &spec.Group{
Teams: []json.RawMessage{teamRaw},
}
// Define script names
scriptNames := []string{"test-script.sh"}
// Set context timeout
ctx, cancel := context.WithTimeout(context.Background(), tt.contextTime)
defer cancel()
// Call the actual production function
err = DownloadAndUpdateScripts(ctx, specs, scriptNames, tempDir, slog.New(slog.DiscardHandler))
if tt.expectError {
require.Error(t, err)
// The error message might vary depending on the HTTP client implementation,
// but it should contain either "timeout", "deadline exceeded", or "context canceled"
errorMsg := strings.ToLower(err.Error())
timeoutRelated := strings.Contains(errorMsg, "timeout") ||
strings.Contains(errorMsg, "deadline exceeded") ||
strings.Contains(errorMsg, "context canceled")
assert.True(t, timeoutRelated, "Expected a timeout-related error, got: %s", err.Error())
} else {
require.NoError(t, err)
}
})
}
}
func TestApplyStarterLibraryWithMockClient(t *testing.T) {
// Read the real production starter library YAML file
starterLibraryPath := "../../docs/01-Using-Fleet/starter-library/starter-library.yml"
starterLibraryContent, err := os.ReadFile(starterLibraryPath)
require.NoError(t, err, "Should be able to read starter library YAML file")
// Create mock HTTP client for downloading the starter library and scripts
mockRT := &testRoundTripper2{
calls: []string{},
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.String() == starterLibraryURL:
// Return the real starter library content
return createTestResponse(200, string(starterLibraryContent)), nil
case strings.Contains(req.URL.String(), "uninstall-fleetd"):
// Return a simple script for any script URL
return createTestResponse(200, "#!/bin/bash\necho ok"), nil
default:
// For any other URL, return a 404
return createTestResponse(404, "Not found"), nil
}
},
}
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
client := fleethttp.NewClient(opts...)
client.Transport = mockRT
return client
}
// Create a client factory that returns a real client
// We're not testing the token setting functionality
clientFactory := NewClient
// Track if ApplyGroup was called and capture the specs
applyGroupCalled := false
var capturedSpecs *spec.Group
// Create a mock ApplyGroup function
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
applyGroupCalled = true
capturedSpecs = specs
return nil
}
// Call the function under test
testErr := ApplyStarterLibrary(
context.Background(),
"https://example.com",
"test-token",
slog.New(slog.DiscardHandler),
httpClientFactory,
clientFactory,
mockApplyGroup,
)
// Verify results
require.NoError(t, testErr)
assert.True(t, applyGroupCalled, "ApplyGroup should have been called")
// Verify that the specs were correctly parsed
require.NotNil(t, capturedSpecs, "Specs should not be nil")
require.NotEmpty(t, capturedSpecs.Teams, "Specs should contain teams")
// Verify that the first team has the expected structure
var team1 map[string]interface{}
unmarshalErr := json.Unmarshal(capturedSpecs.Teams[0], &team1)
require.NoError(t, unmarshalErr, "Should be able to unmarshal team JSON")
// Verify that the team has a name
require.Contains(t, team1, "name", "Team should have a name")
teamName := team1["name"].(string)
require.NotEmpty(t, teamName, "Team name should not be empty")
// Verify that the starter library URL was requested
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
}
func TestApplyStarterLibraryWithMalformedYAML(t *testing.T) {
// Create mock HTTP client that returns malformed YAML
malformedYAML := `
teams:
- name: "Malformed Team
# Missing closing quote and improper indentation
scripts:
- "script1.sh
`
mockRT := &testRoundTripper2{
calls: []string{},
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.String() == starterLibraryURL:
// Return malformed YAML content
return createTestResponse(200, malformedYAML), nil
default:
// For any other URL, return a 404
return createTestResponse(404, "Not found"), nil
}
},
}
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
client := fleethttp.NewClient(opts...)
client.Transport = mockRT
return client
}
// Create a client factory that returns a real client
clientFactory := NewClient
// Create a mock ApplyGroup function that should not be called
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
t.Error("ApplyGroup should not be called with malformed YAML")
return nil
}
// Use a defer/recover to explicitly catch any panics
var panicValue interface{}
defer func() {
if r := recover(); r != nil {
panicValue = r
t.Fatalf("Panic occurred when processing malformed YAML: %v", panicValue)
}
}()
// Call the function under test
testErr := ApplyStarterLibrary(
context.Background(),
"https://example.com",
"test-token",
slog.New(slog.DiscardHandler),
httpClientFactory,
clientFactory,
mockApplyGroup,
)
// Verify results
require.Error(t, testErr, "Should return an error with malformed YAML")
assert.Contains(t, testErr.Error(), "failed to parse starter library",
"Error should indicate YAML parsing failure")
// Verify that the starter library URL was requested
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
// If we reach here, no panic occurred and the setup flow was not interrupted
}
func TestApplyStarterLibraryWithFreeLicense(t *testing.T) {
// Read the real production starter library YAML file
starterLibraryPath := "../../docs/01-Using-Fleet/starter-library/starter-library.yml"
starterLibraryContent, err := os.ReadFile(starterLibraryPath)
require.NoError(t, err, "Should be able to read starter library YAML file")
// Create mock HTTP client for downloading the starter library and scripts
mockRT := &testRoundTripper2{
calls: []string{},
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.String() == starterLibraryURL:
// Return the real starter library content
return createTestResponse(200, string(starterLibraryContent)), nil
case strings.Contains(req.URL.String(), "uninstall-fleetd"):
// Return a simple script for any script URL
return createTestResponse(200, "#!/bin/bash\necho ok"), nil
default:
// For any other URL, return a 404
return createTestResponse(404, "Not found"), nil
}
},
}
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
client := fleethttp.NewClient(opts...)
client.Transport = mockRT
return client
}
// Create a mock client that returns a free license
// Create a properly structured EnrichedAppConfig
mockEnrichedAppConfig := &fleet.EnrichedAppConfig{}
// Set the License field using json marshaling/unmarshaling to bypass unexported field access
configJSON := []byte(`{"license":{"tier":"free"}}`)
if err := json.Unmarshal(configJSON, mockEnrichedAppConfig); err != nil {
t.Fatal("Failed to unmarshal mock config:", err)
}
// Create a mock client factory
clientFactory := func(serverURL string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) {
mockClient := &Client{}
// Override the baseClient with a mock implementation
// Create a mock HTTP client
mockHTTPClient := &mockHTTPClient{
DoFunc: func(req *http.Request) (*http.Response, error) {
// Mock the GetAppConfig response
if req.URL.Path == "/api/v1/fleet/config" && req.Method == http.MethodGet {
respBody, _ := json.Marshal(mockEnrichedAppConfig)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBuffer(respBody)),
Header: make(http.Header),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("{}")),
Header: make(http.Header),
}, nil
},
}
// Set up the baseClient with the mock HTTP client and a valid baseURL
baseURL, _ := url.Parse(serverURL)
mockClient.baseClient = &baseClient{
HTTP: mockHTTPClient,
BaseURL: baseURL,
}
return mockClient, nil
}
// Track if ApplyGroup was called and capture the specs
applyGroupCalled := false
var capturedSpecs *spec.Group
// Create a mock ApplyGroup function
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
applyGroupCalled = true
capturedSpecs = specs
return nil
}
// Call the function under test - teams will be skipped automatically for free license
testErr := ApplyStarterLibrary(
context.Background(),
"https://example.com",
"test-token",
slog.New(slog.DiscardHandler),
httpClientFactory,
clientFactory,
mockApplyGroup,
)
// Verify results
require.NoError(t, testErr)
assert.True(t, applyGroupCalled, "ApplyGroup should have been called")
// Verify that the specs were correctly parsed
require.NotNil(t, capturedSpecs, "Specs should not be nil")
// Verify that teams were removed
require.Empty(t, capturedSpecs.Teams, "Teams should be empty for free license")
// Verify that policies referencing teams were filtered out
if capturedSpecs.Policies != nil {
for _, policy := range capturedSpecs.Policies {
assert.Empty(t, policy.Team, "Policies should not reference teams for free license")
}
}
// Verify that scripts were removed from AppConfig
if capturedSpecs.AppConfig != nil {
appConfigMap, ok := capturedSpecs.AppConfig.(map[string]interface{})
if ok {
_, hasScripts := appConfigMap["scripts"]
assert.False(t, hasScripts, "AppConfig should not contain scripts for free license")
}
}
// Verify that the starter library URL was requested
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
}
// mockHTTPClient is a mock implementation of the http.Client
type mockHTTPClient struct {
DoFunc func(req *http.Request) (*http.Response, error)
}
func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
return m.DoFunc(req)
}

View file

@ -1120,12 +1120,12 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
// WithSetup is an http middleware that checks if setup procedures have been completed.
// If setup hasn't been completed it serves the API with a setup middleware.
// If the server is already configured, the default API handler is exposed.
func WithSetup(svc fleet.Service, logger *slog.Logger, next http.Handler) http.HandlerFunc {
func WithSetup(svc fleet.Service, logger *slog.Logger, applyStarterLibrary func(ctx context.Context, serverURL, token string) error, next http.Handler) http.HandlerFunc {
rxOsquery := regexp.MustCompile(`^/api/[^/]+/osquery`)
return func(w http.ResponseWriter, r *http.Request) {
configRouter := http.NewServeMux()
srv := kithttp.NewServer(
makeSetupEndpoint(svc, logger),
makeSetupEndpoint(svc, logger, applyStarterLibrary),
decodeSetupRequest,
encodeResponse,
)

View file

@ -16,6 +16,7 @@ import (
"net/http/httptest"
"os"
"sort"
"strings"
"sync"
"testing"
"time"
@ -1541,3 +1542,24 @@ type acmeCSRSigner struct {
func (a *acmeCSRSigner) SignCSR(_ context.Context, csr *x509.CertificateRequest) (*x509.Certificate, error) {
return a.signer.Signx509CSR(csr)
}
// mockRoundTripper is a custom http.RoundTripper that redirects requests to a mock server.
type mockRoundTripper struct {
mockServer string
origBaseURL string
next http.RoundTripper
}
func (rt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if strings.Contains(req.URL.String(), rt.origBaseURL) {
path := strings.TrimPrefix(req.URL.Path, "/")
newURL := fmt.Sprintf("%s/%s", rt.mockServer, path)
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body) //nolint:gosec // test helper, URL is from mock server
if err != nil {
return nil, err
}
newReq.Header = req.Header
return rt.next.RoundTrip(newReq)
}
return rt.next.RoundTrip(req)
}