diff --git a/server/service/global_policies.go b/server/service/global_policies.go index 28a9c3789a..98ae211718 100644 --- a/server/service/global_policies.go +++ b/server/service/global_policies.go @@ -609,9 +609,32 @@ func (e AutofillError) Internal() string { } func (svc *Service) AutofillPolicySql(ctx context.Context, sql string) (description string, resolution string, err error) { + vc, ok := viewer.FromContext(ctx) + if !ok { + svc.authz.SkipAuthorization(ctx) + return "", "", fleet.ErrNoContext + } + // We expect that only users with policy write permissions will autofill policies. - if err = svc.authz.Authorize(ctx, &fleet.Policy{}, fleet.ActionWrite); err != nil { - return "", "", err + if vc.User.GlobalRole != nil || len(vc.User.Teams) == 0 { + if err = svc.authz.Authorize(ctx, &fleet.Policy{}, fleet.ActionWrite); err != nil { + return "", "", err + } + } else { + // Check if this user has team policy write permissions. + teamID := vc.User.Teams[0].Team.ID + for _, teamUser := range vc.User.Teams { + if teamUser.Role == fleet.RoleAdmin || teamUser.Role == fleet.RoleMaintainer || teamUser.Role == fleet.RoleGitOps { + teamID = teamUser.Team.ID + break + } + } + err = svc.authz.Authorize( + ctx, &fleet.Policy{PolicyData: fleet.PolicyData{TeamID: &teamID}}, fleet.ActionWrite, + ) + if err != nil { + return "", "", err + } } appConfig, err := svc.ds.AppConfig(ctx) diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index d7905364f5..c38ee3b5ce 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -8737,3 +8737,127 @@ func (s *integrationEnterpriseTestSuite) cleanupQuery(queryID uint) { var delResp deleteQueryByIDResponse s.DoJSON("DELETE", fmt.Sprintf("/api/latest/fleet/queries/id/%d", queryID), nil, http.StatusOK, &delResp) } + +func (s *integrationEnterpriseTestSuite) TestAutofillPoliciesAuthTeamUser() { + t := s.T() + startMockServer := func(t *testing.T) string { + // create a test http server + srv := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + switch r.URL.Path { + case "/ok": + var body map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + t.Log(err) + w.WriteHeader(http.StatusBadRequest) + return + } + _, _ = w.Write([]byte(`{"risks":"description", "whatWillProbablyHappenDuringMaintenance":"resolution"}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + }, + ), + ) + t.Cleanup(srv.Close) + return srv.URL + } + mockUrl := startMockServer(t) + originalUrl := getHumanInterpretationFromOsquerySqlUrl + originalTimeout := getHumanInterpretationFromOsquerySqlTimeout + t.Cleanup( + func() { + getHumanInterpretationFromOsquerySqlUrl = originalUrl + getHumanInterpretationFromOsquerySqlTimeout = originalTimeout + }, + ) + + // Create teams + team1, err := s.ds.NewTeam( + context.Background(), &fleet.Team{ + ID: 42, + Name: "team1" + t.Name(), + Description: "desc team1", + }, + ) + require.NoError(t, err) + team2, err := s.ds.NewTeam( + context.Background(), &fleet.Team{ + ID: 43, + Name: "team2" + t.Name(), + Description: "desc team2", + }, + ) + require.NoError(t, err) + + oldToken := s.token + t.Cleanup( + func() { + s.token = oldToken + }, + ) + + switchUser := func(t *testing.T, role string) { + password := test.GoodPassword + email := role + "-testteam@user.com" + u := &fleet.User{ + Name: "test team user", + Email: email, + GlobalRole: nil, + Teams: []fleet.UserTeam{ + { + Team: *team2, + Role: fleet.RoleObserver, + }, + { + Team: *team1, + Role: role, + }, + }, + } + require.NoError(t, u.SetPassword(password, 10, 10)) + _, err = s.ds.NewUser(context.Background(), u) + require.NoError(t, err) + + s.token = s.getTestToken(email, password) + } + + req := autofillPoliciesRequest{ + SQL: "select 1", + } + getHumanInterpretationFromOsquerySqlUrl = mockUrl + "/ok" + + tests := []struct { + role string + pass bool + }{ + {role: fleet.RoleAdmin, pass: true}, + {role: fleet.RoleMaintainer, pass: true}, + {role: fleet.RoleGitOps, pass: true}, + {role: fleet.RoleObserver, pass: false}, + {role: fleet.RoleObserverPlus, pass: false}, + } + + for _, tt := range tests { + t.Run( + tt.role, func(t *testing.T) { + switchUser(t, tt.role) + if tt.pass { + var res autofillPoliciesResponse + s.DoJSON("POST", "/api/latest/fleet/autofill/policy", req, http.StatusOK, &res) + assert.Equal(t, "description", res.Description) + assert.Equal(t, "resolution", res.Resolution) + } else { + _ = s.Do("POST", "/api/latest/fleet/autofill/policy", req, http.StatusForbidden) + } + }, + ) + } + +}