From 35cfde8b34784ea44750d31b9dfc3d55213d6784 Mon Sep 17 00:00:00 2001 From: gillespi314 <73313222+gillespi314@users.noreply.github.com> Date: Thu, 12 Oct 2023 13:25:05 -0500 Subject: [PATCH] Always return empty host scripts details for unsupported platforms (#14451) --- ee/server/service/scripts.go | 15 +++-- server/service/integration_enterprise_test.go | 64 +++++++++++++++++-- server/service/scripts_test.go | 48 +++++++++++++- 3 files changed, 117 insertions(+), 10 deletions(-) diff --git a/ee/server/service/scripts.go b/ee/server/service/scripts.go index d2abd3745b..8c83e6070d 100644 --- a/ee/server/service/scripts.go +++ b/ee/server/service/scripts.go @@ -11,6 +11,7 @@ import ( "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/go-kit/kit/log/level" ) func (svc *Service) RunHostScript(ctx context.Context, request *fleet.HostScriptRequestPayload, waitForResult time.Duration) (*fleet.HostScriptResult, error) { @@ -360,6 +361,16 @@ func (svc *Service) GetHostScriptDetails(ctx context.Context, hostID uint, opt f return nil, nil, err } + if err := svc.authz.Authorize(ctx, &fleet.Script{TeamID: h.TeamID}, fleet.ActionRead); err != nil { + return nil, nil, err + } + + if h.Platform != "darwin" { + // only darwin is supported for now, all other platforms return empty results + level.Debug(svc.logger).Log("msg", "unsupported platform for host script details", "platform", h.Platform, "host_id", h.ID) + return []*fleet.HostScriptDetail{}, &fleet.PaginationMetadata{}, nil + } + // cursor-based pagination is not supported for scripts opt.After = "" // custom ordering is not supported, always by name @@ -370,10 +381,6 @@ func (svc *Service) GetHostScriptDetails(ctx context.Context, hostID uint, opt f // always include metadata for scripts opt.IncludeMetadata = true - if err := svc.authz.Authorize(ctx, &fleet.Script{TeamID: h.TeamID}, fleet.ActionRead); err != nil { - return nil, nil, err - } - return svc.ds.GetHostScriptDetails(ctx, h.ID, h.TeamID, opt) } diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 750e3e7ca2..8b2b12a571 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -4571,7 +4571,7 @@ func (s *integrationEnterpriseTestSuite) TestHostScriptDetails() { NodeKey: ptr.String("host1"), UUID: uuid.New().String(), Hostname: "host1", - Platform: "windows", + Platform: "darwin", TeamID: &tm1.ID, }) require.NoError(t, err) @@ -4586,11 +4586,41 @@ func (s *integrationEnterpriseTestSuite) TestHostScriptDetails() { NodeKey: ptr.String("host2"), UUID: uuid.New().String(), Hostname: "host2", - Platform: "linux", + Platform: "darwin", TeamID: &tm3.ID, }) require.NoError(t, err) + // create a Windows host (unsupported) + host3, err := s.ds.NewHost(ctx, &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-1 * time.Minute), + OsqueryHostID: ptr.String("host3"), + NodeKey: ptr.String("host3"), + UUID: uuid.New().String(), + Hostname: "host3", + Platform: "windows", + TeamID: nil, + }) + require.NoError(t, err) + + // create a Linux host (unsupported) + host4, err := s.ds.NewHost(ctx, &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-1 * time.Minute), + OsqueryHostID: ptr.String("host4"), + NodeKey: ptr.String("host4"), + UUID: uuid.New().String(), + Hostname: "host4", + Platform: "ubuntu", + TeamID: nil, + }) + require.NoError(t, err) + insertResults := func(t *testing.T, hostID uint, script *fleet.Script, createdAt time.Time, execID string, exitCode *int64) { stmt := ` INSERT INTO @@ -4772,6 +4802,30 @@ VALUES require.NotNil(t, resp.Scripts) require.Len(t, resp.Scripts, 0) }) + + t.Run("unsupported platform windows", func(t *testing.T) { + require.Nil(t, host3.TeamID) + noTeamScripts, _, err := s.ds.ListScripts(ctx, nil, fleet.ListOptions{}) + require.NoError(t, err) + require.True(t, len(noTeamScripts) > 0) + + var resp getHostScriptDetailsResponse + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/hosts/%d/scripts", host3.ID), nil, http.StatusOK, &resp) + require.NotNil(t, resp.Scripts) + require.Len(t, resp.Scripts, 0) + }) + + t.Run("unsupported platform linux", func(t *testing.T) { + require.Nil(t, host4.TeamID) + noTeamScripts, _, err := s.ds.ListScripts(ctx, nil, fleet.ListOptions{}) + require.NoError(t, err) + require.True(t, len(noTeamScripts) > 0) + + var resp getHostScriptDetailsResponse + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/hosts/%d/scripts", host4.ID), nil, http.StatusOK, &resp) + require.NotNil(t, resp.Scripts) + require.Len(t, resp.Scripts, 0) + }) } // generates the body and headers part of a multipart request ready to be @@ -5073,14 +5127,14 @@ func (s *integrationEnterpriseTestSuite) TestTeamConfigDetailQueriesOverrides() require.NoError(t, err) // get distributed queries for the host - s.lq.On("QueriesForHost", linuxHost.ID).Return(map[string]string{fmt.Sprintf("%d", linuxHost.ID): "select 1 from osquery;"}, nil) + s.lq.On("QueriesForHost", linuxHost.ID).Return(map[string]string{t.Name(): "select 1 from osquery;"}, nil) req := getDistributedQueriesRequest{NodeKey: *linuxHost.NodeKey} var dqResp getDistributedQueriesResponse s.DoJSON("POST", "/api/osquery/distributed/read", req, http.StatusOK, &dqResp) require.NotContains(t, dqResp.Queries, "fleet_detail_query_users") require.NotContains(t, dqResp.Queries, "fleet_detail_query_disk_encryption_linux") require.Contains(t, dqResp.Queries, "fleet_detail_query_software_linux") - require.Contains(t, dqResp.Queries, "fleet_distributed_query_21") + require.Contains(t, dqResp.Queries, fmt.Sprintf("fleet_distributed_query_%s", t.Name())) spec = []byte(fmt.Sprintf(` name: %s @@ -5108,5 +5162,5 @@ func (s *integrationEnterpriseTestSuite) TestTeamConfigDetailQueriesOverrides() require.Contains(t, dqResp.Queries, "fleet_detail_query_users") require.Contains(t, dqResp.Queries, "fleet_detail_query_disk_encryption_linux") require.Contains(t, dqResp.Queries, "fleet_detail_query_software_linux") - require.Contains(t, dqResp.Queries, "fleet_distributed_query_21") + require.Contains(t, dqResp.Queries, fmt.Sprintf("fleet_distributed_query_%s", t.Name())) } diff --git a/server/service/scripts_test.go b/server/service/scripts_test.go index 8251cc62d5..f07cc46c56 100644 --- a/server/service/scripts_test.go +++ b/server/service/scripts_test.go @@ -681,7 +681,7 @@ func TestSavedScripts(t *testing.T) { } } -func TestHostScriptDetails(t *testing.T) { +func TestHostScriptDetailsAuth(t *testing.T) { ds := new(mock.Store) license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)} svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true}) @@ -833,3 +833,49 @@ func TestHostScriptDetails(t *testing.T) { }) } } + +func TestHostScriptDetailsSupportedPlatform(t *testing.T) { + ds := new(mock.Store) + license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)} + svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true}) + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + + ds.GetHostScriptDetailsFunc = func(ctx context.Context, hostID uint, teamID *uint, opts fleet.ListOptions) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) { + return []*fleet.HostScriptDetail{{HostID: hostID, ScriptID: 1337, Name: "some-script.sh"}}, nil, nil + } + + for _, tt := range []struct { + platform string + supported bool + }{ + {"darwin", true}, + {"ubuntu", false}, + {"centos", false}, + {"rhel", false}, + {"debian", false}, + {"windows", false}, + } { + t.Run(tt.platform, func(t *testing.T) { + ds.GetHostScriptDetailsFuncInvoked = false + ds.HostLiteFunc = func(ctx context.Context, hostID uint) (*fleet.Host, error) { + return &fleet.Host{ID: hostID, Platform: tt.platform}, nil + } + + res, _, err := svc.GetHostScriptDetails(ctx, 42, fleet.ListOptions{}) + require.NoError(t, err) + if tt.supported { + require.NotNil(t, res) + require.Len(t, res, 1) + require.True(t, ds.GetHostScriptDetailsFuncInvoked) + } else { + require.NotNil(t, res) + require.Len(t, res, 0) + require.False(t, ds.GetHostScriptDetailsFuncInvoked) + } + }) + } +}