fleet/cmd/fleetctl/scripts_test.go
Sarah Gillespie c29f0abf92
Update API and CLI to enable running scripts by name and team id (#17322)
TODO:
- Integration tests

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

<!-- Note that API documentation changes are now addressed by the
product design team. -->

- [ ] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [ ] Documented any permissions changes (docs/Using
Fleet/manage-access.md)
- [ ] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for
new osquery data ingestion features.
- [ ] Added/updated tests
- [ ] If database migrations are included, checked table schema to
confirm autoupdate
- For database migrations:
- [ ] Checked schema for all modified table for columns that will
auto-update timestamps during migration.
- [ ] Confirmed that updating the timestamps is acceptable, and will not
cause unwanted side effects.
- [ ] Manual QA for all new/changed functionality
  - For Orbit and Fleet Desktop changes:
- [ ] Manual QA must be performed in the three main OSs, macOS, Windows
and Linux.
- [ ] Auto-update manual QA, from released version of component to new
version (see [tools/tuf/test](../tools/tuf/test/README.md)).
2024-03-05 08:53:17 -06:00

336 lines
9.9 KiB
Go

package main
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/stretchr/testify/require"
)
func TestRunScriptCommand(t *testing.T) {
_, ds := runServerWithMockedDS(t,
&service.TestServerOpts{
License: &fleet.LicenseInfo{
Tier: fleet.TierPremium,
},
},
&service.TestServerOpts{
HTTPServerConfig: &http.Server{WriteTimeout: 90 * time.Second}, // nolint:gosec
},
)
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
return nil
}
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
return nil, nil
}
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
return nil, nil
}
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
return nil, nil
}
ds.ListHostBatteriesFunc = func(ctx context.Context, hid uint) ([]*fleet.HostBattery, error) {
return nil, nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{ServerSettings: fleet.ServerSettings{ScriptsDisabled: false}}, nil
}
ds.GetScriptIDByNameFunc = func(ctx context.Context, name string, teamID *uint) (uint, error) {
return 1, nil
}
ds.IsExecutionPendingForHostFunc = func(ctx context.Context, hid uint, scriptID uint) ([]*uint, error) {
return []*uint{}, nil
}
ds.GetScriptContentsFunc = func(ctx context.Context, id uint) ([]byte, error) {
return []byte("echo hello world"), nil
}
generateValidPath := func() string {
return writeTmpScriptContents(t, "echo hello world", ".sh")
}
maxChars := strings.Repeat("a", 10001)
type testCase struct {
name string
scriptPath func() string
scriptName string
teamID *uint
scriptResult *fleet.HostScriptResult
expectOutput string
expectErrMsg string
expectNotFound bool
expectOffline bool
expectPending bool
}
cases := []testCase{
{
name: "host offline",
scriptPath: generateValidPath,
expectErrMsg: fleet.RunScriptHostOfflineErrMsg,
expectOffline: true,
},
{
name: "host not found",
scriptPath: generateValidPath,
expectErrMsg: fleet.RunScriptHostNotFoundErrMsg,
expectNotFound: true,
},
{
name: "invalid file type",
scriptPath: func() string { return writeTmpScriptContents(t, "echo hello world", ".txt") },
expectErrMsg: fleet.RunScriptInvalidTypeErrMsg,
},
{
name: "invalid hashbang",
scriptPath: func() string { return writeTmpScriptContents(t, "#! /foo/bar", ".sh") },
expectErrMsg: `Interpreter not supported. Bash scripts must run in "#!/bin/sh”.`,
},
{
name: "script too long (unsaved)",
scriptPath: func() string {
return writeTmpScriptContents(t, maxChars, ".sh")
},
expectErrMsg: "Script is too large. Script referenced by '--script-path' is limited to 10,000 characters. To run larger script save it to Fleet and use '--script-name'.",
},
{
name: "script-path and script-name disallowed",
scriptPath: generateValidPath,
scriptName: "foo",
expectErrMsg: `Only one of --script-path or --script-name is allowed.`,
},
{
name: "missing one of script-path and script-nqme",
expectErrMsg: `One of --script-path or --script-name must be specified.`,
},
{
name: "script-path and team-id disallowed",
scriptPath: generateValidPath,
teamID: ptr.Uint(1),
expectErrMsg: `Only one of --script-path or --team-id is allowed.`,
},
{
name: "script empty",
scriptPath: func() string { return writeTmpScriptContents(t, "", ".sh") },
expectErrMsg: `Script contents must not be empty.`,
},
{
name: "invalid utf8",
scriptPath: func() string { return writeTmpScriptContents(t, "\xff\xfa", ".sh") },
expectErrMsg: `Wrong data format. Only plain text allowed.`,
},
{
name: "script already running",
scriptPath: generateValidPath,
expectErrMsg: fleet.RunScriptAlreadyRunningErrMsg,
expectPending: true,
},
{
name: "script successful",
scriptPath: generateValidPath,
scriptResult: &fleet.HostScriptResult{
ExitCode: ptr.Int64(0),
Output: "hello world",
},
expectOutput: `
Exit code: 0 (Script ran successfully.)
Output:
-------------------------------------------------------------------------------------
hello world
-------------------------------------------------------------------------------------
`,
},
{
name: "script failed",
scriptPath: generateValidPath,
scriptResult: &fleet.HostScriptResult{
ExitCode: ptr.Int64(1),
Output: "",
},
expectOutput: `
Exit code: 1 (Script failed.)
Output:
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
`,
},
{
name: "script killed",
scriptPath: generateValidPath,
scriptResult: &fleet.HostScriptResult{
ExitCode: ptr.Int64(-1),
Output: "Oh no!",
Message: fleet.RunScriptScriptTimeoutErrMsg,
},
expectOutput: `
Error: Timeout. Fleet stopped the script after 5 minutes to protect host performance.
Output before timeout:
-------------------------------------------------------------------------------------
Oh no!
-------------------------------------------------------------------------------------
`,
},
{
name: "scripts disabled",
scriptPath: generateValidPath,
scriptResult: &fleet.HostScriptResult{
ExitCode: ptr.Int64(-2),
Output: "",
Message: fleet.RunScriptDisabledErrMsg,
},
expectOutput: `
Error: Scripts are disabled for this host. To run scripts, deploy the fleetd agent with scripts enabled.
`,
},
{
name: "output truncated",
scriptPath: generateValidPath,
scriptResult: &fleet.HostScriptResult{
ExitCode: ptr.Int64(0),
Output: maxChars,
},
expectOutput: fmt.Sprintf(`
Exit code: 0 (Script ran successfully.)
Output:
-------------------------------------------------------------------------------------
Fleet records the last 10,000 characters to prevent downtime.
%s
-------------------------------------------------------------------------------------
`, maxChars),
},
// TODO: this would take 5 minutes to run, we don't want that kind of slowdown in our test suite
// but can be useful to have around for manual testing.
//{
// name: "host timeout",
// scriptPath: generateValidPath,
// expectErrMsg: fleet.RunScriptHostTimeoutErrMsg,
//},
{name: "disabled scripts globally", scriptPath: generateValidPath, expectErrMsg: fleet.RunScriptScriptsDisabledGloballyErrMsg},
}
setupDS := func(t *testing.T, c testCase) {
ds.HostByIdentifierFunc = func(ctx context.Context, ident string) (*fleet.Host, error) {
if ident != "host1" || c.expectNotFound {
return nil, &notFoundError{}
}
return &fleet.Host{ID: 42, SeenTime: time.Now(), OrbitNodeKey: ptr.String("abc")}, nil
}
ds.HostFunc = func(ctx context.Context, hid uint) (*fleet.Host, error) {
if hid != 42 || c.expectNotFound {
return nil, &notFoundError{}
}
h := fleet.Host{ID: hid, SeenTime: time.Now(), OrbitNodeKey: ptr.String("abc")}
if c.expectOffline {
h.SeenTime = time.Now().Add(-time.Hour)
}
return &h, nil
}
ds.ListPendingHostScriptExecutionsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostScriptResult, error) {
require.Equal(t, uint(42), hid)
if c.expectPending {
return []*fleet.HostScriptResult{{HostID: uint(42)}}, nil
}
return nil, nil
}
ds.GetHostScriptExecutionResultFunc = func(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
if c.scriptResult != nil {
return c.scriptResult, nil
}
return &fleet.HostScriptResult{}, nil
}
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
return &fleet.HostLockWipeStatus{}, nil
}
ds.NewHostScriptExecutionRequestFunc = func(ctx context.Context, req *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
require.Equal(t, uint(42), req.HostID)
return &fleet.HostScriptResult{
Hostname: "host1",
HostID: req.HostID,
ScriptContents: req.ScriptContents,
}, nil
}
if c.name == "disabled scripts globally" {
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{ServerSettings: fleet.ServerSettings{ScriptsDisabled: true}}, nil
}
} else {
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{ServerSettings: fleet.ServerSettings{ScriptsDisabled: false}}, nil
}
}
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
setupDS(t, c)
args := []string{"run-script", "--host", "host1"}
if c.scriptPath != nil {
scriptPath := c.scriptPath()
defer os.Remove(scriptPath)
args = append(args, "--script-path", scriptPath)
}
if c.scriptName != "" {
args = append(args, "--script-name", c.scriptName)
}
if c.teamID != nil {
args = append(args, "--team-id", fmt.Sprintf("%d", *c.teamID))
}
b, err := runAppNoChecks(args)
if c.expectErrMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), c.expectErrMsg)
} else {
require.NoError(t, err)
}
if c.scriptResult != nil {
out := b.String()
require.NoError(t, err)
require.NotEmpty(t, out)
require.Equal(t, c.expectOutput, out)
} else {
require.Empty(t, b.String())
}
})
}
}
func writeTmpScriptContents(t *testing.T, scriptContents string, extension string) string {
tmpFile, err := os.CreateTemp(t.TempDir(), "*"+extension)
require.NoError(t, err)
_, err = tmpFile.WriteString(scriptContents)
require.NoError(t, err)
return tmpFile.Name()
}