mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
712 lines
21 KiB
Go
712 lines
21 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
|
"github.com/fleetdm/fleet/v4/pkg/spec"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
kitlog "github.com/go-kit/log"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// mockRoundTripper is a custom http.RoundTripper that redirects requests to a mock server
|
|
type mockRoundTripper struct {
|
|
mockServer string
|
|
origBaseURL string
|
|
next http.RoundTripper
|
|
}
|
|
|
|
func (rt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// If the request URL contains the original base URL, replace it with the mock server URL
|
|
if strings.Contains(req.URL.String(), rt.origBaseURL) {
|
|
// Extract the path from the original URL
|
|
path := strings.TrimPrefix(req.URL.Path, "/")
|
|
|
|
// Create a new URL with the mock server
|
|
newURL := fmt.Sprintf("%s/%s", rt.mockServer, path)
|
|
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Copy headers
|
|
newReq.Header = req.Header
|
|
|
|
// Use the next transport to perform the request
|
|
return rt.next.RoundTrip(newReq)
|
|
}
|
|
|
|
// For other requests, use the next transport
|
|
return rt.next.RoundTrip(req)
|
|
}
|
|
|
|
// Helper function to create HTTP responses for testing
|
|
func createTestResponse(code int, body string) *http.Response {
|
|
return &http.Response{
|
|
StatusCode: code,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Header: make(http.Header),
|
|
}
|
|
}
|
|
|
|
// testRoundTripper2 is a mock implementation of http.RoundTripper for tracking URL calls
|
|
type testRoundTripper2 struct {
|
|
RoundTripFunc func(req *http.Request) (*http.Response, error)
|
|
calls []string // Track URLs that were called
|
|
}
|
|
|
|
func (m *testRoundTripper2) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
m.calls = append(m.calls, req.URL.String())
|
|
return m.RoundTripFunc(req)
|
|
}
|
|
|
|
func TestExtractScriptNames(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
teams []map[string]interface{}
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "multiple teams with scripts",
|
|
teams: []map[string]interface{}{
|
|
{
|
|
"name": "Team1",
|
|
"scripts": []interface{}{"script1.sh", "script2.sh"},
|
|
},
|
|
{
|
|
"name": "Team2",
|
|
"scripts": []interface{}{"script2.sh", "script3.sh"}, // Note: script2.sh is duplicated
|
|
},
|
|
{
|
|
"name": "Team3", // No scripts
|
|
},
|
|
},
|
|
expected: []string{"script1.sh", "script2.sh", "script3.sh"},
|
|
},
|
|
{
|
|
name: "no teams",
|
|
teams: []map[string]interface{}{},
|
|
expected: []string{},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create spec group from test data
|
|
var teams []json.RawMessage
|
|
for _, team := range tt.teams {
|
|
teamRaw, err := json.Marshal(team)
|
|
require.NoError(t, err)
|
|
teams = append(teams, teamRaw)
|
|
}
|
|
specs := &spec.Group{Teams: teams}
|
|
|
|
// Call the function
|
|
scriptNames := ExtractScriptNames(specs)
|
|
|
|
// Verify the results
|
|
assert.Len(t, scriptNames, len(tt.expected))
|
|
for _, name := range tt.expected {
|
|
assert.Contains(t, scriptNames, name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDownloadAndUpdateScripts(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
scriptNames []string
|
|
scriptPaths []string
|
|
}{
|
|
{
|
|
name: "single script",
|
|
scriptNames: []string{"test-script.sh"},
|
|
scriptPaths: []string{"test-script.sh"},
|
|
},
|
|
{
|
|
name: "multiple scripts with nested path",
|
|
scriptNames: []string{"test-script.sh", "subfolder/nested-script.sh"},
|
|
scriptPaths: []string{"test-script.sh", "subfolder/nested-script.sh"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create a mock server to serve the scripts
|
|
scriptContent := "#!/bin/bash\necho 'Hello, World!'"
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(scriptContent))
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
// Save the original HTTP transport
|
|
origTransport := http.DefaultTransport
|
|
|
|
// Create a custom transport that redirects requests to our mock server
|
|
mockTransport := &mockRoundTripper{
|
|
mockServer: mockServer.URL,
|
|
origBaseURL: scriptsBaseURL,
|
|
next: origTransport,
|
|
}
|
|
|
|
// Replace the default transport with our mock transport
|
|
http.DefaultTransport = mockTransport
|
|
|
|
// Restore the original transport when the test is done
|
|
defer func() {
|
|
http.DefaultTransport = origTransport
|
|
}()
|
|
|
|
// Create a temporary directory
|
|
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
// Create a test spec group
|
|
teamData := map[string]interface{}{
|
|
"name": "Team1",
|
|
"scripts": []interface{}{tt.scriptNames[0]},
|
|
}
|
|
teamRaw, err := json.Marshal(teamData)
|
|
require.NoError(t, err)
|
|
|
|
specs := &spec.Group{
|
|
Teams: []json.RawMessage{teamRaw},
|
|
}
|
|
|
|
// Call the actual production function
|
|
err = DownloadAndUpdateScripts(context.Background(), specs, tt.scriptNames, tempDir, kitlog.NewNopLogger())
|
|
require.NoError(t, err)
|
|
|
|
// Verify the scripts were downloaded
|
|
for _, scriptName := range tt.scriptPaths {
|
|
scriptPath := filepath.Join(tempDir, scriptName)
|
|
_, err := os.Stat(scriptPath)
|
|
assert.NoError(t, err, "Script should exist: %s", scriptPath)
|
|
|
|
// Verify the content
|
|
content, err := os.ReadFile(scriptPath)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, scriptContent, string(content))
|
|
}
|
|
|
|
// Verify the specs were updated
|
|
var updatedTeamData map[string]interface{}
|
|
err = json.Unmarshal(specs.Teams[0], &updatedTeamData)
|
|
require.NoError(t, err)
|
|
|
|
updatedScripts, ok := updatedTeamData["scripts"].([]interface{})
|
|
require.True(t, ok)
|
|
|
|
// The scripts should now be local paths
|
|
for i, script := range updatedScripts {
|
|
scriptPath, ok := script.(string)
|
|
require.True(t, ok)
|
|
assert.Contains(t, scriptPath, tempDir)
|
|
assert.Contains(t, scriptPath, tt.scriptNames[i])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDownloadAndUpdateScriptsWithInvalidPaths(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
scriptNames []string
|
|
errorMsg string
|
|
}{
|
|
{
|
|
name: "path traversal attempt",
|
|
scriptNames: []string{"../test-script.sh"},
|
|
errorMsg: "invalid script name",
|
|
},
|
|
{
|
|
name: "absolute path attempt",
|
|
scriptNames: []string{"/etc/passwd"},
|
|
errorMsg: "invalid script name",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create a mock server to serve the scripts
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("script content"))
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
// Save the original HTTP transport
|
|
origTransport := http.DefaultTransport
|
|
|
|
// Create a custom transport that redirects requests to our mock server
|
|
mockTransport := &mockRoundTripper{
|
|
mockServer: mockServer.URL,
|
|
origBaseURL: scriptsBaseURL,
|
|
next: origTransport,
|
|
}
|
|
|
|
// Replace the default transport with our mock transport
|
|
http.DefaultTransport = mockTransport
|
|
|
|
// Restore the original transport when the test is done
|
|
defer func() {
|
|
http.DefaultTransport = origTransport
|
|
}()
|
|
|
|
// Create a temporary directory
|
|
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
// Create a test spec group
|
|
teamData := map[string]interface{}{
|
|
"name": "Team1",
|
|
"scripts": []interface{}{tt.scriptNames[0]},
|
|
}
|
|
teamRaw, err := json.Marshal(teamData)
|
|
require.NoError(t, err)
|
|
|
|
specs := &spec.Group{
|
|
Teams: []json.RawMessage{teamRaw},
|
|
}
|
|
|
|
// Call the actual production function
|
|
err = DownloadAndUpdateScripts(context.Background(), specs, tt.scriptNames, tempDir, kitlog.NewNopLogger())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFleetHTTPClientOverride(t *testing.T) {
|
|
// Create a mock server
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("test response"))
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
// Create a custom client that uses our mock server
|
|
client := fleethttp.NewClient()
|
|
client.Transport = &mockRoundTripper{
|
|
mockServer: mockServer.URL,
|
|
origBaseURL: scriptsBaseURL,
|
|
next: http.DefaultTransport,
|
|
}
|
|
|
|
// Create a fleethttp client with our custom transport
|
|
fleetClient := fleethttp.NewClient()
|
|
// Replace the client's transport with our mock transport
|
|
fleetClient.Transport = client.Transport
|
|
|
|
// Create a request to the original URL
|
|
req, err := http.NewRequest("GET", scriptsBaseURL+"/test.sh", nil)
|
|
require.NoError(t, err)
|
|
|
|
// Send the request
|
|
resp, err := fleetClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
// Verify the response
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test response", string(body))
|
|
}
|
|
|
|
func TestDownloadAndUpdateScriptsTimeout(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sleepTime time.Duration
|
|
contextTime time.Duration
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "timeout occurs",
|
|
sleepTime: 6 * time.Second, // Server sleeps longer than timeout
|
|
contextTime: 2 * time.Second, // Short context timeout
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "no timeout",
|
|
sleepTime: 1 * time.Second, // Server responds quickly
|
|
contextTime: 5 * time.Second, // Longer context timeout
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Skip the long-running test in short mode
|
|
if tt.sleepTime > 2*time.Second && testing.Short() {
|
|
t.Skip("Skipping long-running test in short mode")
|
|
}
|
|
|
|
// Create a mock server that delays its response
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(tt.sleepTime)
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("test response"))
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
// Save the original HTTP transport
|
|
origTransport := http.DefaultTransport
|
|
|
|
// Create a custom transport that redirects requests to our mock server
|
|
mockTransport := &mockRoundTripper{
|
|
mockServer: mockServer.URL,
|
|
origBaseURL: scriptsBaseURL,
|
|
next: origTransport,
|
|
}
|
|
|
|
// Replace the default transport with our mock transport
|
|
http.DefaultTransport = mockTransport
|
|
|
|
// Restore the original transport when the test is done
|
|
defer func() {
|
|
http.DefaultTransport = origTransport
|
|
}()
|
|
|
|
// Create a temporary directory
|
|
tempDir, err := os.MkdirTemp("", "fleet-test-scripts-*")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
// Create a test spec group
|
|
teamData := map[string]interface{}{
|
|
"name": "Team1",
|
|
"scripts": []interface{}{"test-script.sh"},
|
|
}
|
|
teamRaw, err := json.Marshal(teamData)
|
|
require.NoError(t, err)
|
|
|
|
specs := &spec.Group{
|
|
Teams: []json.RawMessage{teamRaw},
|
|
}
|
|
|
|
// Define script names
|
|
scriptNames := []string{"test-script.sh"}
|
|
|
|
// Set context timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), tt.contextTime)
|
|
defer cancel()
|
|
|
|
// Call the actual production function
|
|
err = DownloadAndUpdateScripts(ctx, specs, scriptNames, tempDir, kitlog.NewNopLogger())
|
|
|
|
if tt.expectError {
|
|
require.Error(t, err)
|
|
// The error message might vary depending on the HTTP client implementation,
|
|
// but it should contain either "timeout", "deadline exceeded", or "context canceled"
|
|
errorMsg := strings.ToLower(err.Error())
|
|
timeoutRelated := strings.Contains(errorMsg, "timeout") ||
|
|
strings.Contains(errorMsg, "deadline exceeded") ||
|
|
strings.Contains(errorMsg, "context canceled")
|
|
assert.True(t, timeoutRelated, "Expected a timeout-related error, got: %s", err.Error())
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestApplyStarterLibraryWithMockClient(t *testing.T) {
|
|
// Read the real production starter library YAML file
|
|
starterLibraryPath := "../../docs/01-Using-Fleet/starter-library/starter-library.yml"
|
|
starterLibraryContent, err := os.ReadFile(starterLibraryPath)
|
|
require.NoError(t, err, "Should be able to read starter library YAML file")
|
|
|
|
// Create mock HTTP client for downloading the starter library and scripts
|
|
mockRT := &testRoundTripper2{
|
|
calls: []string{},
|
|
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
|
|
switch {
|
|
case req.URL.String() == starterLibraryURL:
|
|
// Return the real starter library content
|
|
return createTestResponse(200, string(starterLibraryContent)), nil
|
|
case strings.Contains(req.URL.String(), "uninstall-fleetd"):
|
|
// Return a simple script for any script URL
|
|
return createTestResponse(200, "#!/bin/bash\necho ok"), nil
|
|
default:
|
|
// For any other URL, return a 404
|
|
return createTestResponse(404, "Not found"), nil
|
|
}
|
|
},
|
|
}
|
|
|
|
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
|
|
client := fleethttp.NewClient(opts...)
|
|
client.Transport = mockRT
|
|
return client
|
|
}
|
|
|
|
// Create a client factory that returns a real client
|
|
// We're not testing the token setting functionality
|
|
clientFactory := NewClient
|
|
|
|
// Track if ApplyGroup was called and capture the specs
|
|
applyGroupCalled := false
|
|
var capturedSpecs *spec.Group
|
|
// Create a mock ApplyGroup function
|
|
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
|
|
applyGroupCalled = true
|
|
capturedSpecs = specs
|
|
return nil
|
|
}
|
|
|
|
// Call the function under test
|
|
testErr := ApplyStarterLibrary(
|
|
context.Background(),
|
|
"https://example.com",
|
|
"test-token",
|
|
kitlog.NewNopLogger(),
|
|
httpClientFactory,
|
|
clientFactory,
|
|
mockApplyGroup,
|
|
)
|
|
|
|
// Verify results
|
|
require.NoError(t, testErr)
|
|
assert.True(t, applyGroupCalled, "ApplyGroup should have been called")
|
|
|
|
// Verify that the specs were correctly parsed
|
|
require.NotNil(t, capturedSpecs, "Specs should not be nil")
|
|
require.NotEmpty(t, capturedSpecs.Teams, "Specs should contain teams")
|
|
|
|
// Verify that the first team has the expected structure
|
|
var team1 map[string]interface{}
|
|
unmarshalErr := json.Unmarshal(capturedSpecs.Teams[0], &team1)
|
|
require.NoError(t, unmarshalErr, "Should be able to unmarshal team JSON")
|
|
|
|
// Verify that the team has a name
|
|
require.Contains(t, team1, "name", "Team should have a name")
|
|
teamName := team1["name"].(string)
|
|
require.NotEmpty(t, teamName, "Team name should not be empty")
|
|
|
|
// Verify that the starter library URL was requested
|
|
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
|
|
}
|
|
|
|
func TestApplyStarterLibraryWithMalformedYAML(t *testing.T) {
|
|
// Create mock HTTP client that returns malformed YAML
|
|
malformedYAML := `
|
|
teams:
|
|
- name: "Malformed Team
|
|
# Missing closing quote and improper indentation
|
|
scripts:
|
|
- "script1.sh
|
|
`
|
|
|
|
mockRT := &testRoundTripper2{
|
|
calls: []string{},
|
|
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
|
|
switch {
|
|
case req.URL.String() == starterLibraryURL:
|
|
// Return malformed YAML content
|
|
return createTestResponse(200, malformedYAML), nil
|
|
default:
|
|
// For any other URL, return a 404
|
|
return createTestResponse(404, "Not found"), nil
|
|
}
|
|
},
|
|
}
|
|
|
|
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
|
|
client := fleethttp.NewClient(opts...)
|
|
client.Transport = mockRT
|
|
return client
|
|
}
|
|
|
|
// Create a client factory that returns a real client
|
|
clientFactory := NewClient
|
|
|
|
// Create a mock ApplyGroup function that should not be called
|
|
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
|
|
t.Error("ApplyGroup should not be called with malformed YAML")
|
|
return nil
|
|
}
|
|
|
|
// Use a defer/recover to explicitly catch any panics
|
|
var panicValue interface{}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicValue = r
|
|
t.Fatalf("Panic occurred when processing malformed YAML: %v", panicValue)
|
|
}
|
|
}()
|
|
|
|
// Call the function under test
|
|
testErr := ApplyStarterLibrary(
|
|
context.Background(),
|
|
"https://example.com",
|
|
"test-token",
|
|
kitlog.NewNopLogger(),
|
|
httpClientFactory,
|
|
clientFactory,
|
|
mockApplyGroup,
|
|
)
|
|
|
|
// Verify results
|
|
require.Error(t, testErr, "Should return an error with malformed YAML")
|
|
assert.Contains(t, testErr.Error(), "failed to parse starter library",
|
|
"Error should indicate YAML parsing failure")
|
|
|
|
// Verify that the starter library URL was requested
|
|
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
|
|
|
|
// If we reach here, no panic occurred and the setup flow was not interrupted
|
|
}
|
|
|
|
func TestApplyStarterLibraryWithFreeLicense(t *testing.T) {
|
|
// Read the real production starter library YAML file
|
|
starterLibraryPath := "../../docs/01-Using-Fleet/starter-library/starter-library.yml"
|
|
starterLibraryContent, err := os.ReadFile(starterLibraryPath)
|
|
require.NoError(t, err, "Should be able to read starter library YAML file")
|
|
|
|
// Create mock HTTP client for downloading the starter library and scripts
|
|
mockRT := &testRoundTripper2{
|
|
calls: []string{},
|
|
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
|
|
switch {
|
|
case req.URL.String() == starterLibraryURL:
|
|
// Return the real starter library content
|
|
return createTestResponse(200, string(starterLibraryContent)), nil
|
|
case strings.Contains(req.URL.String(), "uninstall-fleetd"):
|
|
// Return a simple script for any script URL
|
|
return createTestResponse(200, "#!/bin/bash\necho ok"), nil
|
|
default:
|
|
// For any other URL, return a 404
|
|
return createTestResponse(404, "Not found"), nil
|
|
}
|
|
},
|
|
}
|
|
|
|
httpClientFactory := func(opts ...fleethttp.ClientOpt) *http.Client {
|
|
client := fleethttp.NewClient(opts...)
|
|
client.Transport = mockRT
|
|
return client
|
|
}
|
|
|
|
// Create a mock client that returns a free license
|
|
// Create a properly structured EnrichedAppConfig
|
|
mockEnrichedAppConfig := &fleet.EnrichedAppConfig{}
|
|
// Set the License field using json marshaling/unmarshaling to bypass unexported field access
|
|
configJSON := []byte(`{"license":{"tier":"free"}}`)
|
|
if err := json.Unmarshal(configJSON, mockEnrichedAppConfig); err != nil {
|
|
t.Fatal("Failed to unmarshal mock config:", err)
|
|
}
|
|
|
|
// Create a mock client factory
|
|
clientFactory := func(serverURL string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) {
|
|
mockClient := &Client{}
|
|
|
|
// Override the baseClient with a mock implementation
|
|
// Create a mock HTTP client
|
|
mockHTTPClient := &mockHTTPClient{
|
|
DoFunc: func(req *http.Request) (*http.Response, error) {
|
|
// Mock the GetAppConfig response
|
|
if req.URL.Path == "/api/v1/fleet/config" && req.Method == http.MethodGet {
|
|
respBody, _ := json.Marshal(mockEnrichedAppConfig)
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewBuffer(respBody)),
|
|
Header: make(http.Header),
|
|
}, nil
|
|
}
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(strings.NewReader("{}")),
|
|
Header: make(http.Header),
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
// Set up the baseClient with the mock HTTP client and a valid baseURL
|
|
baseURL, _ := url.Parse(serverURL)
|
|
mockClient.baseClient = &baseClient{
|
|
http: mockHTTPClient,
|
|
baseURL: baseURL,
|
|
}
|
|
|
|
return mockClient, nil
|
|
}
|
|
|
|
// Track if ApplyGroup was called and capture the specs
|
|
applyGroupCalled := false
|
|
var capturedSpecs *spec.Group
|
|
// Create a mock ApplyGroup function
|
|
mockApplyGroup := func(ctx context.Context, specs *spec.Group) error {
|
|
applyGroupCalled = true
|
|
capturedSpecs = specs
|
|
return nil
|
|
}
|
|
|
|
// Call the function under test - teams will be skipped automatically for free license
|
|
testErr := ApplyStarterLibrary(
|
|
context.Background(),
|
|
"https://example.com",
|
|
"test-token",
|
|
kitlog.NewNopLogger(),
|
|
httpClientFactory,
|
|
clientFactory,
|
|
mockApplyGroup,
|
|
)
|
|
|
|
// Verify results
|
|
require.NoError(t, testErr)
|
|
assert.True(t, applyGroupCalled, "ApplyGroup should have been called")
|
|
|
|
// Verify that the specs were correctly parsed
|
|
require.NotNil(t, capturedSpecs, "Specs should not be nil")
|
|
|
|
// Verify that teams were removed
|
|
require.Empty(t, capturedSpecs.Teams, "Teams should be empty for free license")
|
|
|
|
// Verify that policies referencing teams were filtered out
|
|
if capturedSpecs.Policies != nil {
|
|
for _, policy := range capturedSpecs.Policies {
|
|
assert.Empty(t, policy.Team, "Policies should not reference teams for free license")
|
|
}
|
|
}
|
|
|
|
// Verify that scripts were removed from AppConfig
|
|
if capturedSpecs.AppConfig != nil {
|
|
appConfigMap, ok := capturedSpecs.AppConfig.(map[string]interface{})
|
|
if ok {
|
|
_, hasScripts := appConfigMap["scripts"]
|
|
assert.False(t, hasScripts, "AppConfig should not contain scripts for free license")
|
|
}
|
|
}
|
|
|
|
// Verify that the starter library URL was requested
|
|
assert.Contains(t, mockRT.calls, starterLibraryURL, "The starter library URL should have been requested")
|
|
}
|
|
|
|
// mockHTTPClient is a mock implementation of the http.Client
|
|
type mockHTTPClient struct {
|
|
DoFunc func(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|
return m.DoFunc(req)
|
|
}
|