diff --git a/server/fleet/scripts.go b/server/fleet/scripts.go index a2887d73b9..56558cc10f 100644 --- a/server/fleet/scripts.go +++ b/server/fleet/scripts.go @@ -237,7 +237,7 @@ func (hsr HostScriptResult) UserMessage(hostTimeout bool) string { } func (hsr HostScriptResult) HostTimeout(waitForResultTime time.Duration) bool { - return hsr.ExitCode == nil && time.Now().After(hsr.CreatedAt.Add(waitForResultTime)) + return hsr.SyncRequest && hsr.ExitCode == nil && time.Now().After(hsr.CreatedAt.Add(waitForResultTime)) } const MaxScriptRuneLen = 10000 diff --git a/server/fleet/scripts_test.go b/server/fleet/scripts_test.go index d138cc6bf8..06ee3845a9 100644 --- a/server/fleet/scripts_test.go +++ b/server/fleet/scripts_test.go @@ -4,6 +4,7 @@ import ( "errors" "strings" "testing" + "time" "unicode/utf8" "github.com/stretchr/testify/require" @@ -102,3 +103,61 @@ func TestValidateHostScriptContents(t *testing.T) { }) } } + +func TestHostTimeout(t *testing.T) { + now := time.Now() + tests := []struct { + name string + hostScriptResult HostScriptResult + waitForResultTime time.Duration + expectedResult bool + }{ + { + name: "sync exitcode nil timeout passed", + hostScriptResult: HostScriptResult{ + SyncRequest: true, + ExitCode: nil, + CreatedAt: now.Add(-10 * time.Minute), + }, + waitForResultTime: 5 * time.Minute, + expectedResult: true, + }, + { + name: "sync exitcode nil timeout not passed", + hostScriptResult: HostScriptResult{ + SyncRequest: true, + ExitCode: nil, + CreatedAt: now.Add(-3 * time.Minute), + }, + waitForResultTime: 5 * time.Minute, + expectedResult: false, + }, + { + name: "sync exitcode set", + hostScriptResult: HostScriptResult{ + SyncRequest: true, + ExitCode: new(int64), + CreatedAt: now.Add(-10 * time.Minute), + }, + waitForResultTime: 5 * time.Minute, + expectedResult: false, + }, + { + name: "async exitcode nil", + hostScriptResult: HostScriptResult{ + SyncRequest: false, + ExitCode: nil, + CreatedAt: now.Add(-10 * time.Minute), + }, + waitForResultTime: 5 * time.Minute, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.hostScriptResult.HostTimeout(tt.waitForResultTime) + require.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index c743917a0a..9f695f8186 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -4464,6 +4464,23 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() { require.False(t, scriptResultResp.HostTimeout) require.Contains(t, scriptResultResp.Message, fleet.RunScriptAlreadyRunningErrMsg) + // an async script doesn't care about timeouts + now := time.Now() + mysql.ExecAdhocSQL(t, s.ds, func(tx sqlx.ExtContext) error { + _, err := tx.ExecContext(ctx, `UPDATE host_script_results SET created_at = ? WHERE execution_id = ?`, + now.Add(-1*time.Hour), + runResp.ExecutionID, + ) + return err + }) + scriptResultResp = getScriptResultResponse{} + s.DoJSON("GET", "/api/latest/fleet/scripts/results/"+runResp.ExecutionID, nil, http.StatusOK, &scriptResultResp) + require.Equal(t, host.ID, scriptResultResp.HostID) + require.Equal(t, "echo", scriptResultResp.ScriptContents) + require.Nil(t, scriptResultResp.ExitCode) + require.False(t, scriptResultResp.HostTimeout) + require.Contains(t, scriptResultResp.Message, fleet.RunScriptAlreadyRunningErrMsg) + // Disable scripts and verify that there are no Orbit notifs acr := appConfigResponse{} s.DoJSON("PATCH", "/api/latest/fleet/config", json.RawMessage(`{ @@ -5401,9 +5418,9 @@ func (s *integrationEnterpriseTestSuite) TestHostScriptDetails() { insertResults := func(t *testing.T, hostID uint, script *fleet.Script, createdAt time.Time, execID string, exitCode *int64) { stmt := ` INSERT INTO - host_script_results (%s host_id, created_at, execution_id, exit_code, script_contents, output) + host_script_results (%s host_id, created_at, execution_id, exit_code, script_contents, output, sync_request) VALUES - (%s ?,?,?,?,?,?)` + (%s ?,?,?,?,?,?, 1)` args := []interface{}{} if script.ID == 0 {