fleet/server/datastore/mysql/scripts_test.go
Scott Gress 6c659050c0
Fix Orbit-canceled script runs being counted as "pending" (#33300)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #32220

# Details

This PR fixes an issue where hosts whose running scripts were canceled
by Orbit (e.g. due to timing out) were reported as being still "pending"
on the batch script details view. This was due to our only counting runs
as errored if the error code was > 0, and ignoring negative error codes
(which is what Orbit uses for this case).

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [X] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)

## Testing

- [X] Added/updated automated tests
Changed a couple of places where we were using `1` for an error code to
`-1`
- [X] Where appropriate, [automated tests simulate multiple hosts and
test for host
isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing)
(updates to one hosts's records do not affect another)

- [X] QA'd all new/changed functionality manually

For unreleased bug fixes in a release candidate, one of:

- [X] Confirmed that the fix is not expected to adversely impact load
test results
- [X] Alerted the release DRI if additional load testing is needed
2025-09-23 12:22:28 -05:00

2986 lines
108 KiB
Go

package mysql
import (
"context"
_ "embed"
"fmt"
"math"
"strings"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/fleetdm/fleet/v4/server/worker"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScripts(t *testing.T) {
ds := CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *Datastore)
}{
{"HostScriptResult", testHostScriptResult},
{"DEPRestoredHost", testListPendingScriptDEPRestoration},
{"Scripts", testScripts},
{"ListScripts", testListScripts},
{"GetHostScriptDetails", testGetHostScriptDetails},
{"BatchSetScripts", testBatchSetScripts},
{"TestLockHostViaScript", testLockHostViaScript},
{"TestUnlockHostViaScript", testUnlockHostViaScript},
{"TestLockUnlockWipeViaScripts", testLockUnlockWipeViaScripts},
{"TestLockUnlockManually", testLockUnlockManually},
{"TestInsertScriptContents", testInsertScriptContents},
{"TestCleanupUnusedScriptContents", testCleanupUnusedScriptContents},
{"TestGetAnyScriptContents", testGetAnyScriptContents},
{"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy},
{"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy},
{"UpdateScriptContents", testUpdateScriptContents},
{"UpdateScriptToDuplicateContent", testUpdateScriptToDuplicateContent},
{"UpdateSharedScriptContent", testUpdateSharedScriptContent},
{"UpdateScriptToSameContent", testUpdateScriptToSameContent},
{"UpdateDeletingUpcomingScriptExecutions", testUpdateDeletingUpcomingScriptExecutions},
{"BatchExecute", testBatchExecute},
{"BatchExecuteWithStatus", testBatchExecuteWithStatus},
{"BatchScriptSchedule", testBatchScriptSchedule},
{"BatchScriptCancel", testBatchScriptCancel},
{"TestMarkActivitiesAsCompleted", testMarkActivitiesAsCompleted},
{"DeleteScriptActivatesNextActivity", testDeleteScriptActivatesNextActivity},
{"BatchSetScriptActivatesNextActivity", testBatchSetScriptActivatesNextActivity},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer TruncateTables(t, ds)
c.fn(t, ds)
})
}
}
func testHostScriptResult(t *testing.T, ds *Datastore) {
ctx := context.Background()
// no script saved yet
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Empty(t, pending)
_, err = ds.GetHostScriptExecutionResult(ctx, "abc")
require.Error(t, err)
var nfe *common_mysql.NotFoundError
require.ErrorAs(t, err, &nfe)
// create a createdScript execution request (with a user)
u := test.NewUser(t, ds, "Bob", "bob@example.com", true)
createdScript1, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
require.NotZero(t, createdScript1.ID)
require.NotEmpty(t, createdScript1.ExecutionID)
require.Equal(t, uint(1), createdScript1.HostID)
require.NotEmpty(t, createdScript1.ExecutionID)
require.Equal(t, "echo", createdScript1.ScriptContents)
require.Nil(t, createdScript1.ExitCode)
require.Empty(t, createdScript1.Output)
require.NotNil(t, createdScript1.UserID)
require.Equal(t, u.ID, *createdScript1.UserID)
require.True(t, createdScript1.SyncRequest)
// createdScript1 is now activated, as the queue was empty
// the script execution is now listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript1.ID, pending[0].ID)
// the script execution isn't visible when looking at internal-only scripts
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, true)
require.NoError(t, err)
require.Empty(t, pending)
// record a result for this execution
hsr, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdScript1.ExecutionID,
Output: "foo",
Runtime: 2,
ExitCode: 0,
Timeout: 300,
})
require.NoError(t, err)
assert.Empty(t, action)
assert.NotNil(t, hsr)
// record a duplicate result for this execution, will be ignored
hsr, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdScript1.ExecutionID,
Output: "foobarbaz",
Runtime: 22,
ExitCode: 1,
Timeout: 360,
})
require.NoError(t, err)
require.Nil(t, hsr)
// it is not pending anymore
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Empty(t, pending)
// the script result can be retrieved
script, err := ds.GetHostScriptExecutionResult(ctx, createdScript1.ExecutionID)
require.NoError(t, err)
expectScript := *createdScript1
expectScript.Output = "foo"
expectScript.Runtime = 2
expectScript.ExitCode = ptr.Int64(0)
expectScript.Timeout = ptr.Int(300)
expectScript.CreatedAt, script.CreatedAt = time.Time{}, time.Time{}
require.Equal(t, &expectScript, script)
// create another script execution request (null user id this time)
time.Sleep(time.Millisecond) // ensure a different timestamp
createdScript2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo2",
})
require.NoError(t, err)
require.NotZero(t, createdScript2.ID)
require.NotEmpty(t, createdScript2.ExecutionID)
require.Nil(t, createdScript2.UserID)
require.False(t, createdScript2.SyncRequest)
// createdScript2 is now activated as the queue was empty
// the script execution is now listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript2.ID, pending[0].ID)
// the script result can be retrieved even if it has no result yet
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript2.ExecutionID)
require.NoError(t, err)
expectedScript := *createdScript2
expectedScript.CreatedAt, script.CreatedAt = time.Time{}, time.Time{}
require.Equal(t, &expectedScript, script)
// record a result for this execution, with an output that is too large
largeOutput := strings.Repeat("a", 1000) +
strings.Repeat("b", 1000) +
strings.Repeat("c", 1000) +
strings.Repeat("d", 1000) +
strings.Repeat("e", 1000) +
strings.Repeat("f", 1000) +
strings.Repeat("g", 1000) +
strings.Repeat("h", 1000) +
strings.Repeat("i", 1000) +
strings.Repeat("j", 1000) +
strings.Repeat("k", 1000)
// Note that the expectation is that the "a"s get truncated
expectedOutput := strings.Repeat("b", 1000) +
strings.Repeat("c", 1000) +
strings.Repeat("d", 1000) +
strings.Repeat("e", 1000) +
strings.Repeat("f", 1000) +
strings.Repeat("g", 1000) +
strings.Repeat("h", 1000) +
strings.Repeat("i", 1000) +
strings.Repeat("j", 1000) +
strings.Repeat("k", 1000)
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdScript2.ExecutionID,
Output: largeOutput,
Runtime: 10,
ExitCode: 1,
Timeout: 300,
})
require.NoError(t, err)
// the script result can be retrieved
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript2.ExecutionID)
require.NoError(t, err)
require.Equal(t, expectedOutput, script.Output)
// create an async execution request
time.Sleep(time.Millisecond) // ensure a different timestamp
createdScript3, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo 3",
UserID: &u.ID,
SyncRequest: false,
})
require.NoError(t, err)
require.NotZero(t, createdScript3.ID)
require.NotEmpty(t, createdScript3.ExecutionID)
require.Equal(t, uint(1), createdScript3.HostID)
require.NotEmpty(t, createdScript3.ExecutionID)
require.Equal(t, "echo 3", createdScript3.ScriptContents)
require.Nil(t, createdScript3.ExitCode)
require.Empty(t, createdScript3.Output)
require.NotNil(t, createdScript3.UserID)
require.Equal(t, u.ID, *createdScript3.UserID)
require.False(t, createdScript3.SyncRequest)
// createdScript3 is now activated as the queue was empty
// the script execution is now listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript3.ID, pending[0].ID)
// modify the upcoming script to be a sync script that has
// been pending for a long time doesn't change result
// https://github.com/fleetdm/fleet/issues/22866#issuecomment-2575961141
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE upcoming_activities SET created_at = ?, payload = JSON_SET(payload, '$.sync_request', ?) WHERE id = ?",
time.Now().Add(-24*time.Hour), true, createdScript3.ID)
return err
})
// the script is still pending
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript3.ExecutionID, pending[0].ExecutionID)
// check that scripts with large unsigned error codes get
// converted to signed error codes
createdUnsignedScript, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
// record a result for createdScript3 so that the unsigned script gets activated
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdScript3.ExecutionID,
Output: "foo",
Runtime: 1,
ExitCode: 0,
Timeout: 0,
})
require.NoError(t, err)
// createdUnsignedScript is now activated, record its result
unsignedScriptResult, _, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdUnsignedScript.ExecutionID,
Output: "foo",
Runtime: 1,
ExitCode: math.MaxUint32,
Timeout: 300,
})
require.NoError(t, err)
require.EqualValues(t, -1, *unsignedScriptResult.ExitCode)
}
func testListPendingScriptDEPRestoration(t *testing.T, ds *Datastore) {
ctx := context.Background()
host := test.NewHost(t, ds, "host", "10.0.0.1", "1", "uuid1", time.Now())
// no script saved yet
pending, err := ds.ListPendingHostScriptExecutions(ctx, host.ID, false)
require.NoError(t, err)
require.Empty(t, pending)
// create a createdScript execution request (with a user)
u := test.NewUser(t, ds, "Bob", "bob@example.com", true)
createdScript, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host.ID,
ScriptContents: "echo",
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
require.NotZero(t, createdScript.ID)
require.NotEmpty(t, createdScript.ExecutionID)
require.Equal(t, uint(1), createdScript.HostID)
require.NotEmpty(t, createdScript.ExecutionID)
require.Equal(t, "echo", createdScript.ScriptContents)
require.Nil(t, createdScript.ExitCode)
require.Empty(t, createdScript.Output)
require.NotNil(t, createdScript.UserID)
require.Equal(t, u.ID, *createdScript.UserID)
require.True(t, createdScript.SyncRequest)
// the script execution is now listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, host.ID, false)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript.ID, pending[0].ID)
// Set LastEnrolledAt before deleting the host (simulating a DEP enrolled host)
host.LastEnrolledAt = time.Now()
err = ds.DeleteHost(ctx, host.ID)
require.NoError(t, err)
err = ds.RestoreMDMApplePendingDEPHost(ctx, host)
require.NoError(t, err)
// the script execution is no longer listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, host.ID, false)
require.NoError(t, err)
require.Len(t, pending, 0)
}
func testScripts(t *testing.T, ds *Datastore) {
ctx := context.Background()
// get unknown script
_, err := ds.Script(ctx, 123)
var nfe fleet.NotFoundError
require.ErrorAs(t, err, &nfe)
// get unknown script contents
_, err = ds.GetScriptContents(ctx, 123)
require.ErrorAs(t, err, &nfe)
_, err = ds.GetAnyScriptContents(ctx, 123)
require.ErrorAs(t, err, &nfe)
// create global scriptGlobal
scriptGlobal, err := ds.NewScript(ctx, &fleet.Script{
Name: "a",
ScriptContents: "echo",
})
require.NoError(t, err)
require.NotZero(t, scriptGlobal.ID)
require.Nil(t, scriptGlobal.TeamID)
require.Equal(t, "a", scriptGlobal.Name)
require.Empty(t, scriptGlobal.ScriptContents) // we don't return the contents
// get the global script
script, err := ds.Script(ctx, scriptGlobal.ID)
require.NoError(t, err)
require.Equal(t, scriptGlobal, script)
// get the global script contents
contents, err := ds.GetScriptContents(ctx, scriptGlobal.ID)
require.NoError(t, err)
require.Equal(t, "echo", string(contents))
contents, err = ds.GetAnyScriptContents(ctx, scriptGlobal.ID)
require.NoError(t, err)
require.Equal(t, "echo", string(contents))
// create team script but team does not exist
_, err = ds.NewScript(ctx, &fleet.Script{
Name: "a",
TeamID: ptr.Uint(123),
ScriptContents: "echo",
})
require.Error(t, err)
var fkErr fleet.ForeignKeyError
require.ErrorAs(t, err, &fkErr)
// create a team and a script for that team with the same name as global
tm, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name()})
require.NoError(t, err)
scriptTeam, err := ds.NewScript(ctx, &fleet.Script{
Name: "a",
TeamID: &tm.ID,
ScriptContents: "echo 'team'",
})
require.NoError(t, err)
require.NotEqual(t, scriptGlobal.ID, scriptTeam.ID)
require.NotNil(t, scriptTeam.TeamID)
require.Equal(t, tm.ID, *scriptTeam.TeamID)
// get the team script
script, err = ds.Script(ctx, scriptTeam.ID)
require.NoError(t, err)
require.Equal(t, scriptTeam, script)
// get the team script contents
contents, err = ds.GetScriptContents(ctx, scriptTeam.ID)
require.NoError(t, err)
require.Equal(t, "echo 'team'", string(contents))
contents, err = ds.GetAnyScriptContents(ctx, scriptTeam.ID)
require.NoError(t, err)
require.Equal(t, "echo 'team'", string(contents))
// try to create another team script with the same name
_, err = ds.NewScript(ctx, &fleet.Script{
Name: "a",
TeamID: &tm.ID,
ScriptContents: "echo",
})
require.Error(t, err)
var existsErr fleet.AlreadyExistsError
require.ErrorAs(t, err, &existsErr)
// same for a global script
_, err = ds.NewScript(ctx, &fleet.Script{
Name: "a",
ScriptContents: "echo",
})
require.Error(t, err)
require.ErrorAs(t, err, &existsErr)
// create a script with a different name for the team works
_, err = ds.NewScript(ctx, &fleet.Script{
Name: "b",
TeamID: &tm.ID,
ScriptContents: "echo",
})
require.NoError(t, err)
// deleting script "a for the team, then we can re-create it
err = ds.DeleteScript(ctx, scriptTeam.ID)
require.NoError(t, err)
scriptTeam2, err := ds.NewScript(ctx, &fleet.Script{
Name: "a",
TeamID: &tm.ID,
ScriptContents: "echo",
})
require.NoError(t, err)
require.NotEqual(t, scriptTeam.ID, scriptTeam2.ID)
}
func testListScripts(t *testing.T, ds *Datastore) {
ctx := context.Background()
// create three teams
tm1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
tm2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
tm3, err := ds.NewTeam(ctx, &fleet.Team{Name: "team3"})
require.NoError(t, err)
// create 5 scripts for no team and team 1
for i := 0; i < 5; i++ {
_, err = ds.NewScript(ctx, &fleet.Script{
Name: string('a' + byte(i)), // i.e. "a", "b", "c", ...
ScriptContents: "echo",
})
require.NoError(t, err)
_, err = ds.NewScript(ctx, &fleet.Script{Name: string('a' + byte(i)), TeamID: &tm1.ID, ScriptContents: "echo"})
require.NoError(t, err)
}
// create a single script for team 2
_, err = ds.NewScript(ctx, &fleet.Script{Name: "a", TeamID: &tm2.ID, ScriptContents: "echo"})
require.NoError(t, err)
cases := []struct {
opts fleet.ListOptions
teamID *uint
wantNames []string
wantMeta *fleet.PaginationMetadata
}{
{
opts: fleet.ListOptions{},
wantNames: []string{"a", "b", "c", "d", "e"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
},
{
opts: fleet.ListOptions{PerPage: 2},
wantNames: []string{"a", "b"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
},
{
opts: fleet.ListOptions{Page: 1, PerPage: 2},
wantNames: []string{"c", "d"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: true},
},
{
opts: fleet.ListOptions{Page: 2, PerPage: 2},
wantNames: []string{"e"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
{
opts: fleet.ListOptions{PerPage: 3},
teamID: &tm1.ID,
wantNames: []string{"a", "b", "c"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
},
{
opts: fleet.ListOptions{Page: 1, PerPage: 3},
teamID: &tm1.ID,
wantNames: []string{"d", "e"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
{
opts: fleet.ListOptions{Page: 2, PerPage: 3},
teamID: &tm1.ID,
wantNames: nil,
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
{
opts: fleet.ListOptions{PerPage: 3},
teamID: &tm2.ID,
wantNames: []string{"a"},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
},
{
opts: fleet.ListOptions{Page: 0, PerPage: 2},
teamID: &tm3.ID,
wantNames: nil,
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
},
}
for _, c := range cases {
t.Run(fmt.Sprintf("%v: %#v", c.teamID, c.opts), func(t *testing.T) {
// always include metadata
c.opts.IncludeMetadata = true
scripts, meta, err := ds.ListScripts(ctx, c.teamID, c.opts)
require.NoError(t, err)
require.Equal(t, len(c.wantNames), len(scripts))
require.Equal(t, c.wantMeta, meta)
var gotNames []string
if len(scripts) > 0 {
gotNames = make([]string, len(scripts))
for i, s := range scripts {
gotNames[i] = s.Name
require.Equal(t, c.teamID, s.TeamID)
}
}
require.Equal(t, c.wantNames, gotNames)
})
}
}
func testGetHostScriptDetails(t *testing.T, ds *Datastore) {
ctx := context.Background()
t.Cleanup(func() { ds.testActivateSpecificNextActivities = nil })
names := []string{"script-1.sh", "script-2.sh", "script-3.sh", "script-4.sh", "script-5.sh"}
for _, r := range append(names[1:], names[0]) {
_, err := ds.NewScript(ctx, &fleet.Script{
Name: r,
ScriptContents: "echo " + r,
})
require.NoError(t, err)
}
// create a windows script as well
_, err := ds.NewScript(ctx, &fleet.Script{
Name: "script-6.ps1",
ScriptContents: `Write-Host "Hello, World!"`,
})
require.NoError(t, err)
scripts, _, err := ds.ListScripts(ctx, nil, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, scripts, 6)
insertResults := func(t *testing.T, hostID uint, script *fleet.Script, createdAt time.Time, execID string, exitCode *int64) {
var scriptID *uint
if script.ID != 0 {
scriptID = &script.ID
}
hsr, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptID: scriptID,
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `UPDATE upcoming_activities SET execution_id = ?, created_at = ? WHERE execution_id = ?`,
execID, createdAt, hsr.ExecutionID)
return err
})
if exitCode != nil {
ds.testActivateSpecificNextActivities = []string{execID}
act, err := ds.activateNextUpcomingActivity(ctx, ds.writer(ctx), hostID, "")
require.NoError(t, err)
require.ElementsMatch(t, act, ds.testActivateSpecificNextActivities)
ds.testActivateSpecificNextActivities = nil
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: execID,
ExitCode: int(*exitCode),
})
require.NoError(t, err)
// force the test timestamp
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `UPDATE host_script_results SET created_at = ? WHERE execution_id = ?`,
createdAt, execID)
return err
})
}
}
now := time.Now().UTC().Truncate(time.Second)
// add some results for an ad-hoc, non-saved script, should not be included in results
// create it first so that this one gets activated, and the other ones are never
// activated automatically.
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 42,
ScriptContents: "echo script-6",
})
require.NoError(t, err)
// add some results for script-1
insertResults(t, 42, scripts[0], now.Add(-3*time.Minute), "execution-1-1", nil)
insertResults(t, 42, scripts[0], now.Add(-1*time.Minute), "execution-1-2", nil) // last execution for script-1, status "pending"
insertResults(t, 42, scripts[0], now.Add(-2*time.Minute), "execution-1-3", nil)
// add some results for script-2
insertResults(t, 42, scripts[1], now.Add(-3*time.Minute), "execution-2-1", ptr.Int64(0))
insertResults(t, 42, scripts[1], now.Add(-1*time.Minute), "execution-2-2", ptr.Int64(1)) // last execution for script-2, status "error"
// add some results for script-3
insertResults(t, 42, scripts[2], now.Add(-1*time.Minute), "execution-3-1", ptr.Int64(0))
insertResults(t, 42, scripts[2], now.Add(-1*time.Minute), "execution-3-2", ptr.Int64(0)) // last execution for script-3, status "ran"
insertResults(t, 42, scripts[2], now.Add(-2*time.Minute), "execution-3-3", ptr.Int64(0))
// add some results for script-4
insertResults(t, 42, scripts[3], now.Add(-1*time.Minute), "execution-4-1", ptr.Int64(-2)) // last execution for script-4, status "error"
// add a pending and a completed script execution for script-5
insertResults(t, 42, scripts[4], now.Add(-2*time.Minute), "execution-5-1", ptr.Int64(0))
insertResults(t, 42, scripts[4], now.Add(-3*time.Minute), "execution-5-2", nil) // upcoming is always latest, regardless of timestamp
t.Run("results match expected formatting and filtering", func(t *testing.T) {
res, _, err := ds.GetHostScriptDetails(ctx, 42, nil, fleet.ListOptions{}, "")
require.NoError(t, err)
require.Len(t, res, 6)
for _, r := range res {
switch r.ScriptID {
case scripts[0].ID:
require.Equal(t, scripts[0].Name, r.Name)
require.NotNil(t, r.LastExecution)
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
require.Equal(t, "execution-1-2", r.LastExecution.ExecutionID)
require.Equal(t, "pending", r.LastExecution.Status)
case scripts[1].ID:
require.Equal(t, scripts[1].Name, r.Name)
require.NotNil(t, r.LastExecution)
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
require.Equal(t, "execution-2-2", r.LastExecution.ExecutionID)
require.Equal(t, "error", r.LastExecution.Status)
case scripts[2].ID:
require.Equal(t, scripts[2].Name, r.Name)
require.NotNil(t, r.LastExecution)
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
require.Equal(t, "execution-3-2", r.LastExecution.ExecutionID)
require.Equal(t, "ran", r.LastExecution.Status)
case scripts[3].ID:
require.Equal(t, scripts[3].Name, r.Name)
require.NotNil(t, r.LastExecution)
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
require.Equal(t, "execution-4-1", r.LastExecution.ExecutionID)
require.Equal(t, "error", r.LastExecution.Status)
case scripts[4].ID:
require.Equal(t, scripts[4].Name, r.Name)
require.NotNil(t, r.LastExecution)
// require.Equal(t, now.Add(-3*time.Minute), r.LastExecution.ExecutedAt)
require.Equal(t, "execution-5-2", r.LastExecution.ExecutionID)
require.Equal(t, "pending", r.LastExecution.Status)
case scripts[5].ID:
require.Equal(t, scripts[5].Name, r.Name)
require.Nil(t, r.LastExecution)
default:
t.Errorf("unexpected script id: %d", r.ScriptID)
}
}
})
t.Run("empty slice returned if no scripts", func(t *testing.T) {
res, _, err := ds.GetHostScriptDetails(ctx, 42, ptr.Uint(1), fleet.ListOptions{}, "") // team 1 has no scripts
require.NoError(t, err)
require.NotNil(t, res)
require.Len(t, res, 0)
})
t.Run("list options are supported", func(t *testing.T) {
cases := []struct {
opts fleet.ListOptions
teamID *uint
wantNames []string
wantMeta *fleet.PaginationMetadata
}{
{
opts: fleet.ListOptions{},
wantNames: names,
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
},
{
opts: fleet.ListOptions{PerPage: 2},
wantNames: names[:2],
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
},
{
opts: fleet.ListOptions{Page: 1, PerPage: 2},
wantNames: names[2:4],
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: true},
},
{
opts: fleet.ListOptions{Page: 2, PerPage: 2},
wantNames: names[4:],
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
}
for _, c := range cases {
t.Run(fmt.Sprintf("%#v", c.opts), func(t *testing.T) {
// always include metadata
c.opts.IncludeMetadata = true
// custom ordering is not supported, always by name
c.opts.OrderKey = "name"
results, meta, err := ds.GetHostScriptDetails(ctx, 42, nil, c.opts, "darwin")
require.NoError(t, err)
require.Equal(t, len(c.wantNames), len(results))
require.Equal(t, c.wantMeta, meta)
var gotNames []string
if len(results) > 0 {
gotNames = make([]string, len(results))
for i, r := range results {
gotNames[i] = r.Name
}
}
require.Equal(t, c.wantNames, gotNames)
})
}
})
t.Run("windows ps1 scripts are supported", func(t *testing.T) {
res, _, err := ds.GetHostScriptDetails(ctx, 42, nil, fleet.ListOptions{}, "windows")
require.NoError(t, err)
require.NotNil(t, res)
require.Len(t, res, 1)
require.Equal(t, "script-6.ps1", res[0].Name)
})
t.Run("can check if pending host script results exist", func(t *testing.T) {
insertResults(t, 42, scripts[2], now.Add(-2*time.Minute), "execution-3-4", nil)
r, err := ds.IsExecutionPendingForHost(ctx, 42, scripts[2].ID)
require.NoError(t, err)
require.True(t, r)
})
t.Run("script deletion cancels pending script runs", func(t *testing.T) {
insertResults(t, 43, scripts[3], now.Add(-2*time.Minute), "execution-4-4", nil)
pending, err := ds.ListPendingHostScriptExecutions(ctx, 43, false)
require.NoError(t, err)
require.Len(t, pending, 1)
err = ds.DeleteScript(ctx, scripts[3].ID)
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 43, false)
require.NoError(t, err)
require.Len(t, pending, 0)
})
}
func testBatchSetScripts(t *testing.T, ds *Datastore) {
ctx := context.Background()
applyAndExpect := func(newSet []*fleet.Script, tmID *uint, want []*fleet.Script) map[string]uint {
responseFromSet, err := ds.BatchSetScripts(ctx, tmID, newSet)
require.NoError(t, err)
if tmID == nil {
tmID = ptr.Uint(0)
}
got, _, err := ds.ListScripts(ctx, tmID, fleet.ListOptions{})
require.NoError(t, err)
// compare only the fields we care about
fromGetByScriptName := make(map[string]uint)
fromSetByScriptName := make(map[string]uint)
for _, gotScript := range responseFromSet {
fromSetByScriptName[gotScript.Name] = gotScript.ID
}
for _, gotScript := range got {
fromGetByScriptName[gotScript.Name] = gotScript.ID
if gotScript.TeamID != nil && *gotScript.TeamID == 0 {
gotScript.TeamID = nil
}
require.Equal(t, fromGetByScriptName[gotScript.Name], gotScript.ID)
gotScript.ID = 0
gotScript.CreatedAt = time.Time{}
gotScript.UpdatedAt = time.Time{}
}
// order is not guaranteed
require.ElementsMatch(t, want, got)
return fromGetByScriptName
}
// apply empty set for no-team
applyAndExpect(nil, nil, nil)
// create a team
tm1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "_tm1"})
require.NoError(t, err)
// apply single script set for tm1
sTm1 := applyAndExpect([]*fleet.Script{
{Name: "N1", ScriptContents: "C1"},
}, ptr.Uint(tm1.ID), []*fleet.Script{
{Name: "N1", TeamID: ptr.Uint(tm1.ID)},
})
n1WithTeamID := sTm1["N1"]
teamPolicy, err := ds.NewTeamPolicy(ctx, tm1.ID, nil, fleet.PolicyPayload{
Name: "Team One Policy",
Query: "SELECT 1",
Platform: "darwin",
ScriptID: &n1WithTeamID,
})
require.NoError(t, err)
// apply single script set for no-team
sNoTm := applyAndExpect([]*fleet.Script{
{Name: "N1", ScriptContents: "C1"},
}, nil, []*fleet.Script{
{Name: "N1", TeamID: nil},
})
n1WithNoTeamId := sNoTm["N1"]
noTeamPolicy, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, nil, fleet.PolicyPayload{
Name: "No Team Policy",
Query: "SELECT 1",
Platform: "darwin",
ScriptID: &n1WithNoTeamId,
})
require.NoError(t, err)
// apply new script set for tm1
sTm1b := applyAndExpect([]*fleet.Script{
{Name: "N1", ScriptContents: "C1"},
{Name: "N2", ScriptContents: "C2"},
}, ptr.Uint(tm1.ID), []*fleet.Script{
{Name: "N1", TeamID: ptr.Uint(tm1.ID)},
{Name: "N2", TeamID: ptr.Uint(tm1.ID)},
})
// name for N1-I1 is unchanged
require.Equal(t, sTm1["I1"], sTm1b["I1"])
// policy still has script associated
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Equal(t, n1WithTeamID, *teamPolicy.ScriptID)
// apply edited (by contents only) script set for no-team
sNoTmb := applyAndExpect([]*fleet.Script{
{Name: "N1", ScriptContents: "C1-changed"},
}, nil, []*fleet.Script{
{Name: "N1", TeamID: nil},
})
require.Equal(t, sNoTm["I1"], sNoTmb["I1"])
// policy still has script associated
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
require.NoError(t, err)
require.Equal(t, n1WithNoTeamId, *noTeamPolicy.ScriptID)
// apply edited script (by content only), unchanged script and new
// script for tm1
sTm1c := applyAndExpect([]*fleet.Script{
{Name: "N1", ScriptContents: "C1-updated"}, // content updated
{Name: "N2", ScriptContents: "C2"}, // unchanged
{Name: "N3", ScriptContents: "C3"}, // new
}, ptr.Uint(tm1.ID), []*fleet.Script{
{Name: "N1", TeamID: ptr.Uint(tm1.ID)}, // content updated
{Name: "N2", TeamID: ptr.Uint(tm1.ID)}, // unchanged
{Name: "N3", TeamID: ptr.Uint(tm1.ID)}, // new
})
// name for N1-I1 is unchanged
require.Equal(t, sTm1b["I1"], sTm1c["I1"])
// identifier for N2-I2 is unchanged
require.Equal(t, sTm1b["I2"], sTm1c["I2"])
// policy still has script associated
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Equal(t, n1WithTeamID, *teamPolicy.ScriptID)
// add pending scripts on team and no-team and confirm they're shown as pending
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 44,
ScriptID: &n1WithTeamID,
})
require.NoError(t, err)
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 45,
ScriptID: &n1WithNoTeamId,
})
require.NoError(t, err)
pending, err := ds.ListPendingHostScriptExecutions(ctx, 44, false)
require.NoError(t, err)
require.Len(t, pending, 1)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45, false)
require.NoError(t, err)
require.Len(t, pending, 1)
// clear scripts for tm1
applyAndExpect(nil, ptr.Uint(1), nil)
// policy on team should not have script assigned
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Nil(t, teamPolicy.ScriptID)
// no-team policy still has script associated
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
require.NoError(t, err)
require.Equal(t, n1WithNoTeamId, *noTeamPolicy.ScriptID)
// team script should no longer be pending, no-team script should still be pending
pending, err = ds.ListPendingHostScriptExecutions(ctx, 44, false)
require.NoError(t, err)
require.Len(t, pending, 0)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45, false)
require.NoError(t, err)
require.Len(t, pending, 1)
// apply only new scripts to no-team
applyAndExpect([]*fleet.Script{
{Name: "N4", ScriptContents: "C4"},
{Name: "N5", ScriptContents: "C5"},
}, nil, []*fleet.Script{
{Name: "N4", TeamID: nil},
{Name: "N5", TeamID: nil},
})
// policy on team should not have script assigned
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
require.NoError(t, err)
require.Nil(t, teamPolicy.ScriptID)
// no-team policy should not have script associated
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
require.NoError(t, err)
require.Nil(t, noTeamPolicy.ScriptID)
// no-team script should no longer be pending
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45, false)
require.NoError(t, err)
require.Len(t, pending, 0)
}
func testLockHostViaScript(t *testing.T, ds *Datastore) {
ctx := context.Background()
// no script saved yet
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Empty(t, pending)
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
windowsHostID := uint(1)
script := "lock"
err = ds.LockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: windowsHostID,
ScriptContents: script,
UserID: &user.ID,
SyncRequest: false,
}, "windows")
require.NoError(t, err)
// verify that we have created entries in host_mdm_actions and host_script_results
status, err := ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: windowsHostID, Platform: "windows", UUID: "uuid"})
require.NoError(t, err)
require.Equal(t, "windows", status.HostFleetPlatform)
require.NotNil(t, status.LockScript)
assert.Nil(t, status.UnlockScript)
s := status.LockScript
require.Equal(t, script, s.ScriptContents)
require.Equal(t, windowsHostID, s.HostID)
require.False(t, s.SyncRequest)
require.Equal(t, &user.ID, s.UserID)
require.True(t, status.IsPendingLock())
// simulate a successful result for the lock script execution
_, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: s.HostID,
ExecutionID: s.ExecutionID,
ExitCode: 0,
})
require.NoError(t, err)
assert.Equal(t, "lock_ref", action)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: windowsHostID, Platform: "windows", UUID: "uuid"})
require.NoError(t, err)
require.True(t, status.IsLocked())
require.False(t, status.IsPendingLock())
require.False(t, status.IsUnlocked())
}
func testUnlockHostViaScript(t *testing.T, ds *Datastore) {
ctx := context.Background()
// no script saved yet
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Empty(t, pending)
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
hostID := uint(1)
hostUUID := "uuid"
hostPlatform := "windows"
host, err := ds.NewHost(ctx, &fleet.Host{
ID: hostID,
UUID: hostUUID,
Platform: hostPlatform,
OsqueryHostID: &hostUUID,
})
require.NoError(t, err)
script := "unlock"
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: script,
UserID: &user.ID,
SyncRequest: false,
}, hostPlatform)
require.NoError(t, err)
// verify that we have created entries in host_mdm_actions and host_script_results
status, err := ds.GetHostLockWipeStatus(ctx, host)
require.NoError(t, err)
require.Equal(t, hostPlatform, status.HostFleetPlatform)
require.NotNil(t, status.UnlockScript)
s := status.UnlockScript
require.Equal(t, script, s.ScriptContents)
require.Equal(t, hostID, s.HostID)
require.False(t, s.SyncRequest)
require.Equal(t, &user.ID, s.UserID)
require.True(t, status.IsPendingUnlock())
// simulate a cancel while it's pending unlock
_, err = ds.CancelHostUpcomingActivity(ctx, s.HostID, s.ExecutionID)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, host)
require.NoError(t, err)
require.False(t, status.IsPendingUnlock())
// add a new unlock script execution
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: script,
UserID: &user.ID,
SyncRequest: false,
}, hostPlatform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, host)
require.NoError(t, err)
require.Equal(t, hostPlatform, status.HostFleetPlatform)
require.NotNil(t, status.UnlockScript)
s = status.UnlockScript
// simulate a successful result for the unlock script execution
_, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: s.HostID,
ExecutionID: s.ExecutionID,
ExitCode: 0,
})
require.NoError(t, err)
assert.Equal(t, "unlock_ref", action)
status, err = ds.GetHostLockWipeStatus(ctx, host)
require.NoError(t, err)
require.True(t, status.IsUnlocked())
require.False(t, status.IsPendingUnlock())
require.False(t, status.IsLocked())
}
func testLockUnlockWipeViaScripts(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
for i, platform := range []string{"windows", "linux"} {
hostID := uint(i + 1) //nolint:gosec // dismiss G115
t.Run(platform, func(t *testing.T) {
status, err := ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
// default state
checkLockWipeState(t, status, true, false, false, false, false, false)
// record a request to lock the host
err = ds.LockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: "lock",
UserID: &user.ID,
SyncRequest: false,
}, platform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, true, false)
// simulate a successful result for the lock script execution
_, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: status.LockScript.ExecutionID,
ExitCode: 0,
})
require.NoError(t, err)
assert.Equal(t, "lock_ref", action)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, false, true, false, false, false, false)
// record a request to unlock the host
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: "unlock",
UserID: &user.ID,
SyncRequest: false,
}, platform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, false, true, false, true, false, false)
// simulate a failed result for the unlock script execution
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: status.UnlockScript.ExecutionID,
ExitCode: -1,
})
require.NoError(t, err)
assert.Equal(t, "unlock_ref", action)
// still locked
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, false, true, false, false, false, false)
// record another request to unlock the host
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: "unlock",
UserID: &user.ID,
SyncRequest: false,
}, platform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, false, true, false, true, false, false)
// this time simulate a successful result for the unlock script execution
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: status.UnlockScript.ExecutionID,
ExitCode: 0,
})
require.NoError(t, err)
assert.Equal(t, "unlock_ref", action)
// host is now unlocked
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, false, false)
// record another request to lock the host
err = ds.LockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: "lock",
UserID: &user.ID,
SyncRequest: false,
}, platform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, true, false)
// simulate a failed result for the lock script execution
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: status.LockScript.ExecutionID,
ExitCode: 2,
})
require.NoError(t, err)
assert.Equal(t, "lock_ref", action)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, false, false)
switch platform {
case "windows":
// need a real MDM-enrolled host for MDM commands
h, err := ds.NewHost(ctx, &fleet.Host{
Hostname: "test-host-windows",
OsqueryHostID: ptr.String("osquery-windows"),
NodeKey: ptr.String("nodekey-windows"),
UUID: "test-uuid-windows",
Platform: "windows",
})
require.NoError(t, err)
windowsEnroll(t, ds, h)
// record a request to wipe the host
wipeCmdUUID := uuid.NewString()
wipeCmd := &fleet.MDMWindowsCommand{
CommandUUID: wipeCmdUUID,
RawCommand: []byte(`<Exec></Exec>`),
TargetLocURI: "./Device/Vendor/MSFT/RemoteWipe/doWipeProtected",
}
err = ds.WipeHostViaWindowsMDM(ctx, h, wipeCmd)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, h)
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, false, true)
// TODO: we don't seem to have an easy way to simulate a Windows MDM
// protocol response, and there are lots of validations happening so we
// can't just send a simple XML. Will test the rest via integration
// tests.
case "linux":
// record a request to wipe the host
err = ds.WipeHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: "wipe",
UserID: &user.ID,
SyncRequest: false,
}, platform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, false, true)
// simulate a failed result for the wipe script execution
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: status.WipeScript.ExecutionID,
ExitCode: 1,
})
require.NoError(t, err)
assert.Equal(t, "wipe_ref", action)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, false, false)
// record another request to wipe the host
err = ds.WipeHostViaScript(ctx, &fleet.HostScriptRequestPayload{
HostID: hostID,
ScriptContents: "wipe2",
UserID: &user.ID,
SyncRequest: false,
}, platform)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, true, false, false, false, false, true)
// simulate a successful result for the wipe script execution
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: hostID,
ExecutionID: status.WipeScript.ExecutionID,
ExitCode: 0,
})
require.NoError(t, err)
assert.Equal(t, "wipe_ref", action)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
require.NoError(t, err)
checkLockWipeState(t, status, false, false, true, false, false, false)
}
})
}
}
func testLockUnlockManually(t *testing.T, ds *Datastore) {
ctx := context.Background()
twoDaysAgo := time.Now().AddDate(0, 0, -2).UTC()
today := time.Now().UTC()
err := ds.UnlockHostManually(ctx, 1, "darwin", twoDaysAgo)
require.NoError(t, err)
status, err := ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: 1, Platform: "darwin", UUID: "uuid"})
require.NoError(t, err)
require.False(t, status.UnlockRequestedAt.IsZero())
require.WithinDuration(t, twoDaysAgo, status.UnlockRequestedAt, 1*time.Second)
// if the unlock request already exists, it is not overwritten by subsequent
// requests
err = ds.UnlockHostManually(ctx, 1, "darwin", today)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: 1, Platform: "darwin", UUID: "uuid"})
require.NoError(t, err)
require.False(t, status.UnlockRequestedAt.IsZero())
require.WithinDuration(t, twoDaysAgo, status.UnlockRequestedAt, 1*time.Second)
// but for a new host, it will set it properly, even if that host already has a
// host_mdm_actions entry
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "INSERT INTO host_mdm_actions (host_id) VALUES (2)")
return err
})
err = ds.UnlockHostManually(ctx, 2, "darwin", today)
require.NoError(t, err)
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: 2, Platform: "darwin", UUID: "uuid"})
require.NoError(t, err)
require.False(t, status.UnlockRequestedAt.IsZero())
require.WithinDuration(t, today, status.UnlockRequestedAt, 1*time.Second)
}
func checkLockWipeState(t *testing.T, status *fleet.HostLockWipeStatus, unlocked, locked, wiped, pendingUnlock, pendingLock, pendingWipe bool) {
require.Equal(t, unlocked, status.IsUnlocked(), "unlocked")
require.Equal(t, locked, status.IsLocked(), "locked")
require.Equal(t, wiped, status.IsWiped(), "wiped")
require.Equal(t, pendingLock, status.IsPendingLock(), "pending lock")
require.Equal(t, pendingUnlock, status.IsPendingUnlock(), "pending unlock")
require.Equal(t, pendingWipe, status.IsPendingWipe(), "pending wipe")
}
type scriptContents struct {
ID uint `db:"id"`
Checksum string `db:"md5_checksum"`
}
func testInsertScriptContents(t *testing.T, ds *Datastore) {
ctx := context.Background()
contents := `echo foobar;`
res, err := insertScriptContents(ctx, ds.writer(ctx), contents)
require.NoError(t, err)
id, _ := res.LastInsertId()
require.Equal(t, int64(1), id)
expectedCS := md5ChecksumScriptContent(contents)
// insert same contents again, verify that the checksum and ID stayed the same
res, err = insertScriptContents(ctx, ds.writer(ctx), contents)
require.NoError(t, err)
id, _ = res.LastInsertId()
require.Equal(t, int64(1), id)
stmt := `SELECT id, HEX(md5_checksum) as md5_checksum FROM script_contents WHERE id = ?`
var sc []scriptContents
err = sqlx.SelectContext(ctx, ds.reader(ctx),
&sc, stmt,
id,
)
require.NoError(t, err)
require.Len(t, sc, 1)
require.EqualValues(t, id, sc[0].ID)
require.Equal(t, expectedCS, sc[0].Checksum)
}
func testCleanupUnusedScriptContents(t *testing.T, ds *Datastore) {
ctx := context.Background()
// create a saved script
s := &fleet.Script{
ScriptContents: "echo foobar",
}
s, err := ds.NewScript(ctx, s)
require.NoError(t, err)
user1 := test.NewUser(t, ds, "Bob", "bob@example.com", true)
// create a sync script execution
res, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{ScriptContents: "echo something_else", SyncRequest: true})
require.NoError(t, err)
// create a software install that references scripts
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
swi, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "install-script",
UninstallScript: "uninstall-script",
PreInstallQuery: "SELECT 1",
PostInstallScript: "post-install-script",
InstallerFile: tfr1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
// delete our saved script without ever executing it
require.NoError(t, ds.DeleteScript(ctx, s.ID))
// validate that script contents still exist
var sc []scriptContents
stmt := `SELECT id, HEX(md5_checksum) as md5_checksum FROM script_contents`
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
require.NoError(t, err)
require.Len(t, sc, 5)
// this should only remove the script_contents of the saved script, since the sync script is
// still "in use" by the script execution
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
sc = []scriptContents{}
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
require.NoError(t, err)
require.Len(t, sc, 4)
require.ElementsMatch(t, []string{
md5ChecksumScriptContent(res.ScriptContents),
md5ChecksumScriptContent("install-script"),
md5ChecksumScriptContent("post-install-script"),
md5ChecksumScriptContent("uninstall-script"),
}, []string{
sc[0].Checksum,
sc[1].Checksum,
sc[2].Checksum,
sc[3].Checksum,
})
// remove the software installer from the DB
err = ds.DeleteSoftwareInstaller(ctx, swi)
require.NoError(t, err)
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
// validate that script contents still exist
sc = []scriptContents{}
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
require.NoError(t, err)
require.Len(t, sc, 1)
require.Equal(t, md5ChecksumScriptContent(res.ScriptContents), sc[0].Checksum)
// create a software install without a post-install script
tfr2, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
swi, _, err = ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
PreInstallQuery: "SELECT 1",
InstallScript: "install-script",
UninstallScript: "uninstall-script",
InstallerFile: tfr2,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user1.ID,
ValidatedLabels: &fleet.LabelIdentsWithScope{},
})
require.NoError(t, err)
// run the cleanup function
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
sc = []scriptContents{}
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
require.NoError(t, err)
require.Len(t, sc, 3)
// remove the software installer from the DB
err = ds.DeleteSoftwareInstaller(ctx, swi)
require.NoError(t, err)
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
// validate that script contents still exist
sc = []scriptContents{}
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
require.NoError(t, err)
require.Len(t, sc, 1)
require.Equal(t, md5ChecksumScriptContent(res.ScriptContents), sc[0].Checksum)
}
func testGetAnyScriptContents(t *testing.T, ds *Datastore) {
ctx := context.Background()
contents := `echo foobar;`
res, err := insertScriptContents(ctx, ds.writer(ctx), contents)
require.NoError(t, err)
id, _ := res.LastInsertId()
result, err := ds.GetAnyScriptContents(ctx, uint(id)) //nolint:gosec // dismiss G115
require.NoError(t, err)
require.Equal(t, contents, string(result))
}
func testDeleteScriptsAssignedToPolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script.sh",
TeamID: &team1.ID,
ScriptContents: "hello world",
})
require.NoError(t, err)
p1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
ScriptID: &script.ID,
})
require.NoError(t, err)
err = ds.DeleteScript(ctx, script.ID)
require.Error(t, err)
require.ErrorIs(t, err, errDeleteScriptWithAssociatedPolicy)
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{p1.ID})
require.NoError(t, err)
err = ds.DeleteScript(ctx, script.ID)
require.NoError(t, err)
}
func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, _ := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
script1, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
TeamID: &team1.ID,
ScriptContents: "hello world",
})
require.NoError(t, err)
script2, err := ds.NewScript(ctx, &fleet.Script{
Name: "script2.sh",
TeamID: &team1.ID,
ScriptContents: "hello world",
})
require.NoError(t, err)
p1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
ScriptID: &script1.ID,
})
require.NoError(t, err)
p2, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 2;",
ScriptID: &script2.ID,
})
require.NoError(t, err)
// pending host script execution for correct policy
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p1.ID,
SyncRequest: true,
ScriptID: &script1.ID,
})
require.NoError(t, err)
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Equal(t, 1, len(pending))
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Equal(t, 0, len(pending))
// test pending host script execution for incorrect policy
hsr, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p2.ID,
SyncRequest: true,
ScriptID: &script2.ID,
})
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Equal(t, 1, len(pending))
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
require.NoError(t, err)
require.Equal(t, 1, len(pending))
// test not pending host script execution for correct policy
scriptExecution, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p1.ID,
SyncRequest: true,
ScriptID: &script1.ID,
})
require.NoError(t, err)
// record a result for the previous pending script
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: hsr.ExecutionID,
Output: "foo",
ExitCode: 0,
})
require.NoError(t, err)
// record a failed result for the current pending script
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: scriptExecution.ExecutionID,
Output: "foo",
ExitCode: 1,
})
require.NoError(t, err)
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
var count int
err = sqlx.GetContext(
ctx,
ds.reader(ctx),
&count,
"SELECT count(1) FROM host_script_results WHERE id = ?",
scriptExecution.ID,
)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func testUpdateScriptContents(t *testing.T, ds *Datastore) {
ctx := context.Background()
originalScript, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1",
ScriptContents: "hello world",
})
require.NoError(t, err)
originalContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID)
require.NoError(t, err)
require.Equal(t, "hello world", string(originalContents))
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE scripts SET updated_at = ? WHERE id = ?", time.Now().Add(-2*time.Minute), originalScript.ID)
if err != nil {
return err
}
return nil
})
// Make sure updated_at was changed correctly, but the script is the same
oldScript, err := ds.Script(ctx, originalScript.ID)
require.Equal(t, originalScript.ScriptContentID, oldScript.ScriptContentID)
require.NoError(t, err)
require.NotEqual(t, originalScript.UpdatedAt, oldScript.UpdatedAt)
// Modify the script
updatedScript, err := ds.UpdateScriptContents(ctx, originalScript.ID, "updated script")
require.NoError(t, err)
require.Equal(t, originalScript.ID, updatedScript.ID)
// With the fix, the script should get a new content ID since content changed
require.NotEqual(t, originalScript.ScriptContentID, updatedScript.ScriptContentID)
updatedContents, err := ds.GetScriptContents(ctx, updatedScript.ID)
require.NoError(t, err)
require.Equal(t, "updated script", string(updatedContents))
require.NotEqual(t, oldScript.UpdatedAt, updatedScript.UpdatedAt)
}
func testUpdateDeletingUpcomingScriptExecutions(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "User", "user@example.com", true)
host1 := test.NewHost(t, ds, "host1", "10.0.0.1", "host1Key", "host1UUID", time.Now())
host2 := test.NewHost(t, ds, "host2", "10.0.0.2", "host2Key", "host2UUID", time.Now())
script1, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1",
ScriptContents: "contents1",
})
require.NoError(t, err)
script2, err := ds.NewScript(ctx, &fleet.Script{
Name: "script2",
ScriptContents: "contents2",
})
require.NoError(t, err)
script3, err := ds.NewScript(ctx, &fleet.Script{
Name: "script3",
ScriptContents: "contents3",
})
require.NoError(t, err)
// Queue script executions
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host1.ID,
ScriptID: &script1.ID,
UserID: &user.ID,
})
require.NoError(t, err)
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host1.ID,
ScriptID: &script2.ID,
UserID: &user.ID,
})
require.NoError(t, err)
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host2.ID,
ScriptID: &script2.ID,
UserID: &user.ID,
})
require.NoError(t, err)
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host2.ID,
ScriptID: &script1.ID,
UserID: &user.ID,
})
require.NoError(t, err)
upcoming1, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 2)
upcoming2, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 2)
// Updating the "pending/upcoming" script will cancel the activity and stop it from running
_, err = ds.UpdateScriptContents(ctx, script1.ID, "new contents1")
require.NoError(t, err)
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 1)
require.Equal(t, script2.ID, *upcoming1[0].ScriptID)
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 1)
require.Equal(t, script2.ID, *upcoming2[0].ScriptID)
// Updating a script with no upcoming activities shouldn't affect anything
_, err = ds.UpdateScriptContents(ctx, script3.ID, "new contents")
require.NoError(t, err)
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 1)
require.Equal(t, script2.ID, *upcoming1[0].ScriptID)
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 1)
require.Equal(t, script2.ID, *upcoming2[0].ScriptID)
}
func testBatchExecute(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "user1", "user@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
test.SetOrbitEnrollment(t, hostWindows, ds)
test.SetOrbitEnrollment(t, host1, ds)
test.SetOrbitEnrollment(t, host2, ds)
test.SetOrbitEnrollment(t, host3, ds)
test.SetOrbitEnrollment(t, hostTeam1, ds)
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo hi",
})
require.NoError(t, err)
// Hosts all have to be on the same team as the script
execID, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostTeam1.ID})
require.Empty(t, execID)
require.ErrorContains(t, err, "same team")
// Actual good execution
execID, err = ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
require.NoError(t, err)
summary, err := ds.BatchExecuteSummary(ctx, execID)
require.NoError(t, err)
require.Equal(t, script.ID, *summary.ScriptID)
require.Equal(t, script.Name, summary.ScriptName)
require.Equal(t, uint(0), *summary.TeamID)
require.NotNil(t, summary.CreatedAt)
// The summary should have two pending hosts and two errored ones, because
// the script is not compatible with the hostNoScripts and hostWindows.
require.Equal(t, *summary.NumPending, uint(3))
require.Equal(t, *summary.NumErrored, uint(2))
require.Equal(t, *summary.NumRan, uint(0))
require.Equal(t, *summary.NumCanceled, uint(0))
// Host 1 should have an upcoming execution
host1Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, host1Upcoming, 1)
require.Equal(t, summary.ScriptID, host1Upcoming[0].ScriptID)
// Host 2 should have an upcoming execution
host2Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, host2Upcoming, 1)
require.Equal(t, summary.ScriptID, host2Upcoming[0].ScriptID)
// Host 3 should have an upcoming execution
host3Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host3.ID, false, false)
require.NoError(t, err)
require.Len(t, host3Upcoming, 1)
require.Equal(t, summary.ScriptID, host3Upcoming[0].ScriptID)
// Host Windows should not have an upcoming execution
hostWindowsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostWindows.ID, false, false)
require.NoError(t, err)
require.Len(t, hostWindowsUpcoming, 0)
// Host No Scripts should not have an upcoming execution
hostNoScriptsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostNoScripts.ID, false, false)
require.NoError(t, err)
require.Len(t, hostNoScriptsUpcoming, 0)
// Host Windows should have an error in its `batch_activity_host_results` row
var exec_error string
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
db := q.(*sqlx.DB)
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostWindows.ID, execID)
require.NoError(t, err)
return nil
})
require.Equal(t, fleet.BatchExecuteIncompatiblePlatform, exec_error)
// Host No Scripts should have an error in its `batch_activity_host_results` row
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
db := q.(*sqlx.DB)
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostNoScripts.ID, execID)
require.NoError(t, err)
return nil
})
require.Equal(t, fleet.BatchExecuteIncompatibleFleetd, exec_error)
// Set host 1 to have a successful script result
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host1.ID,
ExecutionID: host1Upcoming[0].ExecutionID,
Output: "foo",
ExitCode: 0,
})
require.NoError(t, err)
// Get the summary again
summary, err = ds.BatchExecuteSummary(ctx, execID)
require.NoError(t, err)
// The summary should have one pending host, one run host and two errored ones.
require.Equal(t, *summary.NumPending, uint(2))
require.Equal(t, *summary.NumErrored, uint(2))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(0))
// Set host 1 to have a failed script result
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host2.ID,
ExecutionID: host2Upcoming[0].ExecutionID,
Output: "bar",
ExitCode: 1,
})
require.NoError(t, err)
// Get the summary again
summary, err = ds.BatchExecuteSummary(ctx, execID)
require.NoError(t, err)
// The summary should have one pending host, one run host and two errored ones.
require.Equal(t, *summary.NumPending, uint(1))
require.Equal(t, *summary.NumErrored, uint(3))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(0))
// Cancel the execution
_, err = ds.CancelHostUpcomingActivity(ctx, host3.ID, host3Upcoming[0].ExecutionID)
require.NoError(t, err)
// Get the summary again
summary, err = ds.BatchExecuteSummary(ctx, execID)
require.NoError(t, err)
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
require.Equal(t, *summary.NumPending, uint(0))
require.Equal(t, *summary.NumErrored, uint(3))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(1))
}
func testBatchExecuteWithStatus(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "user1", "user@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
test.SetOrbitEnrollment(t, hostWindows, ds)
test.SetOrbitEnrollment(t, host1, ds)
test.SetOrbitEnrollment(t, host2, ds)
test.SetOrbitEnrollment(t, host3, ds)
test.SetOrbitEnrollment(t, hostTeam1, ds)
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo hi",
})
require.NoError(t, err)
// Hosts all have to be on the same team as the script
execID, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostTeam1.ID})
require.Empty(t, execID)
require.ErrorContains(t, err, "same team")
// Actual good execution
execID, err = ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
require.NoError(t, err)
// Update the batch to have a pending status
// TODO -- remove this when status is set automatically
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status = 'scheduled' WHERE execution_id = ?", execID)
return err
})
summaryList, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary := (summaryList)[0]
require.Equal(t, execID, summary.BatchExecutionID)
require.Equal(t, script.ID, *summary.ScriptID)
require.Equal(t, script.Name, summary.ScriptName)
require.Equal(t, uint(0), *summary.TeamID)
require.NotNil(t, summary.CreatedAt)
// The summary should have two pending hosts and two errored ones, because
// the script is not compatible with the hostNoScripts and hostWindows.
require.Equal(t, *summary.NumTargeted, uint(5))
require.Equal(t, *summary.NumPending, uint(3))
require.Equal(t, *summary.NumIncompatible, uint(2))
require.Equal(t, *summary.NumErrored, uint(0))
require.Equal(t, *summary.NumRan, uint(0))
require.Equal(t, *summary.NumCanceled, uint(0))
// Host 1 should have an upcoming execution
host1Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, host1Upcoming, 1)
require.Equal(t, summary.ScriptID, host1Upcoming[0].ScriptID)
// Host 2 should have an upcoming execution
host2Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, host2Upcoming, 1)
require.Equal(t, summary.ScriptID, host2Upcoming[0].ScriptID)
// Host 3 should have an upcoming execution
host3Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host3.ID, false, false)
require.NoError(t, err)
require.Len(t, host3Upcoming, 1)
require.Equal(t, summary.ScriptID, host3Upcoming[0].ScriptID)
// Host Windows should not have an upcoming execution
hostWindowsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostWindows.ID, false, false)
require.NoError(t, err)
require.Len(t, hostWindowsUpcoming, 0)
// Host No Scripts should not have an upcoming execution
hostNoScriptsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostNoScripts.ID, false, false)
require.NoError(t, err)
require.Len(t, hostNoScriptsUpcoming, 0)
// Host Windows should have an error in its `batch_activity_host_results` row
var exec_error string
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
db := q.(*sqlx.DB)
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostWindows.ID, execID)
require.NoError(t, err)
return nil
})
require.Equal(t, fleet.BatchExecuteIncompatiblePlatform, exec_error)
// Host No Scripts should have an error in its `batch_activity_host_results` row
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
db := q.(*sqlx.DB)
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostNoScripts.ID, execID)
require.NoError(t, err)
return nil
})
require.Equal(t, fleet.BatchExecuteIncompatibleFleetd, exec_error)
// Set host 1 to have a successful script result
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host1.ID,
ExecutionID: host1Upcoming[0].ExecutionID,
Output: "foo",
ExitCode: 0,
})
require.NoError(t, err)
// Get the summary again
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary = (summaryList)[0]
// The summary should have one pending host, one run host and two errored ones.
require.Equal(t, *summary.NumTargeted, uint(5))
require.Equal(t, *summary.NumPending, uint(2))
require.Equal(t, *summary.NumIncompatible, uint(2))
require.Equal(t, *summary.NumErrored, uint(0))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(0))
// Set host 1 to have a failed script result
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host2.ID,
ExecutionID: host2Upcoming[0].ExecutionID,
Output: "bar",
ExitCode: -1,
})
require.NoError(t, err)
// Get the summary again
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary = (summaryList)[0] // The summary should have one pending host, one run host and two errored ones.
require.Equal(t, *summary.NumTargeted, uint(5))
require.Equal(t, *summary.NumPending, uint(1))
require.Equal(t, *summary.NumIncompatible, uint(2))
require.Equal(t, *summary.NumErrored, uint(1))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(0))
// Cancel the execution
_, err = ds.CancelHostUpcomingActivity(ctx, host3.ID, host3Upcoming[0].ExecutionID)
require.NoError(t, err)
// Get the summary again
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary = (summaryList)[0]
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
require.Equal(t, *summary.NumPending, uint(0))
require.Equal(t, *summary.NumIncompatible, uint(2))
require.Equal(t, *summary.NumErrored, uint(1))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(1))
// The summary should be returned when filtering by status "scheduled".
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
Status: ptr.String("scheduled"),
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary = (summaryList)[0]
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
require.Equal(t, *summary.NumPending, uint(0))
require.Equal(t, *summary.NumIncompatible, uint(2))
require.Equal(t, *summary.NumErrored, uint(1))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(1))
// The summary should be returned when filtering by team 1.
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
TeamID: ptr.Uint(0),
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary = (summaryList)[0]
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
require.Equal(t, *summary.NumPending, uint(0))
require.Equal(t, *summary.NumIncompatible, uint(2))
require.Equal(t, *summary.NumErrored, uint(1))
require.Equal(t, *summary.NumRan, uint(1))
require.Equal(t, *summary.NumCanceled, uint(1))
// Mark the execution as completed, and make up some host numbers
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status = 'finished', num_pending = 4, num_ran = 5, num_errored = 6, num_canceled = 7, num_incompatible = 8, num_targeted = 9 WHERE execution_id = ?", execID)
return err
})
// Get the summary again
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary = (summaryList)[0] // The summary should have one pending host, one run host and two errored ones.
require.Equal(t, *summary.NumPending, uint(4))
require.Equal(t, *summary.NumRan, uint(5))
require.Equal(t, *summary.NumErrored, uint(6))
require.Equal(t, *summary.NumCanceled, uint(7))
require.Equal(t, *summary.NumIncompatible, uint(8))
require.Equal(t, *summary.NumTargeted, uint(9))
}
func testBatchScriptSchedule(t *testing.T, ds *Datastore) {
ctx := t.Context()
user := test.NewUser(t, ds, "user1", "user@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
host4 := test.NewHost(t, ds, "host4", "10.0.0.5", "host4key", "host4uuid", time.Now())
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.6", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
test.SetOrbitEnrollment(t, host1, ds)
test.SetOrbitEnrollment(t, host2, ds)
test.SetOrbitEnrollment(t, host3, ds)
test.SetOrbitEnrollment(t, hostTeam1, ds)
test.SetOrbitEnrollment(t, host4, ds)
test.SetOrbitEnrollment(t, hostWindows, ds)
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo hi",
})
require.NoError(t, err)
scheduledTime := time.Now().Add(10 * time.Hour).Truncate(time.Second).UTC()
execID, err := ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host1.ID, host2.ID, host3.ID}, scheduledTime)
require.NoError(t, err)
require.NotEmpty(t, execID)
jobs, err := ds.GetQueuedJobs(ctx, 10, scheduledTime.Add(10*time.Minute))
require.NoError(t, err)
// Should have scheduled one job
require.NotZero(t, jobs)
// find our job
var job *fleet.Job
for _, j := range jobs {
if j.Name == fleet.BatchActivityScriptsJobName {
job = j
}
}
require.NotNil(t, job)
require.Equal(t, fleet.BatchActivityScriptsJobName, job.Name)
// Make sure the name matches the name on the worker job
batchJob := worker.BatchScripts{}
require.Equal(t, batchJob.Name(), job.Name)
// Time from DB isn't super accurate
require.Equal(t, scheduledTime.Truncate(time.Minute), job.NotBefore.Truncate(time.Minute))
assert.JSONEq(t, fmt.Sprintf(`{"execution_id":%q}`, execID), string(*job.Args))
batchActivity, err := ds.GetBatchActivity(ctx, execID)
require.NoError(t, err)
require.Equal(t, execID, batchActivity.BatchExecutionID)
require.Nil(t, batchActivity.StartedAt)
require.Equal(t, user.ID, *batchActivity.UserID)
require.Equal(t, script.ID, *batchActivity.ScriptID)
require.Equal(t, script.Name, batchActivity.ScriptName)
require.Equal(t, fleet.BatchExecutionActivityScript, batchActivity.ActivityType)
require.Equal(t, job.ID, *batchActivity.JobID)
require.Equal(t, fleet.ScheduledBatchExecutionScheduled, batchActivity.Status)
require.Equal(t, uint(3), *batchActivity.NumTargeted)
hostResults, err := ds.GetBatchActivityHostResults(ctx, execID)
require.NoError(t, err)
require.Len(t, hostResults, 3)
for _, hostResult := range hostResults {
require.Equal(t, execID, hostResult.BatchExecutionID)
require.Nil(t, hostResult.HostExecutionID)
}
// Run it manually, the same as the job running it but without waiting for the time
err = ds.RunScheduledBatchActivity(ctx, execID)
require.NoError(t, err)
batchActivity, err = ds.GetBatchActivity(ctx, execID)
require.NoError(t, err)
require.Equal(t, execID, batchActivity.BatchExecutionID)
require.NotNil(t, batchActivity.StartedAt)
require.Equal(t, user.ID, *batchActivity.UserID)
require.Equal(t, script.ID, *batchActivity.ScriptID)
require.Equal(t, script.Name, batchActivity.ScriptName)
require.Equal(t, fleet.BatchExecutionActivityScript, batchActivity.ActivityType)
require.Equal(t, job.ID, *batchActivity.JobID)
require.Equal(t, fleet.ScheduledBatchExecutionStarted, batchActivity.Status)
require.Equal(t, uint(3), *batchActivity.NumTargeted)
hostResults, err = ds.GetBatchActivityHostResults(ctx, execID)
require.NoError(t, err)
require.Len(t, hostResults, 3)
for _, hostResult := range hostResults {
require.Equal(t, execID, hostResult.BatchExecutionID)
require.NotNil(t, hostResult.HostExecutionID)
upcomingScripts, err := ds.ListPendingHostScriptExecutions(ctx, hostResult.HostID, false)
require.NoError(t, err)
require.Len(t, upcomingScripts, 1)
}
// Make sure we can't run the same scheduled script again after it's started
err = ds.RunScheduledBatchActivity(ctx, execID)
require.Error(t, err)
// Make sure we can't run a canceled scheduled script after it's been canceled
execID, err = ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host1.ID}, scheduledTime)
require.NoError(t, err)
err = ds.CancelBatchScript(ctx, execID)
require.NoError(t, err)
err = ds.RunScheduledBatchActivity(ctx, execID)
require.Error(t, err)
// Schedule script where most hosts will fail for various reaons
// These would be checked for some validity before insertion if submitted by a user
execID, err = ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host4.ID, hostWindows.ID, hostTeam1.ID, hostNoScripts.ID, 0xbeef}, scheduledTime)
require.NoError(t, err)
require.NotEmpty(t, execID)
err = ds.RunScheduledBatchActivity(ctx, execID)
require.NoError(t, err)
batchActivity, err = ds.GetBatchActivity(ctx, execID)
require.NoError(t, err)
require.Equal(t, execID, batchActivity.BatchExecutionID)
require.NotNil(t, batchActivity.StartedAt)
require.Equal(t, fleet.ScheduledBatchExecutionStarted, batchActivity.Status)
require.Equal(t, uint(5), *batchActivity.NumTargeted)
executions, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, executions, 1)
require.Equal(t, uint(3), *executions[0].NumIncompatible)
hostResults, err = ds.GetBatchActivityHostResults(ctx, execID)
require.NoError(t, err)
require.Len(t, hostResults, 5)
for _, hostResult := range hostResults {
require.Equal(t, execID, hostResult.BatchExecutionID)
var upcomingScripts []*fleet.HostScriptResult
if hostResult.HostID != 0xbeef {
upcomingScripts, err = ds.ListPendingHostScriptExecutions(ctx, hostResult.HostID, false)
require.NoError(t, err)
}
switch hostResult.HostID {
case host4.ID:
// The only valid host in the group
require.NotNil(t, hostResult.HostExecutionID)
require.Len(t, upcomingScripts, 1)
case hostWindows.ID:
// Bad platform
require.Len(t, upcomingScripts, 0)
require.NotNil(t, hostResult.Error)
require.Equal(t, fleet.BatchExecuteIncompatiblePlatform, *hostResult.Error)
case hostTeam1.ID:
// Bad team
require.Len(t, upcomingScripts, 1)
require.Nil(t, hostResult.Error)
case hostNoScripts.ID:
// Host doesn't support scripts
require.Len(t, upcomingScripts, 0)
require.NotNil(t, hostResult.Error)
require.Equal(t, fleet.BatchExecuteIncompatibleFleetd, *hostResult.Error)
case 0xbeef:
// Host was deleted after scheduling
require.NotNil(t, hostResult.Error)
require.Equal(t, fleet.BatchExecuteInvalidHost, *hostResult.Error)
default:
require.Failf(t, "forgot to check a host", "host_id: %d", hostResult.HostID)
}
}
// Schedule script that we will subsequently cancel.
execID, err = ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host4.ID, hostWindows.ID, hostTeam1.ID, hostNoScripts.ID, 0xbeef}, scheduledTime)
require.NoError(t, err)
require.NotEmpty(t, execID)
err = ds.CancelBatchScript(ctx, execID)
require.NoError(t, err)
// Get the summary again
summaryList, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
ExecutionID: &execID,
})
require.NoError(t, err)
require.Len(t, summaryList, 1)
summary := (summaryList)[0]
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
require.Equal(t, *summary.NumPending, uint(0))
require.Equal(t, *summary.NumIncompatible, uint(0))
require.Equal(t, *summary.NumErrored, uint(0))
require.Equal(t, *summary.NumRan, uint(0))
require.Equal(t, *summary.NumCanceled, uint(5))
}
func testMarkActivitiesAsCompleted(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "user1", "user@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
test.SetOrbitEnrollment(t, hostWindows, ds)
test.SetOrbitEnrollment(t, host1, ds)
test.SetOrbitEnrollment(t, host2, ds)
test.SetOrbitEnrollment(t, host3, ds)
test.SetOrbitEnrollment(t, hostTeam1, ds)
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo hi",
})
require.NoError(t, err)
// Actual good execution
execID, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
require.NoError(t, err)
require.NotEmpty(t, execID)
// Schedule another one
execID2, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
require.NoError(t, err)
require.NotEmpty(t, execID2)
// Get the upcoming activities for each host
host1Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
host2Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
host3Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host3.ID, false, false)
require.NoError(t, err)
// Set host 1 to have a successful script result
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host1.ID,
ExecutionID: host1Upcoming[0].ExecutionID,
Output: "foo",
ExitCode: 0,
})
require.NoError(t, err)
// Set host 2 to have a failed script result
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host2.ID,
ExecutionID: host2Upcoming[0].ExecutionID,
Output: "bar",
ExitCode: -1,
})
require.NoError(t, err)
// Cancel the execution for host 3
_, err = ds.CancelHostUpcomingActivity(ctx, host3.ID, host3Upcoming[0].ExecutionID)
require.NoError(t, err)
// Update the batch activity status to "started"
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status='started' WHERE execution_id IN (?,?)", execID, execID2)
return err
})
// Mark activities as completed
err = ds.MarkActivitiesAsCompleted(ctx)
require.NoError(t, err)
// First activity should be marked as finished and updated accordingly.
batchActivity, err := ds.GetBatchActivity(ctx, execID)
require.NoError(t, err)
require.Equal(t, fleet.ScheduledBatchExecutionFinished, batchActivity.Status)
require.Equal(t, uint(5), *batchActivity.NumTargeted)
require.Equal(t, uint(1), *batchActivity.NumRan)
require.Equal(t, uint(1), *batchActivity.NumErrored)
require.Equal(t, uint(2), *batchActivity.NumIncompatible)
require.Equal(t, uint(1), *batchActivity.NumCanceled)
require.Equal(t, uint(0), *batchActivity.NumPending)
// Second activity should still be in "started" status.
batchActivity2, err := ds.GetBatchActivity(ctx, execID2)
require.NoError(t, err)
require.Equal(t, fleet.ScheduledBatchExecutionStarted, batchActivity2.Status)
// Schedule another batch that we will cancel.
execID3, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
require.NoError(t, err)
require.NotEmpty(t, execID3)
// Update the batch activity status to "started"
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status='started' WHERE execution_id = ?", execID3)
return err
})
// Cancel the batch.
err = ds.CancelBatchScript(ctx, execID3)
require.NoError(t, err)
// First activity should be marked as finished and updated accordingly.
batchActivity, err = ds.GetBatchActivity(ctx, execID3)
require.NoError(t, err)
require.Equal(t, fleet.ScheduledBatchExecutionFinished, batchActivity.Status)
require.Equal(t, uint(5), *batchActivity.NumTargeted)
require.Equal(t, uint(0), *batchActivity.NumRan)
require.Equal(t, uint(0), *batchActivity.NumErrored)
require.Equal(t, uint(2), *batchActivity.NumIncompatible)
require.Equal(t, uint(3), *batchActivity.NumCanceled)
require.Equal(t, uint(0), *batchActivity.NumPending)
// Edge case -- batch activity with no hosts.
// In reality this could happen if all the hosts in a batch get deleted.
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "INSERT INTO batch_activities (execution_id, status, script_id, activity_type) VALUES (?, ?, ?, ?)",
"abc123", fleet.ScheduledBatchExecutionStarted, script.ID, "script")
return err
})
// Mark activities as completed
err = ds.MarkActivitiesAsCompleted(ctx)
require.NoError(t, err)
// Activity should be marked as finished and updated accordingly.
batchActivity, err = ds.GetBatchActivity(ctx, "abc123")
require.NoError(t, err)
require.Equal(t, fleet.ScheduledBatchExecutionFinished, batchActivity.Status)
require.Equal(t, uint(0), *batchActivity.NumTargeted)
require.Equal(t, uint(0), *batchActivity.NumRan)
require.Equal(t, uint(0), *batchActivity.NumErrored)
require.Equal(t, uint(0), *batchActivity.NumIncompatible)
require.Equal(t, uint(0), *batchActivity.NumCanceled)
require.Equal(t, uint(0), *batchActivity.NumPending)
}
func testBatchScriptCancel(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "user1", "user@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
test.SetOrbitEnrollment(t, host1, ds)
test.SetOrbitEnrollment(t, host2, ds)
test.SetOrbitEnrollment(t, host3, ds)
test.SetOrbitEnrollment(t, hostTeam1, ds)
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo hi",
})
require.NoError(t, err)
////
// Immediate execution
//
execID1, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{host1.ID, host2.ID})
require.NoError(t, err)
require.NotEmpty(t, execID1)
summary1, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID1})
require.NoError(t, err)
require.Len(t, summary1, 1)
require.Equal(t, fleet.ScheduledBatchExecutionStarted, summary1[0].Status)
require.False(t, summary1[0].Canceled)
require.Equal(t, uint(2), *summary1[0].NumTargeted)
require.Equal(t, uint(2), *summary1[0].NumPending)
upcoming1, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 1)
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host1.ID,
ExecutionID: upcoming1[0].ExecutionID,
Output: "",
ExitCode: 0,
})
require.NoError(t, err)
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 1)
err = ds.CancelBatchScript(ctx, execID1)
require.NoError(t, err)
summary1, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID1})
require.NoError(t, err)
require.Len(t, summary1, 1)
require.Equal(t, fleet.ScheduledBatchExecutionFinished, summary1[0].Status)
require.True(t, summary1[0].Canceled)
require.Equal(t, uint(0), *summary1[0].NumPending)
require.Equal(t, uint(1), *summary1[0].NumRan)
require.Equal(t, uint(2), *summary1[0].NumTargeted)
require.Equal(t, uint(1), *summary1[0].NumCanceled)
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 0)
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming1, 0)
////
// Future execution
//
execID2, err := ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host1.ID, host2.ID}, time.Now().Add(2*time.Hour))
require.NoError(t, err)
require.NotEmpty(t, execID2)
summary2, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID2})
require.NoError(t, err)
require.Len(t, summary2, 1)
require.Equal(t, fleet.ScheduledBatchExecutionScheduled, summary2[0].Status)
require.False(t, summary2[0].Canceled)
require.Equal(t, uint(2), *summary2[0].NumTargeted)
require.Equal(t, uint(2), *summary2[0].NumPending)
upcoming2, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 0)
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 0)
err = ds.CancelBatchScript(ctx, execID2)
require.NoError(t, err)
summary2, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID2})
require.NoError(t, err)
require.Len(t, summary2, 1)
require.Equal(t, fleet.ScheduledBatchExecutionFinished, summary2[0].Status)
require.True(t, summary2[0].Canceled)
require.Equal(t, uint(0), *summary2[0].NumPending)
require.Equal(t, uint(0), *summary2[0].NumRan)
require.Equal(t, uint(2), *summary2[0].NumCanceled)
require.Equal(t, uint(2), *summary2[0].NumTargeted)
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 0)
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
require.NoError(t, err)
require.Len(t, upcoming2, 0)
}
func testDeleteScriptActivatesNextActivity(t *testing.T, ds *Datastore) {
ctx := t.Context()
u := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// create a couple of scripts
scriptA, err := ds.NewScript(ctx, &fleet.Script{
Name: "a",
ScriptContents: "echo 'a'",
})
require.NoError(t, err)
scriptB, err := ds.NewScript(ctx, &fleet.Script{
Name: "b",
ScriptContents: "echo 'b'",
})
require.NoError(t, err)
// create some hosts
hosts := make([]*fleet.Host, 4)
for i := range hosts {
host, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(fmt.Sprint(i)),
UUID: fmt.Sprint(i),
Hostname: fmt.Sprintf("%d-foo.local", i),
PrimaryIP: fmt.Sprintf("192.168.1.%d", i),
PrimaryMac: fmt.Sprintf("30-65-EC-6F-C4-5%d", i),
})
require.NoError(t, err)
hosts[i] = host
}
// enqueue scripts executions:
// * hosts[0]: a, b
// * hosts[1]: a, b
// * hosts[2]: b, a
// * hosts[3]: b
execHost0ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[0].ID,
ScriptID: &scriptA.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
execHost0ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[0].ID,
ScriptID: &scriptB.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
execHost1ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[1].ID,
ScriptID: &scriptA.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
execHost1ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[1].ID,
ScriptID: &scriptB.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
execHost2ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[2].ID,
ScriptID: &scriptB.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
execHost2ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[2].ID,
ScriptID: &scriptA.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
execHost3ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[3].ID,
ScriptID: &scriptB.ID,
UserID: &u.ID,
SyncRequest: true,
})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptA.ExecutionID, execHost0ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptA.ExecutionID, execHost1ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptB.ExecutionID, execHost2ScriptA.ExecutionID)
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptB.ExecutionID)
// delete scriptA removes pending upcoming activity and activates next activity
err = ds.DeleteScript(ctx, scriptA.ID)
require.NoError(t, err)
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptB.ExecutionID)
}
func testBatchSetScriptActivatesNextActivity(t *testing.T, ds *Datastore) {
ctx := t.Context()
// batch-set some scripts
scripts, err := ds.BatchSetScripts(ctx, nil, []*fleet.Script{
{Name: "A", ScriptContents: "C1"},
{Name: "B", ScriptContents: "C2"},
{Name: "C", ScriptContents: "C3"},
})
require.NoError(t, err)
// index scripts by name
scriptByName := make(map[string]uint)
for _, s := range scripts {
scriptByName[s.Name] = s.ID
}
// create some hosts
hosts := make([]*fleet.Host, 4)
for i := range hosts {
host, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(fmt.Sprint(i)),
UUID: fmt.Sprint(i),
Hostname: fmt.Sprintf("%d-foo.local", i),
PrimaryIP: fmt.Sprintf("192.168.1.%d", i),
PrimaryMac: fmt.Sprintf("30-65-EC-6F-C4-5%d", i),
})
require.NoError(t, err)
hosts[i] = host
}
// enqeue script executions:
// * hosts[0]: A, C, A, B
// * hosts[1]: B, B, C
// * hosts[2]: C, A
// * hosts[3]: A, B, C
execHost0ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
})
require.NoError(t, err)
execHost0ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
})
require.NoError(t, err)
execHost0ScriptA2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
})
require.NoError(t, err)
execHost0ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
})
require.NoError(t, err)
execHost1ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[1].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
})
require.NoError(t, err)
execHost1ScriptB2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[1].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
})
require.NoError(t, err)
execHost1ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[1].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
})
require.NoError(t, err)
execHost2ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[2].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
})
require.NoError(t, err)
execHost2ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[2].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
})
require.NoError(t, err)
execHost3ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[3].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
})
require.NoError(t, err)
execHost3ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[3].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
})
require.NoError(t, err)
execHost3ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: hosts[3].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptA.ExecutionID, execHost0ScriptC.ExecutionID, execHost0ScriptA2.ExecutionID, execHost0ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptB.ExecutionID, execHost1ScriptB2.ExecutionID, execHost1ScriptC.ExecutionID)
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptC.ExecutionID, execHost2ScriptA.ExecutionID)
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptA.ExecutionID, execHost3ScriptB.ExecutionID, execHost3ScriptC.ExecutionID)
// no change
_, err = ds.BatchSetScripts(ctx, nil, []*fleet.Script{
{Name: "A", ScriptContents: "C1"},
{Name: "B", ScriptContents: "C2"},
{Name: "C", ScriptContents: "C3"},
})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptA.ExecutionID, execHost0ScriptC.ExecutionID, execHost0ScriptA2.ExecutionID, execHost0ScriptB.ExecutionID)
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptB.ExecutionID, execHost1ScriptB2.ExecutionID, execHost1ScriptC.ExecutionID)
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptC.ExecutionID, execHost2ScriptA.ExecutionID)
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptA.ExecutionID, execHost3ScriptB.ExecutionID, execHost3ScriptC.ExecutionID)
// batch-set removes A, updates B and creates D, cancelling any pending A and B executions
_, err = ds.BatchSetScripts(ctx, nil, []*fleet.Script{
{Name: "B", ScriptContents: "C2updated"},
{Name: "C", ScriptContents: "C3"},
{Name: "D", ScriptContents: "C4"},
})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptC.ExecutionID)
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptC.ExecutionID)
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptC.ExecutionID)
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptC.ExecutionID)
// batch-set remove all
_, err = ds.BatchSetScripts(ctx, nil, []*fleet.Script{})
require.NoError(t, err)
checkUpcomingActivities(t, ds, hosts[0])
checkUpcomingActivities(t, ds, hosts[1])
checkUpcomingActivities(t, ds, hosts[2])
checkUpcomingActivities(t, ds, hosts[3])
}
// Test updating a script to match another script's contents
func testUpdateScriptToDuplicateContent(t *testing.T, ds *Datastore) {
ctx := t.Context()
// Create two scripts with different content
script1, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo hello",
})
require.NoError(t, err)
script2, err := ds.NewScript(ctx, &fleet.Script{
Name: "script2.sh",
ScriptContents: "echo world",
})
require.NoError(t, err)
// Get initial content IDs
s1, err := ds.Script(ctx, script1.ID)
require.NoError(t, err)
s2, err := ds.Script(ctx, script2.ID)
require.NoError(t, err)
initialContentID1 := s1.ScriptContentID
initialContentID2 := s2.ScriptContentID
require.NotEqual(t, initialContentID1, initialContentID2)
// Update script2 to have the same content as script1
// This should NOT cause a duplicate key error
_, err = ds.UpdateScriptContents(ctx, script2.ID, "echo hello")
require.NoError(t, err)
// ScriptContents is not populated from the DB, check via GetScriptContents
// GetScriptContents takes a script ID, not script_content_id
updatedContents, err := ds.GetScriptContents(ctx, script2.ID)
require.NoError(t, err)
require.Equal(t, "echo hello", string(updatedContents))
// Verify both scripts now share the same content ID
s1After, err := ds.Script(ctx, script1.ID)
require.NoError(t, err)
s2After, err := ds.Script(ctx, script2.ID)
require.NoError(t, err)
require.Equal(t, s1After.ScriptContentID, s2After.ScriptContentID)
require.Equal(t, initialContentID1, s2After.ScriptContentID)
// Verify the old content ID was cleaned up
var count int
err = sqlx.GetContext(ctx, ds.reader(ctx), &count,
`SELECT COUNT(*) FROM script_contents WHERE id = ?`, initialContentID2)
require.NoError(t, err)
require.Equal(t, 0, count, "old script content should be deleted")
}
// Test modifying a script whose content currently matches another script's content
func testUpdateSharedScriptContent(t *testing.T, ds *Datastore) {
ctx := t.Context()
// Create two scripts with the SAME content
sharedContent := "echo shared"
script1, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: sharedContent,
})
require.NoError(t, err)
script2, err := ds.NewScript(ctx, &fleet.Script{
Name: "script2.sh",
ScriptContents: sharedContent,
})
require.NoError(t, err)
// Verify they share the same content ID
s1, err := ds.Script(ctx, script1.ID)
require.NoError(t, err)
s2, err := ds.Script(ctx, script2.ID)
require.NoError(t, err)
require.Equal(t, s1.ScriptContentID, s2.ScriptContentID)
// Update script1 to different content
updated, err := ds.UpdateScriptContents(ctx, script1.ID, "echo modified")
require.NoError(t, err)
// ScriptContents is not populated from the DB, check via GetScriptContents
// GetScriptContents takes a script ID, not script_content_id
updatedContents, err := ds.GetScriptContents(ctx, script1.ID)
require.NoError(t, err)
require.Equal(t, "echo modified", string(updatedContents))
// CRITICAL: Verify script2 still has the original content
s2After, err := ds.Script(ctx, script2.ID)
require.NoError(t, err)
s2Contents, err := ds.GetScriptContents(ctx, script2.ID)
require.NoError(t, err)
require.Equal(t, sharedContent, string(s2Contents))
require.NotEqual(t, updated.ScriptContentID, s2After.ScriptContentID)
}
// Test updating script to same content -- a no-op case
func testUpdateScriptToSameContent(t *testing.T, ds *Datastore) {
ctx := t.Context()
// Create a script
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script.sh",
ScriptContents: "echo hello",
})
require.NoError(t, err)
s, err := ds.Script(ctx, script.ID)
require.NoError(t, err)
originalContentID := s.ScriptContentID
// Update with the same content
_, err = ds.UpdateScriptContents(ctx, script.ID, "echo hello")
require.NoError(t, err)
updatedContents, err := ds.GetScriptContents(ctx, script.ID)
require.NoError(t, err)
require.Equal(t, "echo hello", string(updatedContents))
// Verify content ID hasn't changed
sAfter, err := ds.Script(ctx, script.ID)
require.NoError(t, err)
require.Equal(t, originalContentID, sAfter.ScriptContentID)
}