fleet/server/service/testing_client.go
Roberto Dip 33d788caed
add support for displaying EULA during ADE/DEP (#11532)
Related to #11350 and the sub-tasks for stuff that happens in setup
assistant: #11477 and #11479

This adds back-end and UI logic to show an EULA during DEP enrollment if
one was uploaded via the UI, if an EULA wasn't uploaded, we just proceed
to enroll the device right after authentication.


https://user-images.githubusercontent.com/4419992/236316655-282ee74a-5f79-4095-a950-82b77b80a5c0.mov
2023-05-05 14:36:13 -03:00

400 lines
12 KiB
Go

package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"regexp"
"sync"
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/live_query/live_query_mock"
"github.com/fleetdm/fleet/v4/server/pubsub"
"github.com/fleetdm/fleet/v4/server/sso"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/ghodss/yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type withDS struct {
s *suite.Suite
ds *mysql.Datastore
}
func (ts *withDS) SetupSuite(dbName string) {
t := ts.s.T()
ts.ds = mysql.CreateNamedMySQLDS(t, dbName)
test.AddAllHostsLabel(t, ts.ds)
// setup the required fields on AppConfig
appConf, err := ts.ds.AppConfig(context.Background())
require.NoError(t, err)
appConf.OrgInfo.OrgName = "FleetTest"
appConf.ServerSettings.ServerURL = "https://example.org"
err = ts.ds.SaveAppConfig(context.Background(), appConf)
require.NoError(t, err)
}
func (ts *withDS) TearDownSuite() {
ts.ds.Close()
}
type withServer struct {
withDS
server *httptest.Server
users map[string]fleet.User
token string
cachedAdminToken string
cachedTokensMu sync.Mutex
cachedTokens map[string]string // email -> auth token
lq *live_query_mock.MockLiveQuery
}
func (ts *withServer) SetupSuite(dbName string) {
ts.withDS.SetupSuite(dbName)
rs := pubsub.NewInmemQueryResults()
cfg := config.TestConfig()
cfg.Osquery.EnrollCooldown = 0
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, &TestServerOpts{
Rs: rs,
Lq: ts.lq,
FleetConfig: &cfg,
})
ts.server = server
ts.users = users
ts.token = ts.getTestAdminToken()
ts.cachedAdminToken = ts.token
}
func (ts *withServer) TearDownSuite() {
ts.withDS.TearDownSuite()
}
func (ts *withServer) commonTearDownTest(t *testing.T) {
ctx := context.Background()
u := ts.users["admin1@example.com"]
filter := fleet.TeamFilter{User: &u}
hosts, err := ts.ds.ListHosts(ctx, filter, fleet.HostListOptions{})
require.NoError(t, err)
for _, host := range hosts {
require.NoError(t, ts.ds.UpdateHostSoftware(context.Background(), host.ID, nil))
require.NoError(t, ts.ds.DeleteHost(ctx, host.ID))
}
// recalculate software counts will remove the software entries
require.NoError(t, ts.ds.SyncHostsSoftware(context.Background(), time.Now()))
lbls, err := ts.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{})
require.NoError(t, err)
for _, lbl := range lbls {
if lbl.LabelType != fleet.LabelTypeBuiltIn {
err := ts.ds.DeleteLabel(ctx, lbl.Name)
require.NoError(t, err)
}
}
users, err := ts.ds.ListUsers(ctx, fleet.UserListOptions{})
require.NoError(t, err)
for _, u := range users {
if _, ok := ts.users[u.Email]; !ok {
err := ts.ds.DeleteUser(ctx, u.ID)
require.NoError(t, err)
}
}
teams, err := ts.ds.ListTeams(ctx, fleet.TeamFilter{User: &u}, fleet.ListOptions{})
require.NoError(t, err)
for _, tm := range teams {
err := ts.ds.DeleteTeam(ctx, tm.ID)
require.NoError(t, err)
}
globalPolicies, err := ts.ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
if len(globalPolicies) > 0 {
var globalPolicyIDs []uint
for _, gp := range globalPolicies {
globalPolicyIDs = append(globalPolicyIDs, gp.ID)
}
_, err = ts.ds.DeleteGlobalPolicies(ctx, globalPolicyIDs)
require.NoError(t, err)
}
// SyncHostsSoftware performs a cleanup.
err = ts.ds.SyncHostsSoftware(ctx, time.Now())
require.NoError(t, err)
}
func (ts *withServer) Do(verb, path string, params interface{}, expectedStatusCode int, queryParams ...string) *http.Response {
t := ts.s.T()
j, err := json.Marshal(params)
require.NoError(t, err)
resp := ts.DoRaw(verb, path, j, expectedStatusCode, queryParams...)
t.Cleanup(func() {
resp.Body.Close()
})
return resp
}
func (ts *withServer) DoRawWithHeaders(
verb string, path string, rawBytes []byte, expectedStatusCode int, headers map[string]string, queryParams ...string,
) *http.Response {
t := ts.s.T()
requestBody := io.NopCloser(bytes.NewBuffer(rawBytes))
req, err := http.NewRequest(verb, ts.server.URL+path, requestBody)
require.NoError(t, err)
for key, val := range headers {
req.Header.Add(key, val)
}
client := fleethttp.NewClient()
if len(queryParams)%2 != 0 {
require.Fail(t, "need even number of params: key value")
}
if len(queryParams) > 0 {
q := req.URL.Query()
for i := 0; i < len(queryParams); i += 2 {
q.Add(queryParams[i], queryParams[i+1])
}
req.URL.RawQuery = q.Encode()
}
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, expectedStatusCode, resp.StatusCode)
return resp
}
func (ts *withServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int, queryParams ...string) *http.Response {
return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", ts.token),
}, queryParams...)
}
func (ts *withServer) DoRawNoAuth(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response {
return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, nil)
}
func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStatusCode int, v interface{}, queryParams ...string) {
resp := ts.Do(verb, path, params, expectedStatusCode, queryParams...)
err := json.NewDecoder(resp.Body).Decode(v)
require.NoError(ts.s.T(), err)
if e, ok := v.(errorer); ok {
require.NoError(ts.s.T(), e.error())
}
}
func (ts *withServer) getTestAdminToken() string {
testUser := testUsers["admin1"]
// because the login endpoint is rate-limited, use the cached admin token
// if available (if for some reason a test needs to logout the admin user,
// then set cachedAdminToken = "" so that a new token is retrieved).
if ts.cachedAdminToken == "" {
ts.cachedAdminToken = ts.getTestToken(testUser.Email, testUser.PlaintextPassword)
}
return ts.cachedAdminToken
}
// getCachedUserToken returns the cached auth token for the given test user email.
// If it's not found, then a login request is performed and the token cached.
func (ts *withServer) getCachedUserToken(email, password string) string {
ts.cachedTokensMu.Lock()
defer ts.cachedTokensMu.Unlock()
if ts.cachedTokens == nil {
ts.cachedTokens = make(map[string]string)
}
token, ok := ts.cachedTokens[email]
if !ok {
token = ts.getTestToken(email, password)
ts.cachedTokens[email] = token
}
return token
}
func (ts *withServer) getTestToken(email string, password string) string {
params := loginRequest{
Email: email,
Password: password,
}
j, err := json.Marshal(&params)
require.NoError(ts.s.T(), err)
requestBody := io.NopCloser(bytes.NewBuffer(j))
resp, err := http.Post(ts.server.URL+"/api/latest/fleet/login", "application/json", requestBody)
require.NoError(ts.s.T(), err)
defer resp.Body.Close()
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
jsn := struct {
User *fleet.User `json:"user"`
Token string `json:"token"`
Err []map[string]string `json:"errors,omitempty"`
}{}
err = json.NewDecoder(resp.Body).Decode(&jsn)
require.NoError(ts.s.T(), err)
require.Len(ts.s.T(), jsn.Err, 0)
return jsn.Token
}
func (ts *withServer) applyConfig(spec []byte) {
var appConfigSpec interface{}
err := yaml.Unmarshal(spec, &appConfigSpec)
require.NoError(ts.s.T(), err)
ts.Do("PATCH", "/api/latest/fleet/config", appConfigSpec, http.StatusOK)
}
func (ts *withServer) getConfig() *appConfigResponse {
var responseBody *appConfigResponse
ts.DoJSON("GET", "/api/latest/fleet/config", nil, http.StatusOK, &responseBody)
return responseBody
}
func (ts *withServer) LoginSSOUser(username, password string) (fleet.Auth, string) {
t := ts.s.T()
auth, res := ts.loginSSOUser(username, password, "/api/v1/fleet/sso", http.StatusOK)
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
return auth, string(body)
}
func (ts *withServer) LoginMDMSSOUser(username, password string) *http.Response {
_, res := ts.loginSSOUser(username, password, "/api/v1/fleet/mdm/sso", http.StatusTemporaryRedirect)
return res
}
func (ts *withServer) loginSSOUser(username, password string, basePath string, callbackStatus int) (fleet.Auth, *http.Response) {
t := ts.s.T()
if _, ok := os.LookupEnv("SAML_IDP_TEST"); !ok {
t.Skip("SSO tests are disabled")
}
var resIni initiateSSOResponse
ts.DoJSON("POST", basePath, map[string]string{}, http.StatusOK, &resIni)
jar, err := cookiejar.New(nil)
require.NoError(t, err)
client := fleethttp.NewClient(
fleethttp.WithFollowRedir(false),
fleethttp.WithCookieJar(jar),
)
resp, err := client.Get(resIni.URL)
require.NoError(t, err)
// From the redirect Location header we can get the AuthState and the URL to
// which we submit the credentials
parsed, err := url.Parse(resp.Header.Get("Location"))
require.NoError(t, err)
data := url.Values{
"username": {username},
"password": {password},
"AuthState": {parsed.Query().Get("AuthState")},
}
resp, err = client.PostForm(parsed.Scheme+"://"+parsed.Host+parsed.Path, data)
require.NoError(t, err)
// The response is an HTML form, we can extract the base64-encoded response
// to submit to the Fleet server from here
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
re := regexp.MustCompile(`value="(.*)"`)
matches := re.FindSubmatch(body)
require.NotEmptyf(t, matches, "callback HTML doesn't contain a SAMLResponse value, got body: %s", body)
rawSSOResp := string(matches[1])
auth, err := sso.DecodeAuthResponse(rawSSOResp)
require.NoError(t, err)
q := url.QueryEscape(rawSSOResp)
res := ts.DoRawNoAuth("POST", basePath+"/callback?SAMLResponse="+q, nil, callbackStatus)
return auth, res
}
// gets the latest activity and checks that it matches any provided properties.
// empty string or 0 id means do not check that property. It returns the ID of that
// latest activity.
func (ts *withServer) lastActivityMatches(name, details string, id uint) uint {
t := ts.s.T()
var listActivities listActivitiesResponse
ts.DoJSON("GET", "/api/latest/fleet/activities", nil, http.StatusOK, &listActivities, "order_key", "a.id", "order_direction", "desc", "per_page", "1")
require.True(t, len(listActivities.Activities) > 0)
act := listActivities.Activities[0]
if name != "" {
assert.Equal(t, name, act.Type)
}
if details != "" {
require.NotNil(t, act.Details)
assert.JSONEq(t, details, string(*act.Details))
}
if id > 0 {
assert.Equal(t, id, act.ID)
}
return act.ID
}
// gets the latest activity with the specified type name and checks that it
// matches any provided properties. empty string or 0 id means do not check
// that property. It returns the ID of that latest activity.
//
// The difference with lastActivityMatches is that the asserted activity does
// not need to be the very last one, it will look for the last one of this
// specified type, which must be in one of the last 10 activities otherwise the
// test is failed.
func (ts *withServer) lastActivityOfTypeMatches(name, details string, id uint) uint {
t := ts.s.T()
var listActivities listActivitiesResponse
ts.DoJSON("GET", "/api/latest/fleet/activities", nil, http.StatusOK,
&listActivities, "order_key", "a.id", "order_direction", "desc", "per_page", "10")
require.True(t, len(listActivities.Activities) > 0)
for _, act := range listActivities.Activities {
if act.Type == name {
if details != "" {
require.NotNil(t, act.Details)
assert.JSONEq(t, details, string(*act.Details))
}
if id > 0 {
assert.Equal(t, id, act.ID)
}
return act.ID
}
}
t.Fatalf("no activity of type %s found in the last %d activities", name, len(listActivities.Activities))
return 0
}