From ba211437c68b1f3429ca4bb6a7f1c838e0f60119 Mon Sep 17 00:00:00 2001 From: Sarah Gillespie <73313222+gillespi314@users.noreply.github.com> Date: Tue, 5 Mar 2024 15:12:52 -0600 Subject: [PATCH] Update integration tests for run script by name feature (#17381) --- server/fleet/scripts.go | 7 ++ server/service/integration_core_test.go | 4 +- server/service/integration_enterprise_test.go | 86 +++++++++++++++++-- server/service/scripts.go | 9 +- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/server/fleet/scripts.go b/server/fleet/scripts.go index 4fdc583727..0bfd620fbb 100644 --- a/server/fleet/scripts.go +++ b/server/fleet/scripts.go @@ -150,6 +150,13 @@ type HostScriptRequestPayload struct { } func (r HostScriptRequestPayload) ValidateParams(waitForResult time.Duration) error { + if r.ScriptContents == "" && r.ScriptID == nil && r.ScriptName == "" { + if waitForResult <= 0 { + return NewInvalidArgumentError("script", `Script contents must not be empty.`) + } + return NewInvalidArgumentError("script", `One of 'script_id', 'script_contents', or 'script_name' is required.`) + } + if r.ScriptID != nil { switch { case r.ScriptContents != "": diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 568e74a993..c1e86e410f 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -5409,10 +5409,10 @@ func (s *integrationTestSuite) TestScriptsEndpointsWithoutLicense() { // run a script var runResp runScriptResponse - s.DoJSON("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: 1}, http.StatusNotFound, &runResp) + s.DoJSON("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: 1, ScriptContents: "echo foo"}, http.StatusNotFound, &runResp) // run a script sync - s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: 1}, http.StatusNotFound, &runResp) + s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: 1, ScriptContents: "echo foo"}, http.StatusNotFound, &runResp) // get script result var scriptResultResp getScriptResultResponse diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index a2773a5333..83441a6212 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -4676,9 +4676,9 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() { require.Contains(t, errMsg, "Script contents must not be empty.") // attempt to run an overly long script - res = s.Do("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: strings.Repeat("a", 500001)}, http.StatusUnprocessableEntity) + res = s.Do("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: strings.Repeat("a", fleet.UnsavedScriptMaxRuneLen+1)}, http.StatusUnprocessableEntity) errMsg = extractServerErrorText(res.Body) - require.Contains(t, errMsg, "Script is too large.") + require.Contains(t, errMsg, "Script is too large. It's limited to 10,000 characters") // make sure the host is still seen as "online" err := s.ds.MarkHostsSeen(ctx, []uint{host.ID}, time.Now()) @@ -4783,12 +4783,12 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() { // attempt to sync run an empty script res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: ""}, http.StatusUnprocessableEntity) errMsg = extractServerErrorText(res.Body) - require.Contains(t, errMsg, "Script contents must not be empty.") + require.Contains(t, errMsg, "One of 'script_id', 'script_contents', or 'script_name' is required.") // attempt to sync run an overly long script - res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: strings.Repeat("a", 500001)}, http.StatusUnprocessableEntity) + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: strings.Repeat("a", fleet.UnsavedScriptMaxRuneLen+1)}, http.StatusUnprocessableEntity) errMsg = extractServerErrorText(res.Body) - require.Contains(t, errMsg, "Script is too large.") + require.Contains(t, errMsg, "Script is too large. It's limited to 10,000 characters") // make sure the host is still seen as "online" err = s.ds.MarkHostsSeen(ctx, []uint{host.ID}, time.Now()) @@ -5063,6 +5063,36 @@ func (s *integrationEnterpriseTestSuite) TestRunHostSavedScript() { require.True(t, runSyncResp.HostTimeout) require.Contains(t, runSyncResp.Message, fleet.RunScriptHostTimeoutErrMsg) + // attempt to run sync with both script contents and script id + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo", ScriptID: ptr.Uint(savedTmScript.ID + 999)}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `Only one of 'script_id' or 'script_contents' is allowed.`) + + // attempt to run sync with both script contents and script name + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo", ScriptName: savedTmScript.Name}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `Only one of 'script_contents' or 'script_name' is allowed.`) + + // attempt to run sync with both script id and script name + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptID: ptr.Uint(savedTmScript.ID + 999), ScriptName: savedTmScript.Name}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `Only one of 'script_id' or 'script_name' is allowed.`) + + // attempt to run sync with both script contents and team id + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo", TeamID: 1}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `Only one of 'script_contents' or 'team_id' is allowed.`) + + // attempt to run sync with both script id and team id + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptID: ptr.Uint(savedTmScript.ID + 999), TeamID: 1}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `Only one of 'script_id' or 'team_id' is allowed.`) + + // attempt to run sync without script contents, script id, or script name + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `One of 'script_id', 'script_contents', or 'script_name' is required.`) + // deleting the saved script does not impact the pending script s.Do("DELETE", fmt.Sprintf("/api/latest/fleet/scripts/%d", savedNoTmScript.ID), nil, http.StatusNoContent) @@ -5102,6 +5132,50 @@ func (s *integrationEnterpriseTestSuite) TestRunHostSavedScript() { s.DoJSON("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptID: &script.ID}, http.StatusConflict, &runResp) + // set up a new host, new team, and some new scripts + host2 := createOrbitEnrolledHost(t, "linux", "f1337", s.ds) + tm2, err := s.ds.NewTeam(ctx, &fleet.Team{Name: "team 2"}) + require.NoError(t, err) + savedNoTmScript2, err := s.ds.NewScript(ctx, &fleet.Script{ + TeamID: nil, + Name: "f1337.sh", + ScriptContents: "echo 'ALL YOUR BASE ARE BELONG TO US'", + }) + require.NoError(t, err) + savedTmScript2, err := s.ds.NewScript(ctx, &fleet.Script{ + TeamID: &tm2.ID, + Name: "f1337.sh", + ScriptContents: "echo 'ALL YOUR BASE ARE BELONG TO US'", + }) + require.NoError(t, err) + require.NotEqual(t, savedNoTmScript2.ID, savedTmScript2.ID) + + // make sure the new host is seen as "online" + err = s.ds.MarkHostsSeen(ctx, []uint{host2.ID}, time.Now()) + require.NoError(t, err) + + // attempt to run sync with a script that does not exist on the specified team + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host2.ID, ScriptName: "f1337.sh", TeamID: tm.ID}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `Script 'f1337.sh' doesn’t exist.`) + + // attempt to run sync with an existing team script that belongs to a team different from the host's team + res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host2.ID, ScriptName: "f1337.sh", TeamID: tm2.ID}, http.StatusUnprocessableEntity) + errMsg = extractServerErrorText(res.Body) + require.Contains(t, errMsg, `The script does not belong to the same team`) + + // create a valid sync script execution request by script name, fails because the + // request will time-out waiting for a result. + var runSyncResp2 runScriptSyncResponse + s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host2.ID, ScriptName: "f1337.sh"}, http.StatusRequestTimeout, &runSyncResp2) + require.Equal(t, host2.ID, runSyncResp2.HostID) + require.NotEmpty(t, runSyncResp2.ExecutionID) + require.NotNil(t, runSyncResp2.ScriptID) + require.Equal(t, savedNoTmScript2.ID, *runSyncResp2.ScriptID) + require.Equal(t, "echo 'ALL YOUR BASE ARE BELONG TO US'", runSyncResp2.ScriptContents) + require.True(t, runSyncResp2.HostTimeout) + require.Contains(t, runSyncResp2.Message, fleet.RunScriptHostTimeoutErrMsg) + // attempt to run a script on a plain osquery host plainOsqueryHost, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), @@ -5392,7 +5466,7 @@ func (s *integrationEnterpriseTestSuite) TestSavedScripts() { // file content is too large body, headers = generateNewScriptMultipartRequest(t, - "script2.sh", []byte(strings.Repeat("a", 500001)), s.token, nil) + "script2.sh", []byte(strings.Repeat("a", fleet.SavedScriptMaxRuneLen+1)), s.token, nil) res = s.DoRawWithHeaders("POST", "/api/latest/fleet/scripts", body.Bytes(), http.StatusUnprocessableEntity, headers) errMsg = extractServerErrorText(res.Body) require.Contains(t, errMsg, "Script is too large. It's limited to 500,000 characters") diff --git a/server/service/scripts.go b/server/service/scripts.go index 727c69e663..e383b49746 100644 --- a/server/service/scripts.go +++ b/server/service/scripts.go @@ -127,7 +127,14 @@ func (svc *Service) GetScriptIDByName(ctx context.Context, scriptName string, te return 0, err } - return svc.ds.GetScriptIDByName(ctx, scriptName, teamID) + id, err := svc.ds.GetScriptIDByName(ctx, scriptName, teamID) + if err != nil { + if fleet.IsNotFound(err) { + return 0, fleet.NewInvalidArgumentError("script_name", fmt.Sprintf(`Script '%s' doesn’t exist.`, scriptName)) + } + return 0, err + } + return id, nil } const maxPendingScripts = 1000