mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves # Dismisses some gosec rules in test code where they do not apply, since they show up when running `golangci-lint run` locally and make it harder to spot newly introduced errors. # Checklist for submitter ## Testing - [x] Added/updated automated tests - [ ] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [ ] QA'd all new/changed functionality manually
3222 lines
115 KiB
Go
3222 lines
115 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
common_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/test"
|
|
"github.com/fleetdm/fleet/v4/server/worker"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestScripts(t *testing.T) {
|
|
ds := CreateMySQLDS(t)
|
|
|
|
cases := []struct {
|
|
name string
|
|
fn func(t *testing.T, ds *Datastore)
|
|
}{
|
|
{"HostScriptResult", testHostScriptResult},
|
|
{"DEPRestoredHost", testListPendingScriptDEPRestoration},
|
|
{"Scripts", testScripts},
|
|
{"ListScripts", testListScripts},
|
|
{"GetHostScriptDetails", testGetHostScriptDetails},
|
|
{"BatchSetScripts", testBatchSetScripts},
|
|
{"TestLockHostViaScript", testLockHostViaScript},
|
|
{"TestUnlockHostViaScript", testUnlockHostViaScript},
|
|
{"TestLockUnlockWipeViaScripts", testLockUnlockWipeViaScripts},
|
|
{"TestLockUnlockManually", testLockUnlockManually},
|
|
{"TestInsertScriptContents", testInsertScriptContents},
|
|
{"TestCleanupUnusedScriptContents", testCleanupUnusedScriptContents},
|
|
{"TestGetAnyScriptContents", testGetAnyScriptContents},
|
|
{"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy},
|
|
{"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy},
|
|
{"UpdateScriptContents", testUpdateScriptContents},
|
|
{"UpdateScriptToDuplicateContent", testUpdateScriptToDuplicateContent},
|
|
{"UpdateSharedScriptContent", testUpdateSharedScriptContent},
|
|
{"UpdateScriptToSameContent", testUpdateScriptToSameContent},
|
|
{"UpdateDeletingUpcomingScriptExecutions", testUpdateDeletingUpcomingScriptExecutions},
|
|
{"BatchExecute", testBatchExecute},
|
|
{"BatchExecuteWithStatus", testBatchExecuteWithStatus},
|
|
{"BatchScriptSchedule", testBatchScriptSchedule},
|
|
{"BatchScriptCancel", testBatchScriptCancel},
|
|
{"TestMarkActivitiesAsCompleted", testMarkActivitiesAsCompleted},
|
|
{"DeleteScriptActivatesNextActivity", testDeleteScriptActivatesNextActivity},
|
|
{"BatchSetScriptActivatesNextActivity", testBatchSetScriptActivatesNextActivity},
|
|
{"CountHostScriptAttempts", testCountHostScriptAttempts},
|
|
{"ScriptModificationResetsAttemptNumber", testScriptModificationResetsAttemptNumber},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
defer TruncateTables(t, ds)
|
|
|
|
c.fn(t, ds)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testHostScriptResult(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// no script saved yet
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Empty(t, pending)
|
|
|
|
_, err = ds.GetHostScriptExecutionResult(ctx, "abc")
|
|
require.Error(t, err)
|
|
var nfe *common_mysql.NotFoundError
|
|
require.ErrorAs(t, err, &nfe)
|
|
|
|
// create a createdScript execution request (with a user)
|
|
u := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
createdScript1, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo",
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, createdScript1.ID)
|
|
require.NotEmpty(t, createdScript1.ExecutionID)
|
|
require.Equal(t, uint(1), createdScript1.HostID)
|
|
require.NotEmpty(t, createdScript1.ExecutionID)
|
|
require.Equal(t, "echo", createdScript1.ScriptContents)
|
|
require.Nil(t, createdScript1.ExitCode)
|
|
require.Empty(t, createdScript1.Output)
|
|
require.NotNil(t, createdScript1.UserID)
|
|
require.Equal(t, u.ID, *createdScript1.UserID)
|
|
require.True(t, createdScript1.SyncRequest)
|
|
// createdScript1 is now activated, as the queue was empty
|
|
|
|
// the script execution is now listed as pending for this host
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
require.Equal(t, createdScript1.ID, pending[0].ID)
|
|
|
|
// the script execution isn't visible when looking at internal-only scripts
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, true)
|
|
require.NoError(t, err)
|
|
require.Empty(t, pending)
|
|
|
|
// record a result for this execution
|
|
hsr, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: createdScript1.ExecutionID,
|
|
Output: "foo",
|
|
Runtime: 2,
|
|
ExitCode: 0,
|
|
Timeout: 300,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, action)
|
|
assert.NotNil(t, hsr)
|
|
|
|
// record a duplicate result for this execution, will be ignored
|
|
hsr, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: createdScript1.ExecutionID,
|
|
Output: "foobarbaz",
|
|
Runtime: 22,
|
|
ExitCode: 1,
|
|
Timeout: 360,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
require.Nil(t, hsr)
|
|
|
|
// it is not pending anymore
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Empty(t, pending)
|
|
|
|
// the script result can be retrieved
|
|
script, err := ds.GetHostScriptExecutionResult(ctx, createdScript1.ExecutionID)
|
|
require.NoError(t, err)
|
|
expectScript := *createdScript1
|
|
expectScript.Output = "foo"
|
|
expectScript.Runtime = 2
|
|
expectScript.ExitCode = ptr.Int64(0)
|
|
expectScript.Timeout = ptr.Int(300)
|
|
expectScript.CreatedAt, script.CreatedAt = time.Time{}, time.Time{}
|
|
require.Equal(t, &expectScript, script)
|
|
|
|
// create another script execution request (null user id this time)
|
|
time.Sleep(time.Millisecond) // ensure a different timestamp
|
|
createdScript2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo2",
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, createdScript2.ID)
|
|
require.NotEmpty(t, createdScript2.ExecutionID)
|
|
require.Nil(t, createdScript2.UserID)
|
|
require.False(t, createdScript2.SyncRequest)
|
|
// createdScript2 is now activated as the queue was empty
|
|
|
|
// the script execution is now listed as pending for this host
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
require.Equal(t, createdScript2.ID, pending[0].ID)
|
|
|
|
// the script result can be retrieved even if it has no result yet
|
|
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript2.ExecutionID)
|
|
require.NoError(t, err)
|
|
expectedScript := *createdScript2
|
|
expectedScript.CreatedAt, script.CreatedAt = time.Time{}, time.Time{}
|
|
require.Equal(t, &expectedScript, script)
|
|
|
|
// record a result for this execution, with an output that is too large
|
|
largeOutput := strings.Repeat("a", 1000) +
|
|
strings.Repeat("b", 1000) +
|
|
strings.Repeat("c", 1000) +
|
|
strings.Repeat("d", 1000) +
|
|
strings.Repeat("e", 1000) +
|
|
strings.Repeat("f", 1000) +
|
|
strings.Repeat("g", 1000) +
|
|
strings.Repeat("h", 1000) +
|
|
strings.Repeat("i", 1000) +
|
|
strings.Repeat("j", 1000) +
|
|
strings.Repeat("k", 1000)
|
|
// Note that the expectation is that the "a"s get truncated
|
|
expectedOutput := strings.Repeat("b", 1000) +
|
|
strings.Repeat("c", 1000) +
|
|
strings.Repeat("d", 1000) +
|
|
strings.Repeat("e", 1000) +
|
|
strings.Repeat("f", 1000) +
|
|
strings.Repeat("g", 1000) +
|
|
strings.Repeat("h", 1000) +
|
|
strings.Repeat("i", 1000) +
|
|
strings.Repeat("j", 1000) +
|
|
strings.Repeat("k", 1000)
|
|
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: createdScript2.ExecutionID,
|
|
Output: largeOutput,
|
|
Runtime: 10,
|
|
ExitCode: 1,
|
|
Timeout: 300,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// the script result can be retrieved
|
|
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript2.ExecutionID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, expectedOutput, script.Output)
|
|
|
|
// create an async execution request
|
|
time.Sleep(time.Millisecond) // ensure a different timestamp
|
|
createdScript3, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo 3",
|
|
UserID: &u.ID,
|
|
SyncRequest: false,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, createdScript3.ID)
|
|
require.NotEmpty(t, createdScript3.ExecutionID)
|
|
require.Equal(t, uint(1), createdScript3.HostID)
|
|
require.NotEmpty(t, createdScript3.ExecutionID)
|
|
require.Equal(t, "echo 3", createdScript3.ScriptContents)
|
|
require.Nil(t, createdScript3.ExitCode)
|
|
require.Empty(t, createdScript3.Output)
|
|
require.NotNil(t, createdScript3.UserID)
|
|
require.Equal(t, u.ID, *createdScript3.UserID)
|
|
require.False(t, createdScript3.SyncRequest)
|
|
// createdScript3 is now activated as the queue was empty
|
|
|
|
// the script execution is now listed as pending for this host
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
require.Equal(t, createdScript3.ID, pending[0].ID)
|
|
|
|
// modify the upcoming script to be a sync script that has
|
|
// been pending for a long time doesn't change result
|
|
// https://github.com/fleetdm/fleet/issues/22866#issuecomment-2575961141
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "UPDATE upcoming_activities SET created_at = ?, payload = JSON_SET(payload, '$.sync_request', ?) WHERE id = ?",
|
|
time.Now().Add(-24*time.Hour), true, createdScript3.ID)
|
|
return err
|
|
})
|
|
|
|
// the script is still pending
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
require.Equal(t, createdScript3.ExecutionID, pending[0].ExecutionID)
|
|
|
|
// check that scripts with large unsigned error codes get
|
|
// converted to signed error codes
|
|
createdUnsignedScript, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo",
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// record a result for createdScript3 so that the unsigned script gets activated
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: createdScript3.ExecutionID,
|
|
Output: "foo",
|
|
Runtime: 1,
|
|
ExitCode: 0,
|
|
Timeout: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
// createdUnsignedScript is now activated, record its result
|
|
|
|
unsignedScriptResult, _, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: createdUnsignedScript.ExecutionID,
|
|
Output: "foo",
|
|
Runtime: 1,
|
|
ExitCode: math.MaxUint32,
|
|
Timeout: 300,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, -1, *unsignedScriptResult.ExitCode)
|
|
}
|
|
|
|
func testListPendingScriptDEPRestoration(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
host := test.NewHost(t, ds, "host", "10.0.0.1", "1", "uuid1", time.Now())
|
|
|
|
// no script saved yet
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, host.ID, false)
|
|
require.NoError(t, err)
|
|
require.Empty(t, pending)
|
|
|
|
// create a createdScript execution request (with a user)
|
|
u := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
createdScript, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
ScriptContents: "echo",
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, createdScript.ID)
|
|
require.NotEmpty(t, createdScript.ExecutionID)
|
|
require.Equal(t, uint(1), createdScript.HostID)
|
|
require.NotEmpty(t, createdScript.ExecutionID)
|
|
require.Equal(t, "echo", createdScript.ScriptContents)
|
|
require.Nil(t, createdScript.ExitCode)
|
|
require.Empty(t, createdScript.Output)
|
|
require.NotNil(t, createdScript.UserID)
|
|
require.Equal(t, u.ID, *createdScript.UserID)
|
|
require.True(t, createdScript.SyncRequest)
|
|
|
|
// the script execution is now listed as pending for this host
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, host.ID, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
require.Equal(t, createdScript.ID, pending[0].ID)
|
|
|
|
// Set LastEnrolledAt before deleting the host (simulating a DEP enrolled host)
|
|
host.LastEnrolledAt = time.Now()
|
|
|
|
err = ds.DeleteHost(ctx, host.ID)
|
|
require.NoError(t, err)
|
|
|
|
err = ds.RestoreMDMApplePendingDEPHost(ctx, host)
|
|
require.NoError(t, err)
|
|
|
|
// the script execution is no longer listed as pending for this host
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, host.ID, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 0)
|
|
}
|
|
|
|
func testScripts(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// get unknown script
|
|
_, err := ds.Script(ctx, 123)
|
|
var nfe fleet.NotFoundError
|
|
require.ErrorAs(t, err, &nfe)
|
|
|
|
// get unknown script contents
|
|
_, err = ds.GetScriptContents(ctx, 123)
|
|
require.ErrorAs(t, err, &nfe)
|
|
_, err = ds.GetAnyScriptContents(ctx, 123)
|
|
require.ErrorAs(t, err, &nfe)
|
|
|
|
// create global scriptGlobal
|
|
scriptGlobal, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
ScriptContents: "echo",
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, scriptGlobal.ID)
|
|
require.Nil(t, scriptGlobal.TeamID)
|
|
require.Equal(t, "a", scriptGlobal.Name)
|
|
require.Empty(t, scriptGlobal.ScriptContents) // we don't return the contents
|
|
|
|
// get the global script
|
|
script, err := ds.Script(ctx, scriptGlobal.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, scriptGlobal, script)
|
|
|
|
// get the global script contents
|
|
contents, err := ds.GetScriptContents(ctx, scriptGlobal.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo", string(contents))
|
|
contents, err = ds.GetAnyScriptContents(ctx, scriptGlobal.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo", string(contents))
|
|
|
|
// create team script but team does not exist
|
|
_, err = ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
TeamID: ptr.Uint(123),
|
|
ScriptContents: "echo",
|
|
})
|
|
require.Error(t, err)
|
|
var fkErr fleet.ForeignKeyError
|
|
require.ErrorAs(t, err, &fkErr)
|
|
|
|
// create a team and a script for that team with the same name as global
|
|
tm, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name()})
|
|
require.NoError(t, err)
|
|
scriptTeam, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
TeamID: &tm.ID,
|
|
ScriptContents: "echo 'team'",
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, scriptGlobal.ID, scriptTeam.ID)
|
|
require.NotNil(t, scriptTeam.TeamID)
|
|
require.Equal(t, tm.ID, *scriptTeam.TeamID)
|
|
|
|
// get the team script
|
|
script, err = ds.Script(ctx, scriptTeam.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, scriptTeam, script)
|
|
|
|
// get the team script contents
|
|
contents, err = ds.GetScriptContents(ctx, scriptTeam.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo 'team'", string(contents))
|
|
contents, err = ds.GetAnyScriptContents(ctx, scriptTeam.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo 'team'", string(contents))
|
|
|
|
// try to create another team script with the same name
|
|
_, err = ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
TeamID: &tm.ID,
|
|
ScriptContents: "echo",
|
|
})
|
|
require.Error(t, err)
|
|
var existsErr fleet.AlreadyExistsError
|
|
require.ErrorAs(t, err, &existsErr)
|
|
|
|
// same for a global script
|
|
_, err = ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
ScriptContents: "echo",
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorAs(t, err, &existsErr)
|
|
|
|
// create a script with a different name for the team works
|
|
_, err = ds.NewScript(ctx, &fleet.Script{
|
|
Name: "b",
|
|
TeamID: &tm.ID,
|
|
ScriptContents: "echo",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// deleting script "a for the team, then we can re-create it
|
|
err = ds.DeleteScript(ctx, scriptTeam.ID)
|
|
require.NoError(t, err)
|
|
scriptTeam2, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
TeamID: &tm.ID,
|
|
ScriptContents: "echo",
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, scriptTeam.ID, scriptTeam2.ID)
|
|
}
|
|
|
|
func testListScripts(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// create three teams
|
|
tm1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
tm2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
|
|
require.NoError(t, err)
|
|
tm3, err := ds.NewTeam(ctx, &fleet.Team{Name: "team3"})
|
|
require.NoError(t, err)
|
|
|
|
// create 5 scripts for no team and team 1
|
|
for i := 0; i < 5; i++ {
|
|
_, err = ds.NewScript(ctx, &fleet.Script{
|
|
Name: string('a' + byte(i)), // i.e. "a", "b", "c", ...
|
|
ScriptContents: "echo",
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = ds.NewScript(ctx, &fleet.Script{Name: string('a' + byte(i)), TeamID: &tm1.ID, ScriptContents: "echo"})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// create a single script for team 2
|
|
_, err = ds.NewScript(ctx, &fleet.Script{Name: "a", TeamID: &tm2.ID, ScriptContents: "echo"})
|
|
require.NoError(t, err)
|
|
|
|
cases := []struct {
|
|
opts fleet.ListOptions
|
|
teamID *uint
|
|
wantNames []string
|
|
wantMeta *fleet.PaginationMetadata
|
|
}{
|
|
{
|
|
opts: fleet.ListOptions{},
|
|
wantNames: []string{"a", "b", "c", "d", "e"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{PerPage: 2},
|
|
wantNames: []string{"a", "b"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 1, PerPage: 2},
|
|
wantNames: []string{"c", "d"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: true},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 2, PerPage: 2},
|
|
wantNames: []string{"e"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{PerPage: 3},
|
|
teamID: &tm1.ID,
|
|
wantNames: []string{"a", "b", "c"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 1, PerPage: 3},
|
|
teamID: &tm1.ID,
|
|
wantNames: []string{"d", "e"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 2, PerPage: 3},
|
|
teamID: &tm1.ID,
|
|
wantNames: nil,
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{PerPage: 3},
|
|
teamID: &tm2.ID,
|
|
wantNames: []string{"a"},
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 0, PerPage: 2},
|
|
teamID: &tm3.ID,
|
|
wantNames: nil,
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(fmt.Sprintf("%v: %#v", c.teamID, c.opts), func(t *testing.T) {
|
|
// always include metadata
|
|
c.opts.IncludeMetadata = true
|
|
scripts, meta, err := ds.ListScripts(ctx, c.teamID, c.opts)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, len(c.wantNames), len(scripts))
|
|
require.Equal(t, c.wantMeta, meta)
|
|
|
|
var gotNames []string
|
|
if len(scripts) > 0 {
|
|
gotNames = make([]string, len(scripts))
|
|
for i, s := range scripts {
|
|
gotNames[i] = s.Name
|
|
require.Equal(t, c.teamID, s.TeamID)
|
|
}
|
|
}
|
|
require.Equal(t, c.wantNames, gotNames)
|
|
})
|
|
}
|
|
}
|
|
|
|
func testGetHostScriptDetails(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
t.Cleanup(func() { ds.testActivateSpecificNextActivities = nil })
|
|
|
|
names := []string{"script-1.sh", "script-2.sh", "script-3.sh", "script-4.sh", "script-5.sh"}
|
|
for _, r := range append(names[1:], names[0]) {
|
|
_, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: r,
|
|
ScriptContents: "echo " + r,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// create a windows script as well
|
|
_, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script-6.ps1",
|
|
ScriptContents: `Write-Host "Hello, World!"`,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
scripts, _, err := ds.ListScripts(ctx, nil, fleet.ListOptions{})
|
|
require.NoError(t, err)
|
|
require.Len(t, scripts, 6)
|
|
|
|
insertResults := func(t *testing.T, hostID uint, script *fleet.Script, createdAt time.Time, execID string, exitCode *int64) {
|
|
var scriptID *uint
|
|
if script.ID != 0 {
|
|
scriptID = &script.ID
|
|
}
|
|
hsr, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptID: scriptID,
|
|
})
|
|
require.NoError(t, err)
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, `UPDATE upcoming_activities SET execution_id = ?, created_at = ? WHERE execution_id = ?`,
|
|
execID, createdAt, hsr.ExecutionID)
|
|
return err
|
|
})
|
|
if exitCode != nil {
|
|
ds.testActivateSpecificNextActivities = []string{execID}
|
|
act, err := ds.activateNextUpcomingActivity(ctx, ds.writer(ctx), hostID, "")
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, act, ds.testActivateSpecificNextActivities)
|
|
ds.testActivateSpecificNextActivities = nil
|
|
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: execID,
|
|
ExitCode: int(*exitCode),
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// force the test timestamp
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, `UPDATE host_script_results SET created_at = ? WHERE execution_id = ?`,
|
|
createdAt, execID)
|
|
return err
|
|
})
|
|
}
|
|
}
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
|
|
// add some results for an ad-hoc, non-saved script, should not be included in results
|
|
// create it first so that this one gets activated, and the other ones are never
|
|
// activated automatically.
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 42,
|
|
ScriptContents: "echo script-6",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// add some results for script-1
|
|
insertResults(t, 42, scripts[0], now.Add(-3*time.Minute), "execution-1-1", nil)
|
|
insertResults(t, 42, scripts[0], now.Add(-1*time.Minute), "execution-1-2", nil) // last execution for script-1, status "pending"
|
|
insertResults(t, 42, scripts[0], now.Add(-2*time.Minute), "execution-1-3", nil)
|
|
|
|
// add some results for script-2
|
|
insertResults(t, 42, scripts[1], now.Add(-3*time.Minute), "execution-2-1", ptr.Int64(0))
|
|
insertResults(t, 42, scripts[1], now.Add(-1*time.Minute), "execution-2-2", ptr.Int64(1)) // last execution for script-2, status "error"
|
|
|
|
// add some results for script-3
|
|
insertResults(t, 42, scripts[2], now.Add(-1*time.Minute), "execution-3-1", ptr.Int64(0))
|
|
insertResults(t, 42, scripts[2], now.Add(-1*time.Minute), "execution-3-2", ptr.Int64(0)) // last execution for script-3, status "ran"
|
|
insertResults(t, 42, scripts[2], now.Add(-2*time.Minute), "execution-3-3", ptr.Int64(0))
|
|
|
|
// add some results for script-4
|
|
insertResults(t, 42, scripts[3], now.Add(-1*time.Minute), "execution-4-1", ptr.Int64(-2)) // last execution for script-4, status "error"
|
|
|
|
// add a pending and a completed script execution for script-5
|
|
insertResults(t, 42, scripts[4], now.Add(-2*time.Minute), "execution-5-1", ptr.Int64(0))
|
|
insertResults(t, 42, scripts[4], now.Add(-3*time.Minute), "execution-5-2", nil) // upcoming is always latest, regardless of timestamp
|
|
|
|
t.Run("results match expected formatting and filtering", func(t *testing.T) {
|
|
res, _, err := ds.GetHostScriptDetails(ctx, 42, nil, fleet.ListOptions{}, "")
|
|
require.NoError(t, err)
|
|
require.Len(t, res, 6)
|
|
for _, r := range res {
|
|
switch r.ScriptID {
|
|
case scripts[0].ID:
|
|
require.Equal(t, scripts[0].Name, r.Name)
|
|
require.NotNil(t, r.LastExecution)
|
|
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
|
|
require.Equal(t, "execution-1-2", r.LastExecution.ExecutionID)
|
|
require.Equal(t, "pending", r.LastExecution.Status)
|
|
case scripts[1].ID:
|
|
require.Equal(t, scripts[1].Name, r.Name)
|
|
require.NotNil(t, r.LastExecution)
|
|
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
|
|
require.Equal(t, "execution-2-2", r.LastExecution.ExecutionID)
|
|
require.Equal(t, "error", r.LastExecution.Status)
|
|
case scripts[2].ID:
|
|
require.Equal(t, scripts[2].Name, r.Name)
|
|
require.NotNil(t, r.LastExecution)
|
|
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
|
|
require.Equal(t, "execution-3-2", r.LastExecution.ExecutionID)
|
|
require.Equal(t, "ran", r.LastExecution.Status)
|
|
case scripts[3].ID:
|
|
require.Equal(t, scripts[3].Name, r.Name)
|
|
require.NotNil(t, r.LastExecution)
|
|
require.Equal(t, now.Add(-1*time.Minute), r.LastExecution.ExecutedAt)
|
|
require.Equal(t, "execution-4-1", r.LastExecution.ExecutionID)
|
|
require.Equal(t, "error", r.LastExecution.Status)
|
|
case scripts[4].ID:
|
|
require.Equal(t, scripts[4].Name, r.Name)
|
|
require.NotNil(t, r.LastExecution)
|
|
// require.Equal(t, now.Add(-3*time.Minute), r.LastExecution.ExecutedAt)
|
|
require.Equal(t, "execution-5-2", r.LastExecution.ExecutionID)
|
|
require.Equal(t, "pending", r.LastExecution.Status)
|
|
case scripts[5].ID:
|
|
require.Equal(t, scripts[5].Name, r.Name)
|
|
require.Nil(t, r.LastExecution)
|
|
default:
|
|
t.Errorf("unexpected script id: %d", r.ScriptID)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("empty slice returned if no scripts", func(t *testing.T) {
|
|
res, _, err := ds.GetHostScriptDetails(ctx, 42, ptr.Uint(1), fleet.ListOptions{}, "") // team 1 has no scripts
|
|
require.NoError(t, err)
|
|
require.NotNil(t, res)
|
|
require.Len(t, res, 0)
|
|
})
|
|
|
|
t.Run("list options are supported", func(t *testing.T) {
|
|
cases := []struct {
|
|
opts fleet.ListOptions
|
|
teamID *uint
|
|
wantNames []string
|
|
wantMeta *fleet.PaginationMetadata
|
|
}{
|
|
{
|
|
opts: fleet.ListOptions{},
|
|
wantNames: names,
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{PerPage: 2},
|
|
wantNames: names[:2],
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 1, PerPage: 2},
|
|
wantNames: names[2:4],
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: true},
|
|
},
|
|
{
|
|
opts: fleet.ListOptions{Page: 2, PerPage: 2},
|
|
wantNames: names[4:],
|
|
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
|
|
},
|
|
}
|
|
for _, c := range cases {
|
|
t.Run(fmt.Sprintf("%#v", c.opts), func(t *testing.T) {
|
|
// always include metadata
|
|
c.opts.IncludeMetadata = true
|
|
// custom ordering is not supported, always by name
|
|
c.opts.OrderKey = "name"
|
|
results, meta, err := ds.GetHostScriptDetails(ctx, 42, nil, c.opts, "darwin")
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, len(c.wantNames), len(results))
|
|
require.Equal(t, c.wantMeta, meta)
|
|
|
|
var gotNames []string
|
|
if len(results) > 0 {
|
|
gotNames = make([]string, len(results))
|
|
for i, r := range results {
|
|
gotNames[i] = r.Name
|
|
}
|
|
}
|
|
require.Equal(t, c.wantNames, gotNames)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("windows ps1 scripts are supported", func(t *testing.T) {
|
|
res, _, err := ds.GetHostScriptDetails(ctx, 42, nil, fleet.ListOptions{}, "windows")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, res)
|
|
require.Len(t, res, 1)
|
|
require.Equal(t, "script-6.ps1", res[0].Name)
|
|
})
|
|
|
|
t.Run("can check if pending host script results exist", func(t *testing.T) {
|
|
insertResults(t, 42, scripts[2], now.Add(-2*time.Minute), "execution-3-4", nil)
|
|
r, err := ds.IsExecutionPendingForHost(ctx, 42, scripts[2].ID)
|
|
require.NoError(t, err)
|
|
require.True(t, r)
|
|
})
|
|
|
|
t.Run("script deletion cancels pending script runs", func(t *testing.T) {
|
|
insertResults(t, 43, scripts[3], now.Add(-2*time.Minute), "execution-4-4", nil)
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, 43, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
|
|
err = ds.DeleteScript(ctx, scripts[3].ID)
|
|
require.NoError(t, err)
|
|
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 43, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 0)
|
|
})
|
|
}
|
|
|
|
func testBatchSetScripts(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
applyAndExpect := func(newSet []*fleet.Script, tmID *uint, want []*fleet.Script) map[string]uint {
|
|
responseFromSet, err := ds.BatchSetScripts(ctx, tmID, newSet)
|
|
require.NoError(t, err)
|
|
|
|
if tmID == nil {
|
|
tmID = ptr.Uint(0)
|
|
}
|
|
got, _, err := ds.ListScripts(ctx, tmID, fleet.ListOptions{})
|
|
require.NoError(t, err)
|
|
|
|
// compare only the fields we care about
|
|
fromGetByScriptName := make(map[string]uint)
|
|
fromSetByScriptName := make(map[string]uint)
|
|
for _, gotScript := range responseFromSet {
|
|
fromSetByScriptName[gotScript.Name] = gotScript.ID
|
|
}
|
|
for _, gotScript := range got {
|
|
fromGetByScriptName[gotScript.Name] = gotScript.ID
|
|
if gotScript.TeamID != nil && *gotScript.TeamID == 0 {
|
|
gotScript.TeamID = nil
|
|
}
|
|
|
|
require.Equal(t, fromGetByScriptName[gotScript.Name], gotScript.ID)
|
|
gotScript.ID = 0
|
|
gotScript.CreatedAt = time.Time{}
|
|
gotScript.UpdatedAt = time.Time{}
|
|
}
|
|
// order is not guaranteed
|
|
require.ElementsMatch(t, want, got)
|
|
|
|
return fromGetByScriptName
|
|
}
|
|
|
|
// apply empty set for no-team
|
|
applyAndExpect(nil, nil, nil)
|
|
|
|
// create a team
|
|
tm1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "_tm1"})
|
|
require.NoError(t, err)
|
|
|
|
// apply single script set for tm1
|
|
sTm1 := applyAndExpect([]*fleet.Script{
|
|
{Name: "N1", ScriptContents: "C1"},
|
|
}, ptr.Uint(tm1.ID), []*fleet.Script{
|
|
{Name: "N1", TeamID: ptr.Uint(tm1.ID)},
|
|
})
|
|
n1WithTeamID := sTm1["N1"]
|
|
|
|
teamPolicy, err := ds.NewTeamPolicy(ctx, tm1.ID, nil, fleet.PolicyPayload{
|
|
Name: "Team One Policy",
|
|
Query: "SELECT 1",
|
|
Platform: "darwin",
|
|
ScriptID: &n1WithTeamID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// apply single script set for no-team
|
|
sNoTm := applyAndExpect([]*fleet.Script{
|
|
{Name: "N1", ScriptContents: "C1"},
|
|
}, nil, []*fleet.Script{
|
|
{Name: "N1", TeamID: nil},
|
|
})
|
|
n1WithNoTeamId := sNoTm["N1"]
|
|
|
|
noTeamPolicy, err := ds.NewTeamPolicy(ctx, fleet.PolicyNoTeamID, nil, fleet.PolicyPayload{
|
|
Name: "No Team Policy",
|
|
Query: "SELECT 1",
|
|
Platform: "darwin",
|
|
ScriptID: &n1WithNoTeamId,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// apply new script set for tm1
|
|
sTm1b := applyAndExpect([]*fleet.Script{
|
|
{Name: "N1", ScriptContents: "C1"},
|
|
{Name: "N2", ScriptContents: "C2"},
|
|
}, ptr.Uint(tm1.ID), []*fleet.Script{
|
|
{Name: "N1", TeamID: ptr.Uint(tm1.ID)},
|
|
{Name: "N2", TeamID: ptr.Uint(tm1.ID)},
|
|
})
|
|
// name for N1-I1 is unchanged
|
|
require.Equal(t, sTm1["I1"], sTm1b["I1"])
|
|
|
|
// policy still has script associated
|
|
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, n1WithTeamID, *teamPolicy.ScriptID)
|
|
|
|
// apply edited (by contents only) script set for no-team
|
|
sNoTmb := applyAndExpect([]*fleet.Script{
|
|
{Name: "N1", ScriptContents: "C1-changed"},
|
|
}, nil, []*fleet.Script{
|
|
{Name: "N1", TeamID: nil},
|
|
})
|
|
require.Equal(t, sNoTm["I1"], sNoTmb["I1"])
|
|
|
|
// policy still has script associated
|
|
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, n1WithNoTeamId, *noTeamPolicy.ScriptID)
|
|
|
|
// apply edited script (by content only), unchanged script and new
|
|
// script for tm1
|
|
sTm1c := applyAndExpect([]*fleet.Script{
|
|
{Name: "N1", ScriptContents: "C1-updated"}, // content updated
|
|
{Name: "N2", ScriptContents: "C2"}, // unchanged
|
|
{Name: "N3", ScriptContents: "C3"}, // new
|
|
}, ptr.Uint(tm1.ID), []*fleet.Script{
|
|
{Name: "N1", TeamID: ptr.Uint(tm1.ID)}, // content updated
|
|
{Name: "N2", TeamID: ptr.Uint(tm1.ID)}, // unchanged
|
|
{Name: "N3", TeamID: ptr.Uint(tm1.ID)}, // new
|
|
})
|
|
// name for N1-I1 is unchanged
|
|
require.Equal(t, sTm1b["I1"], sTm1c["I1"])
|
|
// identifier for N2-I2 is unchanged
|
|
require.Equal(t, sTm1b["I2"], sTm1c["I2"])
|
|
|
|
// policy still has script associated
|
|
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, n1WithTeamID, *teamPolicy.ScriptID)
|
|
|
|
// add pending scripts on team and no-team and confirm they're shown as pending
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 44,
|
|
ScriptID: &n1WithTeamID,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 45,
|
|
ScriptID: &n1WithNoTeamId,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, 44, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
|
|
// clear scripts for tm1
|
|
applyAndExpect(nil, ptr.Uint(1), nil)
|
|
|
|
// policy on team should not have script assigned
|
|
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, teamPolicy.ScriptID)
|
|
|
|
// no-team policy still has script associated
|
|
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, n1WithNoTeamId, *noTeamPolicy.ScriptID)
|
|
|
|
// team script should no longer be pending, no-team script should still be pending
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 44, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 0)
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 1)
|
|
|
|
// apply only new scripts to no-team
|
|
applyAndExpect([]*fleet.Script{
|
|
{Name: "N4", ScriptContents: "C4"},
|
|
{Name: "N5", ScriptContents: "C5"},
|
|
}, nil, []*fleet.Script{
|
|
{Name: "N4", TeamID: nil},
|
|
{Name: "N5", TeamID: nil},
|
|
})
|
|
|
|
// policy on team should not have script assigned
|
|
teamPolicy, err = ds.Policy(ctx, teamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, teamPolicy.ScriptID)
|
|
|
|
// no-team policy should not have script associated
|
|
noTeamPolicy, err = ds.Policy(ctx, noTeamPolicy.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, noTeamPolicy.ScriptID)
|
|
|
|
// no-team script should no longer be pending
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 45, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, pending, 0)
|
|
}
|
|
|
|
func testLockHostViaScript(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
// no script saved yet
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Empty(t, pending)
|
|
|
|
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
|
|
windowsHostID := uint(1)
|
|
|
|
script := "lock"
|
|
|
|
err = ds.LockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: windowsHostID,
|
|
ScriptContents: script,
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, "windows")
|
|
|
|
require.NoError(t, err)
|
|
|
|
// verify that we have created entries in host_mdm_actions and host_script_results
|
|
status, err := ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: windowsHostID, Platform: "windows", UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "windows", status.HostFleetPlatform)
|
|
require.NotNil(t, status.LockScript)
|
|
assert.Nil(t, status.UnlockScript)
|
|
|
|
s := status.LockScript
|
|
require.Equal(t, script, s.ScriptContents)
|
|
require.Equal(t, windowsHostID, s.HostID)
|
|
require.False(t, s.SyncRequest)
|
|
require.Equal(t, &user.ID, s.UserID)
|
|
|
|
require.True(t, status.IsPendingLock())
|
|
|
|
// simulate a successful result for the lock script execution
|
|
_, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: s.HostID,
|
|
ExecutionID: s.ExecutionID,
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "lock_ref", action)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: windowsHostID, Platform: "windows", UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
require.True(t, status.IsLocked())
|
|
require.False(t, status.IsPendingLock())
|
|
require.False(t, status.IsUnlocked())
|
|
}
|
|
|
|
func testUnlockHostViaScript(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
// no script saved yet
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Empty(t, pending)
|
|
|
|
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
|
|
hostID := uint(1)
|
|
hostUUID := "uuid"
|
|
hostPlatform := "windows"
|
|
host, err := ds.NewHost(ctx, &fleet.Host{
|
|
ID: hostID,
|
|
UUID: hostUUID,
|
|
Platform: hostPlatform,
|
|
OsqueryHostID: &hostUUID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
script := "unlock"
|
|
|
|
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: script,
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, hostPlatform)
|
|
|
|
require.NoError(t, err)
|
|
|
|
// verify that we have created entries in host_mdm_actions and host_script_results
|
|
status, err := ds.GetHostLockWipeStatus(ctx, host)
|
|
require.NoError(t, err)
|
|
require.Equal(t, hostPlatform, status.HostFleetPlatform)
|
|
require.NotNil(t, status.UnlockScript)
|
|
|
|
s := status.UnlockScript
|
|
require.Equal(t, script, s.ScriptContents)
|
|
require.Equal(t, hostID, s.HostID)
|
|
require.False(t, s.SyncRequest)
|
|
require.Equal(t, &user.ID, s.UserID)
|
|
|
|
require.True(t, status.IsPendingUnlock())
|
|
|
|
// simulate a cancel while it's pending unlock
|
|
_, err = ds.CancelHostUpcomingActivity(ctx, s.HostID, s.ExecutionID)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, host)
|
|
require.NoError(t, err)
|
|
require.False(t, status.IsPendingUnlock())
|
|
|
|
// add a new unlock script execution
|
|
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: script,
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, hostPlatform)
|
|
require.NoError(t, err)
|
|
status, err = ds.GetHostLockWipeStatus(ctx, host)
|
|
require.NoError(t, err)
|
|
require.Equal(t, hostPlatform, status.HostFleetPlatform)
|
|
require.NotNil(t, status.UnlockScript)
|
|
s = status.UnlockScript
|
|
|
|
// simulate a successful result for the unlock script execution
|
|
_, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: s.HostID,
|
|
ExecutionID: s.ExecutionID,
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "unlock_ref", action)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, host)
|
|
require.NoError(t, err)
|
|
require.True(t, status.IsUnlocked())
|
|
require.False(t, status.IsPendingUnlock())
|
|
require.False(t, status.IsLocked())
|
|
}
|
|
|
|
func testLockUnlockWipeViaScripts(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
user := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
|
|
for i, platform := range []string{"windows", "linux"} {
|
|
hostID := uint(i + 1) //nolint:gosec // dismiss G115
|
|
|
|
t.Run(platform, func(t *testing.T) {
|
|
status, err := ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
|
|
// default state
|
|
checkLockWipeState(t, status, true, false, false, false, false, false)
|
|
|
|
// record a request to lock the host
|
|
err = ds.LockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: "lock",
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, platform)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, true, false)
|
|
|
|
// simulate a successful result for the lock script execution
|
|
_, action, err := ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: status.LockScript.ExecutionID,
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "lock_ref", action)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, false, true, false, false, false, false)
|
|
|
|
// record a request to unlock the host
|
|
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: "unlock",
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, platform)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, false, true, false, true, false, false)
|
|
|
|
// simulate a failed result for the unlock script execution
|
|
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: status.UnlockScript.ExecutionID,
|
|
ExitCode: -1,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "unlock_ref", action)
|
|
|
|
// still locked
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, false, true, false, false, false, false)
|
|
|
|
// record another request to unlock the host
|
|
err = ds.UnlockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: "unlock",
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, platform)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, false, true, false, true, false, false)
|
|
|
|
// this time simulate a successful result for the unlock script execution
|
|
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: status.UnlockScript.ExecutionID,
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "unlock_ref", action)
|
|
|
|
// host is now unlocked
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, false, false)
|
|
|
|
// record another request to lock the host
|
|
err = ds.LockHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: "lock",
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, platform)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, true, false)
|
|
|
|
// simulate a failed result for the lock script execution
|
|
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: status.LockScript.ExecutionID,
|
|
ExitCode: 2,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "lock_ref", action)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, false, false)
|
|
|
|
switch platform {
|
|
case "windows":
|
|
// need a real MDM-enrolled host for MDM commands
|
|
h, err := ds.NewHost(ctx, &fleet.Host{
|
|
Hostname: "test-host-windows",
|
|
OsqueryHostID: ptr.String("osquery-windows"),
|
|
NodeKey: ptr.String("nodekey-windows"),
|
|
UUID: "test-uuid-windows",
|
|
Platform: "windows",
|
|
})
|
|
require.NoError(t, err)
|
|
windowsEnroll(t, ds, h)
|
|
|
|
// record a request to wipe the host
|
|
wipeCmdUUID := uuid.NewString()
|
|
wipeCmd := &fleet.MDMWindowsCommand{
|
|
CommandUUID: wipeCmdUUID,
|
|
RawCommand: []byte(`<Exec></Exec>`),
|
|
TargetLocURI: "./Device/Vendor/MSFT/RemoteWipe/doWipeProtected",
|
|
}
|
|
err = ds.WipeHostViaWindowsMDM(ctx, h, wipeCmd)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, h)
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, false, true)
|
|
|
|
// TODO: we don't seem to have an easy way to simulate a Windows MDM
|
|
// protocol response, and there are lots of validations happening so we
|
|
// can't just send a simple XML. Will test the rest via integration
|
|
// tests.
|
|
|
|
case "linux":
|
|
// record a request to wipe the host
|
|
err = ds.WipeHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: "wipe",
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, platform)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, false, true)
|
|
|
|
// simulate a failed result for the wipe script execution
|
|
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: status.WipeScript.ExecutionID,
|
|
ExitCode: 1,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "wipe_ref", action)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, false, false)
|
|
|
|
// record another request to wipe the host
|
|
err = ds.WipeHostViaScript(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hostID,
|
|
ScriptContents: "wipe2",
|
|
UserID: &user.ID,
|
|
SyncRequest: false,
|
|
}, platform)
|
|
require.NoError(t, err)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, true, false, false, false, false, true)
|
|
|
|
// simulate a successful result for the wipe script execution
|
|
_, action, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: hostID,
|
|
ExecutionID: status.WipeScript.ExecutionID,
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "wipe_ref", action)
|
|
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: hostID, Platform: platform, UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
checkLockWipeState(t, status, false, false, true, false, false, false)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testLockUnlockManually(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
twoDaysAgo := time.Now().AddDate(0, 0, -2).UTC()
|
|
today := time.Now().UTC()
|
|
err := ds.UnlockHostManually(ctx, 1, "darwin", twoDaysAgo)
|
|
require.NoError(t, err)
|
|
|
|
status, err := ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: 1, Platform: "darwin", UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
require.False(t, status.UnlockRequestedAt.IsZero())
|
|
require.WithinDuration(t, twoDaysAgo, status.UnlockRequestedAt, 1*time.Second)
|
|
|
|
// if the unlock request already exists, it is not overwritten by subsequent
|
|
// requests
|
|
err = ds.UnlockHostManually(ctx, 1, "darwin", today)
|
|
require.NoError(t, err)
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: 1, Platform: "darwin", UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
require.False(t, status.UnlockRequestedAt.IsZero())
|
|
require.WithinDuration(t, twoDaysAgo, status.UnlockRequestedAt, 1*time.Second)
|
|
|
|
// but for a new host, it will set it properly, even if that host already has a
|
|
// host_mdm_actions entry
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "INSERT INTO host_mdm_actions (host_id) VALUES (2)")
|
|
return err
|
|
})
|
|
err = ds.UnlockHostManually(ctx, 2, "darwin", today)
|
|
require.NoError(t, err)
|
|
status, err = ds.GetHostLockWipeStatus(ctx, &fleet.Host{ID: 2, Platform: "darwin", UUID: "uuid"})
|
|
require.NoError(t, err)
|
|
require.False(t, status.UnlockRequestedAt.IsZero())
|
|
require.WithinDuration(t, today, status.UnlockRequestedAt, 1*time.Second)
|
|
}
|
|
|
|
func checkLockWipeState(t *testing.T, status *fleet.HostLockWipeStatus, unlocked, locked, wiped, pendingUnlock, pendingLock, pendingWipe bool) {
|
|
require.Equal(t, unlocked, status.IsUnlocked(), "unlocked")
|
|
require.Equal(t, locked, status.IsLocked(), "locked")
|
|
require.Equal(t, wiped, status.IsWiped(), "wiped")
|
|
require.Equal(t, pendingLock, status.IsPendingLock(), "pending lock")
|
|
require.Equal(t, pendingUnlock, status.IsPendingUnlock(), "pending unlock")
|
|
require.Equal(t, pendingWipe, status.IsPendingWipe(), "pending wipe")
|
|
}
|
|
|
|
type scriptContents struct {
|
|
ID uint `db:"id"`
|
|
Checksum string `db:"md5_checksum"`
|
|
}
|
|
|
|
func testInsertScriptContents(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
contents := `echo foobar;`
|
|
res, err := insertScriptContents(ctx, ds.writer(ctx), contents)
|
|
require.NoError(t, err)
|
|
id, _ := res.LastInsertId()
|
|
require.Equal(t, int64(1), id)
|
|
expectedCS := md5ChecksumScriptContent(contents)
|
|
|
|
// insert same contents again, verify that the checksum and ID stayed the same
|
|
res, err = insertScriptContents(ctx, ds.writer(ctx), contents)
|
|
require.NoError(t, err)
|
|
id, _ = res.LastInsertId()
|
|
require.Equal(t, int64(1), id)
|
|
|
|
stmt := `SELECT id, HEX(md5_checksum) as md5_checksum FROM script_contents WHERE id = ?`
|
|
|
|
var sc []scriptContents
|
|
err = sqlx.SelectContext(ctx, ds.reader(ctx),
|
|
&sc, stmt,
|
|
id,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, sc, 1)
|
|
require.EqualValues(t, id, sc[0].ID)
|
|
require.Equal(t, expectedCS, sc[0].Checksum)
|
|
}
|
|
|
|
func testCleanupUnusedScriptContents(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// create a saved script
|
|
s := &fleet.Script{
|
|
ScriptContents: "echo foobar",
|
|
}
|
|
s, err := ds.NewScript(ctx, s)
|
|
require.NoError(t, err)
|
|
|
|
user1 := test.NewUser(t, ds, "Bob", "bob@example.com", true)
|
|
|
|
// create a sync script execution
|
|
res, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{ScriptContents: "echo something_else", SyncRequest: true})
|
|
require.NoError(t, err)
|
|
|
|
// create a software install that references scripts
|
|
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
|
|
require.NoError(t, err)
|
|
swi, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
|
|
InstallScript: "install-script",
|
|
UninstallScript: "uninstall-script",
|
|
PreInstallQuery: "SELECT 1",
|
|
PostInstallScript: "post-install-script",
|
|
InstallerFile: tfr1,
|
|
StorageID: "storage1",
|
|
Filename: "file1",
|
|
Title: "file1",
|
|
Version: "1.0",
|
|
Source: "apps",
|
|
UserID: user1.ID,
|
|
ValidatedLabels: &fleet.LabelIdentsWithScope{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// delete our saved script without ever executing it
|
|
require.NoError(t, ds.DeleteScript(ctx, s.ID))
|
|
|
|
// validate that script contents still exist
|
|
var sc []scriptContents
|
|
stmt := `SELECT id, HEX(md5_checksum) as md5_checksum FROM script_contents`
|
|
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
|
|
require.NoError(t, err)
|
|
require.Len(t, sc, 5)
|
|
|
|
// this should only remove the script_contents of the saved script, since the sync script is
|
|
// still "in use" by the script execution
|
|
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
|
|
|
|
sc = []scriptContents{}
|
|
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
|
|
require.NoError(t, err)
|
|
require.Len(t, sc, 4)
|
|
require.ElementsMatch(t, []string{
|
|
md5ChecksumScriptContent(res.ScriptContents),
|
|
md5ChecksumScriptContent("install-script"),
|
|
md5ChecksumScriptContent("post-install-script"),
|
|
md5ChecksumScriptContent("uninstall-script"),
|
|
}, []string{
|
|
sc[0].Checksum,
|
|
sc[1].Checksum,
|
|
sc[2].Checksum,
|
|
sc[3].Checksum,
|
|
})
|
|
|
|
// remove the software installer from the DB
|
|
err = ds.DeleteSoftwareInstaller(ctx, swi)
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
|
|
|
|
// validate that script contents still exist
|
|
sc = []scriptContents{}
|
|
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
|
|
require.NoError(t, err)
|
|
require.Len(t, sc, 1)
|
|
require.Equal(t, md5ChecksumScriptContent(res.ScriptContents), sc[0].Checksum)
|
|
|
|
// create a software install without a post-install script
|
|
tfr2, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
|
|
require.NoError(t, err)
|
|
swi, _, err = ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
|
|
PreInstallQuery: "SELECT 1",
|
|
InstallScript: "install-script",
|
|
UninstallScript: "uninstall-script",
|
|
InstallerFile: tfr2,
|
|
StorageID: "storage1",
|
|
Filename: "file1",
|
|
Title: "file1",
|
|
Version: "1.0",
|
|
Source: "apps",
|
|
UserID: user1.ID,
|
|
ValidatedLabels: &fleet.LabelIdentsWithScope{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// run the cleanup function
|
|
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
|
|
sc = []scriptContents{}
|
|
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
|
|
require.NoError(t, err)
|
|
require.Len(t, sc, 3)
|
|
|
|
// remove the software installer from the DB
|
|
err = ds.DeleteSoftwareInstaller(ctx, swi)
|
|
require.NoError(t, err)
|
|
require.NoError(t, ds.CleanupUnusedScriptContents(ctx))
|
|
|
|
// validate that script contents still exist
|
|
sc = []scriptContents{}
|
|
err = sqlx.SelectContext(ctx, ds.reader(ctx), &sc, stmt)
|
|
require.NoError(t, err)
|
|
require.Len(t, sc, 1)
|
|
require.Equal(t, md5ChecksumScriptContent(res.ScriptContents), sc[0].Checksum)
|
|
}
|
|
|
|
func testGetAnyScriptContents(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
contents := `echo foobar;`
|
|
res, err := insertScriptContents(ctx, ds.writer(ctx), contents)
|
|
require.NoError(t, err)
|
|
id, _ := res.LastInsertId()
|
|
|
|
result, err := ds.GetAnyScriptContents(ctx, uint(id)) //nolint:gosec // dismiss G115
|
|
require.NoError(t, err)
|
|
require.Equal(t, contents, string(result))
|
|
}
|
|
|
|
func testDeleteScriptsAssignedToPolicy(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script.sh",
|
|
TeamID: &team1.ID,
|
|
ScriptContents: "hello world",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
p1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
|
|
Name: "p1",
|
|
Query: "SELECT 1;",
|
|
ScriptID: &script.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = ds.DeleteScript(ctx, script.ID)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, errDeleteScriptWithAssociatedPolicy)
|
|
|
|
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{p1.ID})
|
|
require.NoError(t, err)
|
|
|
|
err = ds.DeleteScript(ctx, script.ID)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
|
team1, _ := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
|
|
script1, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
TeamID: &team1.ID,
|
|
ScriptContents: "hello world",
|
|
})
|
|
require.NoError(t, err)
|
|
script2, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script2.sh",
|
|
TeamID: &team1.ID,
|
|
ScriptContents: "hello world",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
p1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
|
|
Name: "p1",
|
|
Query: "SELECT 1;",
|
|
ScriptID: &script1.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
p2, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
|
|
Name: "p2",
|
|
Query: "SELECT 2;",
|
|
ScriptID: &script2.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// pending host script execution for correct policy
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo",
|
|
UserID: &user.ID,
|
|
PolicyID: &p1.ID,
|
|
SyncRequest: true,
|
|
ScriptID: &script1.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, len(pending))
|
|
|
|
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
|
|
require.NoError(t, err)
|
|
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, len(pending))
|
|
|
|
// test pending host script execution for incorrect policy
|
|
hsr, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo",
|
|
UserID: &user.ID,
|
|
PolicyID: &p2.ID,
|
|
SyncRequest: true,
|
|
ScriptID: &script2.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, len(pending))
|
|
|
|
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
|
|
require.NoError(t, err)
|
|
|
|
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, len(pending))
|
|
|
|
// test not pending host script execution for correct policy
|
|
scriptExecution, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: 1,
|
|
ScriptContents: "echo",
|
|
UserID: &user.ID,
|
|
PolicyID: &p1.ID,
|
|
SyncRequest: true,
|
|
ScriptID: &script1.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// record a result for the previous pending script
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: hsr.ExecutionID,
|
|
Output: "foo",
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// record a failed result for the current pending script
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: 1,
|
|
ExecutionID: scriptExecution.ExecutionID,
|
|
Output: "foo",
|
|
ExitCode: 1,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
|
|
require.NoError(t, err)
|
|
|
|
var count int
|
|
err = sqlx.GetContext(
|
|
ctx,
|
|
ds.reader(ctx),
|
|
&count,
|
|
"SELECT count(1) FROM host_script_results WHERE id = ?",
|
|
scriptExecution.ID,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, count)
|
|
}
|
|
|
|
func testUpdateScriptContents(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
originalScript, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1",
|
|
ScriptContents: "hello world",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
originalContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "hello world", string(originalContents))
|
|
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
_, err := q.ExecContext(ctx, "UPDATE scripts SET updated_at = ? WHERE id = ?", time.Now().Add(-2*time.Minute), originalScript.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
|
|
// Make sure updated_at was changed correctly, but the script is the same
|
|
oldScript, err := ds.Script(ctx, originalScript.ID)
|
|
require.Equal(t, originalScript.ScriptContentID, oldScript.ScriptContentID)
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, originalScript.UpdatedAt, oldScript.UpdatedAt)
|
|
|
|
// Modify the script
|
|
updatedScript, err := ds.UpdateScriptContents(ctx, originalScript.ID, "updated script")
|
|
require.NoError(t, err)
|
|
require.Equal(t, originalScript.ID, updatedScript.ID)
|
|
// With the fix, the script should get a new content ID since content changed
|
|
require.NotEqual(t, originalScript.ScriptContentID, updatedScript.ScriptContentID)
|
|
|
|
updatedContents, err := ds.GetScriptContents(ctx, updatedScript.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "updated script", string(updatedContents))
|
|
require.NotEqual(t, oldScript.UpdatedAt, updatedScript.UpdatedAt)
|
|
}
|
|
|
|
func testUpdateDeletingUpcomingScriptExecutions(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
user := test.NewUser(t, ds, "User", "user@example.com", true)
|
|
host1 := test.NewHost(t, ds, "host1", "10.0.0.1", "host1Key", "host1UUID", time.Now())
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.2", "host2Key", "host2UUID", time.Now())
|
|
|
|
script1, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1",
|
|
ScriptContents: "contents1",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
script2, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script2",
|
|
ScriptContents: "contents2",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
script3, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script3",
|
|
ScriptContents: "contents3",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Queue script executions
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host1.ID,
|
|
ScriptID: &script1.ID,
|
|
UserID: &user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host1.ID,
|
|
ScriptID: &script2.ID,
|
|
UserID: &user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host2.ID,
|
|
ScriptID: &script2.ID,
|
|
UserID: &user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host2.ID,
|
|
ScriptID: &script1.ID,
|
|
UserID: &user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
upcoming1, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 2)
|
|
|
|
upcoming2, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 2)
|
|
|
|
// Updating the "pending/upcoming" script will cancel the activity and stop it from running
|
|
_, err = ds.UpdateScriptContents(ctx, script1.ID, "new contents1")
|
|
require.NoError(t, err)
|
|
|
|
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 1)
|
|
require.Equal(t, script2.ID, *upcoming1[0].ScriptID)
|
|
|
|
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 1)
|
|
require.Equal(t, script2.ID, *upcoming2[0].ScriptID)
|
|
|
|
// Updating a script with no upcoming activities shouldn't affect anything
|
|
_, err = ds.UpdateScriptContents(ctx, script3.ID, "new contents")
|
|
require.NoError(t, err)
|
|
|
|
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 1)
|
|
require.Equal(t, script2.ID, *upcoming1[0].ScriptID)
|
|
|
|
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 1)
|
|
require.Equal(t, script2.ID, *upcoming2[0].ScriptID)
|
|
}
|
|
|
|
func testBatchExecute(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "user1", "user@example.com", true)
|
|
|
|
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
|
|
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
|
|
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
|
|
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
|
|
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
|
|
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
|
|
|
|
test.SetOrbitEnrollment(t, hostWindows, ds)
|
|
test.SetOrbitEnrollment(t, host1, ds)
|
|
test.SetOrbitEnrollment(t, host2, ds)
|
|
test.SetOrbitEnrollment(t, host3, ds)
|
|
test.SetOrbitEnrollment(t, hostTeam1, ds)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: "echo hi",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Hosts all have to be on the same team as the script
|
|
execID, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostTeam1.ID})
|
|
require.Empty(t, execID)
|
|
require.ErrorContains(t, err, "same fleet")
|
|
|
|
// Actual good execution
|
|
execID, err = ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
|
|
require.NoError(t, err)
|
|
|
|
summary, err := ds.BatchExecuteSummary(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, script.ID, *summary.ScriptID)
|
|
require.Equal(t, script.Name, summary.ScriptName)
|
|
require.Equal(t, uint(0), *summary.TeamID)
|
|
require.NotNil(t, summary.CreatedAt)
|
|
|
|
// The summary should have two pending hosts and two errored ones, because
|
|
// the script is not compatible with the hostNoScripts and hostWindows.
|
|
require.Equal(t, *summary.NumPending, uint(3))
|
|
require.Equal(t, *summary.NumErrored, uint(2))
|
|
require.Equal(t, *summary.NumRan, uint(0))
|
|
require.Equal(t, *summary.NumCanceled, uint(0))
|
|
// Host 1 should have an upcoming execution
|
|
host1Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, host1Upcoming, 1)
|
|
require.Equal(t, summary.ScriptID, host1Upcoming[0].ScriptID)
|
|
// Host 2 should have an upcoming execution
|
|
host2Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, host2Upcoming, 1)
|
|
require.Equal(t, summary.ScriptID, host2Upcoming[0].ScriptID)
|
|
// Host 3 should have an upcoming execution
|
|
host3Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host3.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, host3Upcoming, 1)
|
|
require.Equal(t, summary.ScriptID, host3Upcoming[0].ScriptID)
|
|
// Host Windows should not have an upcoming execution
|
|
hostWindowsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostWindows.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostWindowsUpcoming, 0)
|
|
// Host No Scripts should not have an upcoming execution
|
|
hostNoScriptsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostNoScripts.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostNoScriptsUpcoming, 0)
|
|
// Host Windows should have an error in its `batch_activity_host_results` row
|
|
var exec_error string
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
db := q.(*sqlx.DB)
|
|
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostWindows.ID, execID)
|
|
require.NoError(t, err)
|
|
return nil
|
|
})
|
|
require.Equal(t, fleet.BatchExecuteIncompatiblePlatform, exec_error)
|
|
// Host No Scripts should have an error in its `batch_activity_host_results` row
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
db := q.(*sqlx.DB)
|
|
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostNoScripts.ID, execID)
|
|
require.NoError(t, err)
|
|
return nil
|
|
})
|
|
require.Equal(t, fleet.BatchExecuteIncompatibleFleetd, exec_error)
|
|
|
|
// Set host 1 to have a successful script result
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host1.ID,
|
|
ExecutionID: host1Upcoming[0].ExecutionID,
|
|
Output: "foo",
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
// Get the summary again
|
|
summary, err = ds.BatchExecuteSummary(ctx, execID)
|
|
require.NoError(t, err)
|
|
// The summary should have one pending host, one run host and two errored ones.
|
|
require.Equal(t, *summary.NumPending, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(2))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(0))
|
|
|
|
// Set host 1 to have a failed script result
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host2.ID,
|
|
ExecutionID: host2Upcoming[0].ExecutionID,
|
|
Output: "bar",
|
|
ExitCode: 1,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Get the summary again
|
|
summary, err = ds.BatchExecuteSummary(ctx, execID)
|
|
require.NoError(t, err)
|
|
// The summary should have one pending host, one run host and two errored ones.
|
|
require.Equal(t, *summary.NumPending, uint(1))
|
|
require.Equal(t, *summary.NumErrored, uint(3))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(0))
|
|
|
|
// Cancel the execution
|
|
_, err = ds.CancelHostUpcomingActivity(ctx, host3.ID, host3Upcoming[0].ExecutionID)
|
|
require.NoError(t, err)
|
|
// Get the summary again
|
|
summary, err = ds.BatchExecuteSummary(ctx, execID)
|
|
require.NoError(t, err)
|
|
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
|
|
require.Equal(t, *summary.NumPending, uint(0))
|
|
require.Equal(t, *summary.NumErrored, uint(3))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(1))
|
|
}
|
|
|
|
func testBatchExecuteWithStatus(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "user1", "user@example.com", true)
|
|
|
|
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
|
|
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
|
|
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
|
|
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
|
|
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
|
|
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
|
|
|
|
test.SetOrbitEnrollment(t, hostWindows, ds)
|
|
test.SetOrbitEnrollment(t, host1, ds)
|
|
test.SetOrbitEnrollment(t, host2, ds)
|
|
test.SetOrbitEnrollment(t, host3, ds)
|
|
test.SetOrbitEnrollment(t, hostTeam1, ds)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: "echo hi",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Hosts all have to be on the same team as the script
|
|
execID, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostTeam1.ID})
|
|
require.Empty(t, execID)
|
|
require.ErrorContains(t, err, "same fleet")
|
|
|
|
// Actual good execution
|
|
execID, err = ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
|
|
require.NoError(t, err)
|
|
|
|
// Update the batch to have a pending status
|
|
// TODO -- remove this when status is set automatically
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status = 'scheduled' WHERE execution_id = ?", execID)
|
|
return err
|
|
})
|
|
|
|
summaryList, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary := (summaryList)[0]
|
|
require.Equal(t, execID, summary.BatchExecutionID)
|
|
require.Equal(t, script.ID, *summary.ScriptID)
|
|
require.Equal(t, script.Name, summary.ScriptName)
|
|
require.Equal(t, uint(0), *summary.TeamID)
|
|
require.NotNil(t, summary.CreatedAt)
|
|
|
|
// The summary should have two pending hosts and two errored ones, because
|
|
// the script is not compatible with the hostNoScripts and hostWindows.
|
|
require.Equal(t, *summary.NumTargeted, uint(5))
|
|
require.Equal(t, *summary.NumPending, uint(3))
|
|
require.Equal(t, *summary.NumIncompatible, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(0))
|
|
require.Equal(t, *summary.NumRan, uint(0))
|
|
require.Equal(t, *summary.NumCanceled, uint(0))
|
|
// Host 1 should have an upcoming execution
|
|
host1Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, host1Upcoming, 1)
|
|
require.Equal(t, summary.ScriptID, host1Upcoming[0].ScriptID)
|
|
// Host 2 should have an upcoming execution
|
|
host2Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, host2Upcoming, 1)
|
|
require.Equal(t, summary.ScriptID, host2Upcoming[0].ScriptID)
|
|
// Host 3 should have an upcoming execution
|
|
host3Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host3.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, host3Upcoming, 1)
|
|
require.Equal(t, summary.ScriptID, host3Upcoming[0].ScriptID)
|
|
// Host Windows should not have an upcoming execution
|
|
hostWindowsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostWindows.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostWindowsUpcoming, 0)
|
|
// Host No Scripts should not have an upcoming execution
|
|
hostNoScriptsUpcoming, err := ds.listUpcomingHostScriptExecutions(ctx, hostNoScripts.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostNoScriptsUpcoming, 0)
|
|
// Host Windows should have an error in its `batch_activity_host_results` row
|
|
var exec_error string
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
db := q.(*sqlx.DB)
|
|
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostWindows.ID, execID)
|
|
require.NoError(t, err)
|
|
return nil
|
|
})
|
|
require.Equal(t, fleet.BatchExecuteIncompatiblePlatform, exec_error)
|
|
// Host No Scripts should have an error in its `batch_activity_host_results` row
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
db := q.(*sqlx.DB)
|
|
err := db.Get(&exec_error, "SELECT error FROM batch_activity_host_results WHERE host_id = ? AND batch_execution_id = ?", hostNoScripts.ID, execID)
|
|
require.NoError(t, err)
|
|
return nil
|
|
})
|
|
require.Equal(t, fleet.BatchExecuteIncompatibleFleetd, exec_error)
|
|
|
|
// Set host 1 to have a successful script result
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host1.ID,
|
|
ExecutionID: host1Upcoming[0].ExecutionID,
|
|
Output: "foo",
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Get the summary again
|
|
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary = (summaryList)[0]
|
|
// The summary should have one pending host, one run host and two errored ones.
|
|
require.Equal(t, *summary.NumTargeted, uint(5))
|
|
require.Equal(t, *summary.NumPending, uint(2))
|
|
require.Equal(t, *summary.NumIncompatible, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(0))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(0))
|
|
|
|
// Set host 1 to have a failed script result
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host2.ID,
|
|
ExecutionID: host2Upcoming[0].ExecutionID,
|
|
Output: "bar",
|
|
ExitCode: -1,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Get the summary again
|
|
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary = (summaryList)[0] // The summary should have one pending host, one run host and two errored ones.
|
|
require.Equal(t, *summary.NumTargeted, uint(5))
|
|
require.Equal(t, *summary.NumPending, uint(1))
|
|
require.Equal(t, *summary.NumIncompatible, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(1))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(0))
|
|
|
|
// Cancel the execution
|
|
_, err = ds.CancelHostUpcomingActivity(ctx, host3.ID, host3Upcoming[0].ExecutionID)
|
|
require.NoError(t, err)
|
|
// Get the summary again
|
|
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary = (summaryList)[0]
|
|
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
|
|
require.Equal(t, *summary.NumPending, uint(0))
|
|
require.Equal(t, *summary.NumIncompatible, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(1))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(1))
|
|
|
|
// The summary should be returned when filtering by status "scheduled".
|
|
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
Status: ptr.String("scheduled"),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary = (summaryList)[0]
|
|
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
|
|
require.Equal(t, *summary.NumPending, uint(0))
|
|
require.Equal(t, *summary.NumIncompatible, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(1))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(1))
|
|
|
|
// The summary should be returned when filtering by team 1.
|
|
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
TeamID: ptr.Uint(0),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary = (summaryList)[0]
|
|
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
|
|
require.Equal(t, *summary.NumPending, uint(0))
|
|
require.Equal(t, *summary.NumIncompatible, uint(2))
|
|
require.Equal(t, *summary.NumErrored, uint(1))
|
|
require.Equal(t, *summary.NumRan, uint(1))
|
|
require.Equal(t, *summary.NumCanceled, uint(1))
|
|
|
|
// Mark the execution as completed, and make up some host numbers
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status = 'finished', num_pending = 4, num_ran = 5, num_errored = 6, num_canceled = 7, num_incompatible = 8, num_targeted = 9 WHERE execution_id = ?", execID)
|
|
return err
|
|
})
|
|
// Get the summary again
|
|
summaryList, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary = (summaryList)[0] // The summary should have one pending host, one run host and two errored ones.
|
|
require.Equal(t, *summary.NumPending, uint(4))
|
|
require.Equal(t, *summary.NumRan, uint(5))
|
|
require.Equal(t, *summary.NumErrored, uint(6))
|
|
require.Equal(t, *summary.NumCanceled, uint(7))
|
|
require.Equal(t, *summary.NumIncompatible, uint(8))
|
|
require.Equal(t, *summary.NumTargeted, uint(9))
|
|
}
|
|
|
|
func testBatchScriptSchedule(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
|
|
user := test.NewUser(t, ds, "user1", "user@example.com", true)
|
|
|
|
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
|
|
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
|
|
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
|
|
|
|
host4 := test.NewHost(t, ds, "host4", "10.0.0.5", "host4key", "host4uuid", time.Now())
|
|
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.6", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
|
|
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
|
|
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
|
|
|
|
test.SetOrbitEnrollment(t, host1, ds)
|
|
test.SetOrbitEnrollment(t, host2, ds)
|
|
test.SetOrbitEnrollment(t, host3, ds)
|
|
test.SetOrbitEnrollment(t, hostTeam1, ds)
|
|
test.SetOrbitEnrollment(t, host4, ds)
|
|
test.SetOrbitEnrollment(t, hostWindows, ds)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: "echo hi",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
scheduledTime := time.Now().Add(10 * time.Hour).Truncate(time.Second).UTC()
|
|
execID, err := ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host1.ID, host2.ID, host3.ID}, scheduledTime)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID)
|
|
|
|
jobs, err := ds.GetQueuedJobs(ctx, 10, scheduledTime.Add(10*time.Minute))
|
|
require.NoError(t, err)
|
|
// Should have scheduled one job
|
|
require.NotZero(t, jobs)
|
|
// find our job
|
|
var job *fleet.Job
|
|
for _, j := range jobs {
|
|
if j.Name == fleet.BatchActivityScriptsJobName {
|
|
job = j
|
|
}
|
|
}
|
|
require.NotNil(t, job)
|
|
require.Equal(t, fleet.BatchActivityScriptsJobName, job.Name)
|
|
// Make sure the name matches the name on the worker job
|
|
batchJob := worker.BatchScripts{}
|
|
require.Equal(t, batchJob.Name(), job.Name)
|
|
// Time from DB isn't super accurate
|
|
require.Equal(t, scheduledTime.Truncate(time.Minute), job.NotBefore.Truncate(time.Minute))
|
|
assert.JSONEq(t, fmt.Sprintf(`{"execution_id":%q}`, execID), string(*job.Args))
|
|
|
|
batchActivity, err := ds.GetBatchActivity(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, execID, batchActivity.BatchExecutionID)
|
|
require.Nil(t, batchActivity.StartedAt)
|
|
require.Equal(t, user.ID, *batchActivity.UserID)
|
|
require.Equal(t, script.ID, *batchActivity.ScriptID)
|
|
require.Equal(t, script.Name, batchActivity.ScriptName)
|
|
require.Equal(t, fleet.BatchExecutionActivityScript, batchActivity.ActivityType)
|
|
require.Equal(t, job.ID, *batchActivity.JobID)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionScheduled, batchActivity.Status)
|
|
require.Equal(t, uint(3), *batchActivity.NumTargeted)
|
|
|
|
hostResults, err := ds.GetBatchActivityHostResults(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostResults, 3)
|
|
for _, hostResult := range hostResults {
|
|
require.Equal(t, execID, hostResult.BatchExecutionID)
|
|
require.Nil(t, hostResult.HostExecutionID)
|
|
}
|
|
|
|
// Run it manually, the same as the job running it but without waiting for the time
|
|
err = ds.RunScheduledBatchActivity(ctx, execID)
|
|
require.NoError(t, err)
|
|
|
|
batchActivity, err = ds.GetBatchActivity(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, execID, batchActivity.BatchExecutionID)
|
|
require.NotNil(t, batchActivity.StartedAt)
|
|
require.Equal(t, user.ID, *batchActivity.UserID)
|
|
require.Equal(t, script.ID, *batchActivity.ScriptID)
|
|
require.Equal(t, script.Name, batchActivity.ScriptName)
|
|
require.Equal(t, fleet.BatchExecutionActivityScript, batchActivity.ActivityType)
|
|
require.Equal(t, job.ID, *batchActivity.JobID)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionStarted, batchActivity.Status)
|
|
require.Equal(t, uint(3), *batchActivity.NumTargeted)
|
|
|
|
hostResults, err = ds.GetBatchActivityHostResults(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostResults, 3)
|
|
for _, hostResult := range hostResults {
|
|
require.Equal(t, execID, hostResult.BatchExecutionID)
|
|
require.NotNil(t, hostResult.HostExecutionID)
|
|
upcomingScripts, err := ds.ListPendingHostScriptExecutions(ctx, hostResult.HostID, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcomingScripts, 1)
|
|
}
|
|
|
|
// Make sure we can't run the same scheduled script again after it's started
|
|
err = ds.RunScheduledBatchActivity(ctx, execID)
|
|
require.Error(t, err)
|
|
|
|
// Make sure we can't run a canceled scheduled script after it's been canceled
|
|
execID, err = ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host1.ID}, scheduledTime)
|
|
require.NoError(t, err)
|
|
|
|
err = ds.CancelBatchScript(ctx, execID)
|
|
require.NoError(t, err)
|
|
|
|
err = ds.RunScheduledBatchActivity(ctx, execID)
|
|
require.Error(t, err)
|
|
|
|
// Schedule script where most hosts will fail for various reaons
|
|
// These would be checked for some validity before insertion if submitted by a user
|
|
execID, err = ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host4.ID, hostWindows.ID, hostTeam1.ID, hostNoScripts.ID, 0xbeef}, scheduledTime)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID)
|
|
|
|
err = ds.RunScheduledBatchActivity(ctx, execID)
|
|
require.NoError(t, err)
|
|
|
|
batchActivity, err = ds.GetBatchActivity(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, execID, batchActivity.BatchExecutionID)
|
|
require.NotNil(t, batchActivity.StartedAt)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionStarted, batchActivity.Status)
|
|
require.Equal(t, uint(5), *batchActivity.NumTargeted)
|
|
|
|
executions, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, executions, 1)
|
|
require.Equal(t, uint(3), *executions[0].NumIncompatible)
|
|
|
|
hostResults, err = ds.GetBatchActivityHostResults(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Len(t, hostResults, 5)
|
|
for _, hostResult := range hostResults {
|
|
require.Equal(t, execID, hostResult.BatchExecutionID)
|
|
var upcomingScripts []*fleet.HostScriptResult
|
|
if hostResult.HostID != 0xbeef {
|
|
upcomingScripts, err = ds.ListPendingHostScriptExecutions(ctx, hostResult.HostID, false)
|
|
require.NoError(t, err)
|
|
}
|
|
switch hostResult.HostID {
|
|
case host4.ID:
|
|
// The only valid host in the group
|
|
require.NotNil(t, hostResult.HostExecutionID)
|
|
require.Len(t, upcomingScripts, 1)
|
|
case hostWindows.ID:
|
|
// Bad platform
|
|
require.Len(t, upcomingScripts, 0)
|
|
require.NotNil(t, hostResult.Error)
|
|
require.Equal(t, fleet.BatchExecuteIncompatiblePlatform, *hostResult.Error)
|
|
case hostTeam1.ID:
|
|
// Bad team
|
|
require.Len(t, upcomingScripts, 1)
|
|
require.Nil(t, hostResult.Error)
|
|
case hostNoScripts.ID:
|
|
// Host doesn't support scripts
|
|
require.Len(t, upcomingScripts, 0)
|
|
require.NotNil(t, hostResult.Error)
|
|
require.Equal(t, fleet.BatchExecuteIncompatibleFleetd, *hostResult.Error)
|
|
case 0xbeef:
|
|
// Host was deleted after scheduling
|
|
require.NotNil(t, hostResult.Error)
|
|
require.Equal(t, fleet.BatchExecuteInvalidHost, *hostResult.Error)
|
|
default:
|
|
require.Failf(t, "forgot to check a host", "host_id: %d", hostResult.HostID)
|
|
}
|
|
}
|
|
|
|
// Schedule script that we will subsequently cancel.
|
|
execID, err = ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host4.ID, hostWindows.ID, hostTeam1.ID, hostNoScripts.ID, 0xbeef}, scheduledTime)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID)
|
|
|
|
err = ds.CancelBatchScript(ctx, execID)
|
|
require.NoError(t, err)
|
|
|
|
// Get the summary again
|
|
summaryList, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{
|
|
ExecutionID: &execID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, summaryList, 1)
|
|
summary := (summaryList)[0]
|
|
// The summary should have no pending hosts, one run host, three errored ones and one canceled.
|
|
require.Equal(t, *summary.NumPending, uint(0))
|
|
require.Equal(t, *summary.NumIncompatible, uint(0))
|
|
require.Equal(t, *summary.NumErrored, uint(0))
|
|
require.Equal(t, *summary.NumRan, uint(0))
|
|
require.Equal(t, *summary.NumCanceled, uint(5))
|
|
}
|
|
|
|
func testMarkActivitiesAsCompleted(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "user1", "user@example.com", true)
|
|
|
|
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
|
|
hostNoScripts := test.NewHost(t, ds, "hostNoScripts", "10.0.0.1", "hostnoscripts", "hostnoscriptsuuid", time.Now())
|
|
hostWindows := test.NewHost(t, ds, "hostWin", "10.0.0.2", "hostWinKey", "hostWinUuid", time.Now(), test.WithPlatform("windows"))
|
|
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
|
|
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
|
|
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
|
|
|
|
test.SetOrbitEnrollment(t, hostWindows, ds)
|
|
test.SetOrbitEnrollment(t, host1, ds)
|
|
test.SetOrbitEnrollment(t, host2, ds)
|
|
test.SetOrbitEnrollment(t, host3, ds)
|
|
test.SetOrbitEnrollment(t, hostTeam1, ds)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: "echo hi",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Actual good execution
|
|
execID, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID)
|
|
|
|
// Schedule another one
|
|
execID2, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID2)
|
|
|
|
// Get the upcoming activities for each host
|
|
host1Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
host2Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
host3Upcoming, err := ds.listUpcomingHostScriptExecutions(ctx, host3.ID, false, false)
|
|
require.NoError(t, err)
|
|
|
|
// Set host 1 to have a successful script result
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host1.ID,
|
|
ExecutionID: host1Upcoming[0].ExecutionID,
|
|
Output: "foo",
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Set host 2 to have a failed script result
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host2.ID,
|
|
ExecutionID: host2Upcoming[0].ExecutionID,
|
|
Output: "bar",
|
|
ExitCode: -1,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Cancel the execution for host 3
|
|
_, err = ds.CancelHostUpcomingActivity(ctx, host3.ID, host3Upcoming[0].ExecutionID)
|
|
require.NoError(t, err)
|
|
|
|
// Update the batch activity status to "started"
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status='started' WHERE execution_id IN (?,?)", execID, execID2)
|
|
return err
|
|
})
|
|
|
|
// Mark activities as completed
|
|
err = ds.MarkActivitiesAsCompleted(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// First activity should be marked as finished and updated accordingly.
|
|
batchActivity, err := ds.GetBatchActivity(ctx, execID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionFinished, batchActivity.Status)
|
|
require.Equal(t, uint(5), *batchActivity.NumTargeted)
|
|
require.Equal(t, uint(1), *batchActivity.NumRan)
|
|
require.Equal(t, uint(1), *batchActivity.NumErrored)
|
|
require.Equal(t, uint(2), *batchActivity.NumIncompatible)
|
|
require.Equal(t, uint(1), *batchActivity.NumCanceled)
|
|
require.Equal(t, uint(0), *batchActivity.NumPending)
|
|
|
|
// Second activity should still be in "started" status.
|
|
batchActivity2, err := ds.GetBatchActivity(ctx, execID2)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionStarted, batchActivity2.Status)
|
|
|
|
// Schedule another batch that we will cancel.
|
|
execID3, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{hostNoScripts.ID, hostWindows.ID, host1.ID, host2.ID, host3.ID})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID3)
|
|
|
|
// Update the batch activity status to "started"
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "UPDATE batch_activities SET status='started' WHERE execution_id = ?", execID3)
|
|
return err
|
|
})
|
|
|
|
// Cancel the batch.
|
|
err = ds.CancelBatchScript(ctx, execID3)
|
|
require.NoError(t, err)
|
|
|
|
// First activity should be marked as finished and updated accordingly.
|
|
batchActivity, err = ds.GetBatchActivity(ctx, execID3)
|
|
require.NoError(t, err)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionFinished, batchActivity.Status)
|
|
require.Equal(t, uint(5), *batchActivity.NumTargeted)
|
|
require.Equal(t, uint(0), *batchActivity.NumRan)
|
|
require.Equal(t, uint(0), *batchActivity.NumErrored)
|
|
require.Equal(t, uint(2), *batchActivity.NumIncompatible)
|
|
require.Equal(t, uint(3), *batchActivity.NumCanceled)
|
|
require.Equal(t, uint(0), *batchActivity.NumPending)
|
|
|
|
// Edge case -- batch activity with no hosts.
|
|
// In reality this could happen if all the hosts in a batch get deleted.
|
|
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, "INSERT INTO batch_activities (execution_id, status, script_id, activity_type) VALUES (?, ?, ?, ?)",
|
|
"abc123", fleet.ScheduledBatchExecutionStarted, script.ID, "script")
|
|
return err
|
|
})
|
|
|
|
// Mark activities as completed
|
|
err = ds.MarkActivitiesAsCompleted(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Activity should be marked as finished and updated accordingly.
|
|
batchActivity, err = ds.GetBatchActivity(ctx, "abc123")
|
|
require.NoError(t, err)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionFinished, batchActivity.Status)
|
|
require.Equal(t, uint(0), *batchActivity.NumTargeted)
|
|
require.Equal(t, uint(0), *batchActivity.NumRan)
|
|
require.Equal(t, uint(0), *batchActivity.NumErrored)
|
|
require.Equal(t, uint(0), *batchActivity.NumIncompatible)
|
|
require.Equal(t, uint(0), *batchActivity.NumCanceled)
|
|
require.Equal(t, uint(0), *batchActivity.NumPending)
|
|
}
|
|
|
|
func testBatchScriptCancel(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
user := test.NewUser(t, ds, "user1", "user@example.com", true)
|
|
|
|
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
|
require.NoError(t, err)
|
|
|
|
host1 := test.NewHost(t, ds, "host1", "10.0.0.3", "host1key", "host1uuid", time.Now())
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.4", "host2key", "host2uuid", time.Now())
|
|
host3 := test.NewHost(t, ds, "host3", "10.0.0.4", "host3key", "host3uuid", time.Now())
|
|
hostTeam1 := test.NewHost(t, ds, "hostTeam1", "10.0.0.5", "hostTeam1key", "hostTeam1uuid", time.Now(), test.WithTeamID(team1.ID))
|
|
|
|
test.SetOrbitEnrollment(t, host1, ds)
|
|
test.SetOrbitEnrollment(t, host2, ds)
|
|
test.SetOrbitEnrollment(t, host3, ds)
|
|
test.SetOrbitEnrollment(t, hostTeam1, ds)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: "echo hi",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
////
|
|
// Immediate execution
|
|
//
|
|
execID1, err := ds.BatchExecuteScript(ctx, &user.ID, script.ID, []uint{host1.ID, host2.ID})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID1)
|
|
|
|
summary1, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID1})
|
|
require.NoError(t, err)
|
|
require.Len(t, summary1, 1)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionStarted, summary1[0].Status)
|
|
require.False(t, summary1[0].Canceled)
|
|
require.Equal(t, uint(2), *summary1[0].NumTargeted)
|
|
require.Equal(t, uint(2), *summary1[0].NumPending)
|
|
|
|
upcoming1, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 1)
|
|
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host1.ID,
|
|
ExecutionID: upcoming1[0].ExecutionID,
|
|
Output: "",
|
|
ExitCode: 0,
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 1)
|
|
|
|
err = ds.CancelBatchScript(ctx, execID1)
|
|
require.NoError(t, err)
|
|
|
|
summary1, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID1})
|
|
require.NoError(t, err)
|
|
require.Len(t, summary1, 1)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionFinished, summary1[0].Status)
|
|
require.True(t, summary1[0].Canceled)
|
|
require.Equal(t, uint(0), *summary1[0].NumPending)
|
|
require.Equal(t, uint(1), *summary1[0].NumRan)
|
|
require.Equal(t, uint(2), *summary1[0].NumTargeted)
|
|
require.Equal(t, uint(1), *summary1[0].NumCanceled)
|
|
|
|
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 0)
|
|
|
|
upcoming1, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming1, 0)
|
|
|
|
////
|
|
// Future execution
|
|
//
|
|
execID2, err := ds.BatchScheduleScript(ctx, &user.ID, script.ID, []uint{host1.ID, host2.ID}, time.Now().Add(2*time.Hour))
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execID2)
|
|
|
|
summary2, err := ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID2})
|
|
require.NoError(t, err)
|
|
require.Len(t, summary2, 1)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionScheduled, summary2[0].Status)
|
|
require.False(t, summary2[0].Canceled)
|
|
require.Equal(t, uint(2), *summary2[0].NumTargeted)
|
|
require.Equal(t, uint(2), *summary2[0].NumPending)
|
|
|
|
upcoming2, err := ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 0)
|
|
|
|
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 0)
|
|
|
|
err = ds.CancelBatchScript(ctx, execID2)
|
|
require.NoError(t, err)
|
|
|
|
summary2, err = ds.ListBatchScriptExecutions(ctx, fleet.BatchExecutionStatusFilter{ExecutionID: &execID2})
|
|
require.NoError(t, err)
|
|
require.Len(t, summary2, 1)
|
|
require.Equal(t, fleet.ScheduledBatchExecutionFinished, summary2[0].Status)
|
|
require.True(t, summary2[0].Canceled)
|
|
require.Equal(t, uint(0), *summary2[0].NumPending)
|
|
require.Equal(t, uint(0), *summary2[0].NumRan)
|
|
require.Equal(t, uint(2), *summary2[0].NumCanceled)
|
|
require.Equal(t, uint(2), *summary2[0].NumTargeted)
|
|
|
|
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host1.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 0)
|
|
|
|
upcoming2, err = ds.listUpcomingHostScriptExecutions(ctx, host2.ID, false, false)
|
|
require.NoError(t, err)
|
|
require.Len(t, upcoming2, 0)
|
|
}
|
|
|
|
func testDeleteScriptActivatesNextActivity(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
u := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
|
|
|
// create a couple of scripts
|
|
scriptA, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "a",
|
|
ScriptContents: "echo 'a'",
|
|
})
|
|
require.NoError(t, err)
|
|
scriptB, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "b",
|
|
ScriptContents: "echo 'b'",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// create some hosts
|
|
hosts := make([]*fleet.Host, 4)
|
|
for i := range hosts {
|
|
host, err := ds.NewHost(context.Background(), &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String(fmt.Sprint(i)),
|
|
UUID: fmt.Sprint(i),
|
|
Hostname: fmt.Sprintf("%d-foo.local", i),
|
|
PrimaryIP: fmt.Sprintf("192.168.1.%d", i),
|
|
PrimaryMac: fmt.Sprintf("30-65-EC-6F-C4-5%d", i),
|
|
})
|
|
require.NoError(t, err)
|
|
hosts[i] = host
|
|
}
|
|
|
|
// enqueue scripts executions:
|
|
// * hosts[0]: a, b
|
|
// * hosts[1]: a, b
|
|
// * hosts[2]: b, a
|
|
// * hosts[3]: b
|
|
execHost0ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[0].ID,
|
|
ScriptID: &scriptA.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
execHost0ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[0].ID,
|
|
ScriptID: &scriptB.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
execHost1ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[1].ID,
|
|
ScriptID: &scriptA.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
execHost1ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[1].ID,
|
|
ScriptID: &scriptB.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
execHost2ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[2].ID,
|
|
ScriptID: &scriptB.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
execHost2ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[2].ID,
|
|
ScriptID: &scriptA.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
execHost3ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[3].ID,
|
|
ScriptID: &scriptB.ID,
|
|
UserID: &u.ID,
|
|
SyncRequest: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptA.ExecutionID, execHost0ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptA.ExecutionID, execHost1ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptB.ExecutionID, execHost2ScriptA.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptB.ExecutionID)
|
|
|
|
// delete scriptA removes pending upcoming activity and activates next activity
|
|
err = ds.DeleteScript(ctx, scriptA.ID)
|
|
require.NoError(t, err)
|
|
|
|
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptB.ExecutionID)
|
|
}
|
|
|
|
func testBatchSetScriptActivatesNextActivity(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
|
|
// batch-set some scripts
|
|
scripts, err := ds.BatchSetScripts(ctx, nil, []*fleet.Script{
|
|
{Name: "A", ScriptContents: "C1"},
|
|
{Name: "B", ScriptContents: "C2"},
|
|
{Name: "C", ScriptContents: "C3"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// index scripts by name
|
|
scriptByName := make(map[string]uint)
|
|
for _, s := range scripts {
|
|
scriptByName[s.Name] = s.ID
|
|
}
|
|
|
|
// create some hosts
|
|
hosts := make([]*fleet.Host, 4)
|
|
for i := range hosts {
|
|
host, err := ds.NewHost(context.Background(), &fleet.Host{
|
|
DetailUpdatedAt: time.Now(),
|
|
LabelUpdatedAt: time.Now(),
|
|
SeenTime: time.Now(),
|
|
NodeKey: ptr.String(fmt.Sprint(i)),
|
|
UUID: fmt.Sprint(i),
|
|
Hostname: fmt.Sprintf("%d-foo.local", i),
|
|
PrimaryIP: fmt.Sprintf("192.168.1.%d", i),
|
|
PrimaryMac: fmt.Sprintf("30-65-EC-6F-C4-5%d", i),
|
|
})
|
|
require.NoError(t, err)
|
|
hosts[i] = host
|
|
}
|
|
|
|
// enqeue script executions:
|
|
// * hosts[0]: A, C, A, B
|
|
// * hosts[1]: B, B, C
|
|
// * hosts[2]: C, A
|
|
// * hosts[3]: A, B, C
|
|
execHost0ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost0ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost0ScriptA2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost0ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[0].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost1ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[1].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost1ScriptB2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[1].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost1ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[1].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost2ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[2].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost2ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[2].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost3ScriptA, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[3].ID, ScriptID: ptr.Uint(scriptByName["A"]), SyncRequest: true, ScriptContents: "C1",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost3ScriptB, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[3].ID, ScriptID: ptr.Uint(scriptByName["B"]), SyncRequest: true, ScriptContents: "C2",
|
|
})
|
|
require.NoError(t, err)
|
|
execHost3ScriptC, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: hosts[3].ID, ScriptID: ptr.Uint(scriptByName["C"]), SyncRequest: true, ScriptContents: "C3",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptA.ExecutionID, execHost0ScriptC.ExecutionID, execHost0ScriptA2.ExecutionID, execHost0ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptB.ExecutionID, execHost1ScriptB2.ExecutionID, execHost1ScriptC.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptC.ExecutionID, execHost2ScriptA.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptA.ExecutionID, execHost3ScriptB.ExecutionID, execHost3ScriptC.ExecutionID)
|
|
|
|
// no change
|
|
_, err = ds.BatchSetScripts(ctx, nil, []*fleet.Script{
|
|
{Name: "A", ScriptContents: "C1"},
|
|
{Name: "B", ScriptContents: "C2"},
|
|
{Name: "C", ScriptContents: "C3"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptA.ExecutionID, execHost0ScriptC.ExecutionID, execHost0ScriptA2.ExecutionID, execHost0ScriptB.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptB.ExecutionID, execHost1ScriptB2.ExecutionID, execHost1ScriptC.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptC.ExecutionID, execHost2ScriptA.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptA.ExecutionID, execHost3ScriptB.ExecutionID, execHost3ScriptC.ExecutionID)
|
|
|
|
// batch-set removes A, updates B and creates D, cancelling any pending A and B executions
|
|
_, err = ds.BatchSetScripts(ctx, nil, []*fleet.Script{
|
|
{Name: "B", ScriptContents: "C2updated"},
|
|
{Name: "C", ScriptContents: "C3"},
|
|
{Name: "D", ScriptContents: "C4"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
checkUpcomingActivities(t, ds, hosts[0], execHost0ScriptC.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[1], execHost1ScriptC.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[2], execHost2ScriptC.ExecutionID)
|
|
checkUpcomingActivities(t, ds, hosts[3], execHost3ScriptC.ExecutionID)
|
|
|
|
// batch-set remove all
|
|
_, err = ds.BatchSetScripts(ctx, nil, []*fleet.Script{})
|
|
require.NoError(t, err)
|
|
|
|
checkUpcomingActivities(t, ds, hosts[0])
|
|
checkUpcomingActivities(t, ds, hosts[1])
|
|
checkUpcomingActivities(t, ds, hosts[2])
|
|
checkUpcomingActivities(t, ds, hosts[3])
|
|
}
|
|
|
|
// Test updating a script to match another script's contents
|
|
func testUpdateScriptToDuplicateContent(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
|
|
// Create two scripts with different content
|
|
script1, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: "echo hello",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
script2, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script2.sh",
|
|
ScriptContents: "echo world",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Get initial content IDs
|
|
s1, err := ds.Script(ctx, script1.ID)
|
|
require.NoError(t, err)
|
|
s2, err := ds.Script(ctx, script2.ID)
|
|
require.NoError(t, err)
|
|
initialContentID1 := s1.ScriptContentID
|
|
initialContentID2 := s2.ScriptContentID
|
|
require.NotEqual(t, initialContentID1, initialContentID2)
|
|
|
|
// Update script2 to have the same content as script1
|
|
// This should NOT cause a duplicate key error
|
|
_, err = ds.UpdateScriptContents(ctx, script2.ID, "echo hello")
|
|
require.NoError(t, err)
|
|
// ScriptContents is not populated from the DB, check via GetScriptContents
|
|
// GetScriptContents takes a script ID, not script_content_id
|
|
updatedContents, err := ds.GetScriptContents(ctx, script2.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo hello", string(updatedContents))
|
|
|
|
// Verify both scripts now share the same content ID
|
|
s1After, err := ds.Script(ctx, script1.ID)
|
|
require.NoError(t, err)
|
|
s2After, err := ds.Script(ctx, script2.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, s1After.ScriptContentID, s2After.ScriptContentID)
|
|
require.Equal(t, initialContentID1, s2After.ScriptContentID)
|
|
|
|
// Verify the old content ID was cleaned up
|
|
var count int
|
|
err = sqlx.GetContext(ctx, ds.reader(ctx), &count,
|
|
`SELECT COUNT(*) FROM script_contents WHERE id = ?`, initialContentID2)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count, "old script content should be deleted")
|
|
}
|
|
|
|
// Test modifying a script whose content currently matches another script's content
|
|
func testUpdateSharedScriptContent(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
|
|
// Create two scripts with the SAME content
|
|
sharedContent := "echo shared"
|
|
script1, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script1.sh",
|
|
ScriptContents: sharedContent,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
script2, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script2.sh",
|
|
ScriptContents: sharedContent,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify they share the same content ID
|
|
s1, err := ds.Script(ctx, script1.ID)
|
|
require.NoError(t, err)
|
|
s2, err := ds.Script(ctx, script2.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, s1.ScriptContentID, s2.ScriptContentID)
|
|
|
|
// Update script1 to different content
|
|
updated, err := ds.UpdateScriptContents(ctx, script1.ID, "echo modified")
|
|
require.NoError(t, err)
|
|
// ScriptContents is not populated from the DB, check via GetScriptContents
|
|
// GetScriptContents takes a script ID, not script_content_id
|
|
updatedContents, err := ds.GetScriptContents(ctx, script1.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo modified", string(updatedContents))
|
|
|
|
// CRITICAL: Verify script2 still has the original content
|
|
s2After, err := ds.Script(ctx, script2.ID)
|
|
require.NoError(t, err)
|
|
s2Contents, err := ds.GetScriptContents(ctx, script2.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sharedContent, string(s2Contents))
|
|
require.NotEqual(t, updated.ScriptContentID, s2After.ScriptContentID)
|
|
}
|
|
|
|
// Test updating script to same content -- a no-op case
|
|
func testUpdateScriptToSameContent(t *testing.T, ds *Datastore) {
|
|
ctx := t.Context()
|
|
|
|
// Create a script
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "script.sh",
|
|
ScriptContents: "echo hello",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
s, err := ds.Script(ctx, script.ID)
|
|
require.NoError(t, err)
|
|
originalContentID := s.ScriptContentID
|
|
|
|
// Update with the same content
|
|
_, err = ds.UpdateScriptContents(ctx, script.ID, "echo hello")
|
|
require.NoError(t, err)
|
|
updatedContents, err := ds.GetScriptContents(ctx, script.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "echo hello", string(updatedContents))
|
|
|
|
// Verify content ID hasn't changed
|
|
sAfter, err := ds.Script(ctx, script.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, originalContentID, sAfter.ScriptContentID)
|
|
}
|
|
|
|
func testCountHostScriptAttempts(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// Create test data
|
|
host := test.NewHost(t, ds, "host1", "10.0.0.1", "host1Key", "host1UUID", time.Now())
|
|
user := test.NewUser(t, ds, "User", "test@example.com", true)
|
|
|
|
policy, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
|
|
Name: "policy",
|
|
Query: "SELECT 1;",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "test.sh",
|
|
ScriptContents: "echo test",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// no attempts exist, count 0
|
|
count, err := ds.CountHostScriptAttempts(ctx, host.ID, script.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
|
|
// script execution attempt
|
|
execReq1, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
ScriptID: &script.ID,
|
|
PolicyID: &policy.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execReq1.ExecutionID)
|
|
|
|
// Set result for first attempt
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host.ID,
|
|
ExecutionID: execReq1.ExecutionID,
|
|
Output: "output1",
|
|
Runtime: 1,
|
|
ExitCode: 1, // failed
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
|
|
// 1 attempt, count should be 1
|
|
count, err = ds.CountHostScriptAttempts(ctx, host.ID, script.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, count)
|
|
|
|
// retry
|
|
execReq2, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
ScriptID: &script.ID,
|
|
PolicyID: &policy.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execReq2.ExecutionID)
|
|
|
|
// Set result for second attempt
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host.ID,
|
|
ExecutionID: execReq2.ExecutionID,
|
|
Output: "output2",
|
|
Runtime: 2,
|
|
ExitCode: 1, // failed again
|
|
}, ptr.Int(2))
|
|
require.NoError(t, err)
|
|
|
|
// 2 attempts, count should be 2
|
|
count, err = ds.CountHostScriptAttempts(ctx, host.ID, script.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 2, count)
|
|
|
|
// retry
|
|
execReq3, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
ScriptID: &script.ID,
|
|
PolicyID: &policy.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execReq3.ExecutionID)
|
|
|
|
// Set result for third attempt
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host.ID,
|
|
ExecutionID: execReq3.ExecutionID,
|
|
Output: "output3",
|
|
Runtime: 3,
|
|
ExitCode: 0, // success
|
|
}, ptr.Int(3))
|
|
require.NoError(t, err)
|
|
|
|
// 3 attempts, count should be 3
|
|
count, err = ds.CountHostScriptAttempts(ctx, host.ID, script.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 3, count)
|
|
|
|
// script execution but without policy_id
|
|
execReq4, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
ScriptID: &script.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, execReq4.ExecutionID)
|
|
|
|
_, _, err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
|
|
HostID: host.ID,
|
|
ExecutionID: execReq4.ExecutionID,
|
|
Output: "output4",
|
|
Runtime: 4,
|
|
ExitCode: 0,
|
|
}, ptr.Int(0))
|
|
require.NoError(t, err)
|
|
|
|
// Count should not change
|
|
count, err = ds.CountHostScriptAttempts(ctx, host.ID, script.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 3, count)
|
|
|
|
// new host
|
|
host2 := test.NewHost(t, ds, "host2", "10.0.0.2", "host2Key", "host2UUID", time.Now())
|
|
count, err = ds.CountHostScriptAttempts(ctx, host2.ID, script.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
|
|
// Same host, different policy
|
|
policy2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
|
|
Name: "test policy 2",
|
|
Query: "SELECT 2;",
|
|
})
|
|
require.NoError(t, err)
|
|
count, err = ds.CountHostScriptAttempts(ctx, host.ID, script.ID, policy2.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
|
|
// Same host and policy, new script
|
|
script2, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "test2.sh",
|
|
ScriptContents: "echo test2",
|
|
})
|
|
require.NoError(t, err)
|
|
count, err = ds.CountHostScriptAttempts(ctx, host.ID, script2.ID, policy.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
}
|
|
|
|
func testScriptModificationResetsAttemptNumber(t *testing.T, ds *Datastore) {
|
|
ctx := context.Background()
|
|
|
|
// Create a team
|
|
team, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name()})
|
|
require.NoError(t, err)
|
|
|
|
// Create a policy
|
|
policy, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
|
|
Name: t.Name(),
|
|
Query: "SELECT 1;",
|
|
Platform: "darwin",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create script content
|
|
var scriptContentID int64
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
res, err := q.ExecContext(ctx, `INSERT INTO script_contents (md5_checksum, contents) VALUES (?, ?)`,
|
|
"md5hash", "echo 'v1'")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
scriptContentID, err = res.LastInsertId()
|
|
return err
|
|
})
|
|
|
|
// Create a script
|
|
script, err := ds.NewScript(ctx, &fleet.Script{
|
|
Name: "test.sh",
|
|
TeamID: &team.ID,
|
|
ScriptContentID: uint(scriptContentID), //nolint:gosec // dismiss G115
|
|
ScriptContents: "echo 'v1'",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Completed first attempt (exit_code IS NOT NULL, attempt_number = 1)
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
_, err := q.ExecContext(ctx, `
|
|
INSERT INTO host_script_results (host_id, execution_id, script_content_id, output, exit_code, script_id, policy_id, attempt_number)
|
|
VALUES (1, 'exec-1', ?, 'output', 1, ?, ?, 1)
|
|
`, scriptContentID, script.ID, policy.ID)
|
|
return err
|
|
})
|
|
// Pending second attempt (exit_code IS NULL, attempt_number = 2)
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
_, err := q.ExecContext(ctx, `
|
|
INSERT INTO host_script_results (host_id, execution_id, script_content_id, output, exit_code, script_id, policy_id, attempt_number)
|
|
VALUES (1, 'exec-2', ?, '', NULL, ?, ?, 2)
|
|
`, scriptContentID, script.ID, policy.ID)
|
|
return err
|
|
})
|
|
|
|
// Update script contents - this should reset all attempt_number to 0
|
|
_, err = ds.UpdateScriptContents(ctx, script.ID, "echo 'v2'")
|
|
require.NoError(t, err)
|
|
|
|
// Verify results
|
|
type result struct {
|
|
ExecutionID string `db:"execution_id"`
|
|
ExitCode *int64 `db:"exit_code"`
|
|
AttemptNumber *int64 `db:"attempt_number"`
|
|
Canceled bool `db:"canceled"`
|
|
}
|
|
var results []result
|
|
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
|
return sqlx.SelectContext(ctx, q, &results, `
|
|
SELECT execution_id, exit_code, attempt_number, canceled
|
|
FROM host_script_results
|
|
WHERE script_id = ? AND policy_id = ?
|
|
ORDER BY execution_id ASC
|
|
`, script.ID, policy.ID)
|
|
})
|
|
|
|
require.Len(t, results, 2)
|
|
|
|
// completed, reset to 0, not canceled
|
|
require.Equal(t, "exec-1", results[0].ExecutionID)
|
|
require.NotNil(t, results[0].AttemptNumber)
|
|
require.Equal(t, int64(0), *results[0].AttemptNumber)
|
|
require.False(t, results[0].Canceled)
|
|
|
|
// pending, reset to 0, canceled
|
|
require.Equal(t, "exec-2", results[1].ExecutionID)
|
|
require.NotNil(t, results[1].AttemptNumber)
|
|
require.Equal(t, int64(0), *results[1].AttemptNumber)
|
|
require.True(t, results[1].Canceled)
|
|
}
|