From 842ebbb2ae0ecee134d91f2e353a4c5f349b5f40 Mon Sep 17 00:00:00 2001 From: Roberto Dip Date: Wed, 1 Jun 2022 20:05:05 -0300 Subject: [PATCH] add Go client to consume device endpoints (#5987) This adds a new API client named DeviceClient to server/service, meant to consume device endpoints and be used from Fleet Desktop. Some of the logic to make requests and parse responses was very repetitive, so I introduced a private baseClient type and moved some of the logic of the existent Client there. Related to #5685 and #5697 --- server/service/base_client.go | 111 ++++++++++++++++++ ...client_errors.go => base_client_errors.go} | 1 + server/service/base_client_test.go | 107 +++++++++++++++++ server/service/client.go | 98 ++-------------- server/service/client_live_query_test.go | 14 ++- server/service/device_client.go | 59 ++++++++++ server/service/device_client_test.go | 60 ++++++++++ 7 files changed, 356 insertions(+), 94 deletions(-) create mode 100644 server/service/base_client.go rename server/service/{client_errors.go => base_client_errors.go} (96%) create mode 100644 server/service/base_client_test.go create mode 100644 server/service/device_client.go create mode 100644 server/service/device_client_test.go diff --git a/server/service/base_client.go b/server/service/base_client.go new file mode 100644 index 0000000000..cc926386c0 --- /dev/null +++ b/server/service/base_client.go @@ -0,0 +1,111 @@ +package service + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/fleetdm/fleet/v4/pkg/fleethttp" +) + +// httpClient interface allows the HTTP methods to be mocked. +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type baseClient struct { + baseURL *url.URL + http httpClient + urlPrefix string + insecureSkipVerify bool +} + +func (bc *baseClient) parseResponse(verb, path string, response *http.Response, responseDest interface{}) error { + switch response.StatusCode { + case http.StatusOK: + // ok + case http.StatusNotFound: + return notFoundErr{} + case http.StatusUnauthorized: + return ErrUnauthenticated + case http.StatusPaymentRequired: + return ErrMissingLicense + default: + return fmt.Errorf( + "%s %s received status %d %s", + verb, path, + response.StatusCode, + extractServerErrorText(response.Body), + ) + } + + if err := json.NewDecoder(response.Body).Decode(&responseDest); err != nil { + return fmt.Errorf("decode %s %s response: %w", verb, path, err) + } + + if e, ok := responseDest.(errorer); ok { + if e.error() != nil { + return fmt.Errorf("%s %s error: %w", verb, path, e.error()) + } + } + + return nil +} + +func (bc *baseClient) url(path, rawQuery string) *url.URL { + u := *bc.baseURL + u.Path = bc.urlPrefix + path + u.RawQuery = rawQuery + return &u +} + +func newBaseClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string) (*baseClient, error) { + baseURL, err := url.Parse(addr) + if err != nil { + return nil, fmt.Errorf("parsing URL: %w", err) + } + + if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") { + return nil, errors.New("address must start with https:// for remote connections") + } + + rootCAPool := x509.NewCertPool() + if rootCA != "" { + // read in the root cert file specified in the context + certs, err := ioutil.ReadFile(rootCA) + if err != nil { + return nil, fmt.Errorf("reading root CA: %w", err) + } + + // add certs to pool + if ok := rootCAPool.AppendCertsFromPEM(certs); !ok { + return nil, errors.New("failed to add certificates to root CA pool") + } + } else if !insecureSkipVerify { + // Use only the system certs (doesn't work on Windows) + rootCAPool, err = x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("loading system cert pool: %w", err) + } + } + + httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{ + InsecureSkipVerify: insecureSkipVerify, + RootCAs: rootCAPool, + })) + + client := &baseClient{ + baseURL: baseURL, + http: httpClient, + insecureSkipVerify: insecureSkipVerify, + urlPrefix: urlPrefix, + } + + return client, nil +} diff --git a/server/service/client_errors.go b/server/service/base_client_errors.go similarity index 96% rename from server/service/client_errors.go rename to server/service/base_client_errors.go index e860825caa..bc08d1c7cf 100644 --- a/server/service/client_errors.go +++ b/server/service/base_client_errors.go @@ -9,6 +9,7 @@ import ( var ( ErrUnauthenticated = errors.New("unauthenticated, or invalid token") + ErrMissingLicense = errors.New("missing or invalid license") ) type SetupAlreadyErr interface { diff --git a/server/service/base_client_test.go b/server/service/base_client_test.go new file mode 100644 index 0000000000..5ea1cdb40f --- /dev/null +++ b/server/service/base_client_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "bytes" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUrlGeneration(t *testing.T) { + t.Run("without prefix", func(t *testing.T) { + bc, err := newBaseClient("https://test.com", true, "", "") + require.NoError(t, err) + require.Equal(t, "https://test.com/test/path", bc.url("test/path", "").String()) + require.Equal(t, "https://test.com/test/path?raw=query", bc.url("test/path", "raw=query").String()) + }) + + t.Run("with prefix", func(t *testing.T) { + bc, err := newBaseClient("https://test.com", true, "", "prefix/") + require.NoError(t, err) + require.Equal(t, "https://test.com/prefix/test/path", bc.url("test/path", "").String()) + require.Equal(t, "https://test.com/prefix/test/path?raw=query", bc.url("test/path", "raw=query").String()) + }) +} + +func TestParseResponseKnownErrors(t *testing.T) { + cases := []struct { + message string + code int + out error + }{ + {"not found errors", http.StatusNotFound, notFoundErr{}}, + {"unauthenticated errors", http.StatusUnauthorized, ErrUnauthenticated}, + {"license errors", http.StatusPaymentRequired, ErrMissingLicense}, + } + + for _, c := range cases { + t.Run(c.message, func(t *testing.T) { + bc, err := newBaseClient("https://test.com", true, "", "") + require.NoError(t, err) + response := &http.Response{ + StatusCode: c.code, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)), + } + err = bc.parseResponse("GET", "", response, &struct{}{}) + require.ErrorIs(t, err, c.out) + }) + } +} + +func TestParseResponseOK(t *testing.T) { + bc, err := newBaseClient("https://test.com", true, "", "") + require.NoError(t, err) + response := &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)), + } + + var resDest struct{ Test string } + err = bc.parseResponse("", "", response, &resDest) + require.NoError(t, err) + require.Equal(t, "ok", resDest.Test) +} + +func TestParseResponseGeneralErrors(t *testing.T) { + t.Run("general HTTP errors", func(t *testing.T) { + bc, err := newBaseClient("https://test.com", true, "", "") + require.NoError(t, err) + response := &http.Response{ + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)), + } + err = bc.parseResponse("GET", "", response, &struct{}{}) + require.Error(t, err) + }) + + t.Run("parse errors", func(t *testing.T) { + bc, err := newBaseClient("https://test.com", true, "", "") + require.NoError(t, err) + response := &http.Response{ + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(bytes.NewBufferString(`invalid json`)), + } + err = bc.parseResponse("GET", "", response, &struct{}{}) + require.Error(t, err) + }) +} + +func TestNewBaseClient(t *testing.T) { + t.Run("invalid addresses are an error", func(t *testing.T) { + _, err := newBaseClient("invalid", true, "", "") + require.Error(t, err) + }) + + t.Run("http is only valid in development", func(t *testing.T) { + _, err := newBaseClient("http://test.com", true, "", "") + require.Error(t, err) + + _, err = newBaseClient("http://localhost:8080", true, "", "") + require.NoError(t, err) + + _, err = newBaseClient("http://127.0.0.1:8080", true, "", "") + require.NoError(t, err) + }) +} diff --git a/server/service/client.go b/server/service/client.go index 322b595b79..cf230c886c 100644 --- a/server/service/client.go +++ b/server/service/client.go @@ -3,36 +3,23 @@ package service import ( "bytes" "context" - "crypto/tls" - "crypto/x509" "encoding/json" "errors" "fmt" "io" - "io/ioutil" "net/http" - "net/url" "os" - "strings" "time" - "github.com/fleetdm/fleet/v4/pkg/fleethttp" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" ) -// httpClient interface allows the HTTP methods to be mocked. -type httpClient interface { - Do(req *http.Request) (*http.Response, error) -} - +// Client is used to consume Fleet APIs from Go code type Client struct { - addr string - baseURL *url.URL - urlPrefix string - token string - http httpClient - insecureSkipVerify bool + *baseClient + addr string + token string writer io.Writer } @@ -42,46 +29,15 @@ type ClientOption func(*Client) error func NewClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) { // TODO #265 refactor all optional parameters to functional options // API breaking change, needs a major version release - baseURL, err := url.Parse(addr) + baseClient, err := newBaseClient(addr, insecureSkipVerify, rootCA, urlPrefix) + if err != nil { - return nil, fmt.Errorf("parsing URL: %w", err) + return nil, err } - if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") { - return nil, errors.New("address must start with https:// for remote connections") - } - - rootCAPool := x509.NewCertPool() - if rootCA != "" { - // read in the root cert file specified in the context - certs, err := ioutil.ReadFile(rootCA) - if err != nil { - return nil, fmt.Errorf("reading root CA: %w", err) - } - - // add certs to pool - if ok := rootCAPool.AppendCertsFromPEM(certs); !ok { - return nil, errors.New("failed to add certificates to root CA pool") - } - } else if !insecureSkipVerify { - // Use only the system certs (doesn't work on Windows) - rootCAPool, err = x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("loading system cert pool: %w", err) - } - } - - httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{ - InsecureSkipVerify: insecureSkipVerify, - RootCAs: rootCAPool, - })) - client := &Client{ - addr: addr, - baseURL: baseURL, - http: httpClient, - insecureSkipVerify: insecureSkipVerify, - urlPrefix: urlPrefix, + baseClient: baseClient, + addr: addr, } for _, option := range options { @@ -179,13 +135,6 @@ func (c *Client) SetToken(t string) { c.token = t } -func (c *Client) url(path, rawQuery string) *url.URL { - u := *c.baseURL - u.Path = c.urlPrefix + path - u.RawQuery = rawQuery - return &u -} - // http.RoundTripper that will log debug information about the request and // response, including paths, timing, and body. // @@ -240,34 +189,7 @@ func (c *Client) authenticatedRequestWithQuery(params interface{}, verb string, } defer response.Body.Close() - switch response.StatusCode { - case http.StatusOK: - // ok - case http.StatusNotFound: - return notFoundErr{} - case http.StatusUnauthorized: - return ErrUnauthenticated - default: - return fmt.Errorf( - "%s %s received status %d %s", - verb, path, - response.StatusCode, - extractServerErrorText(response.Body), - ) - } - - err = json.NewDecoder(response.Body).Decode(&responseDest) - if err != nil { - return fmt.Errorf("decode %s %s response: %w", verb, path, err) - } - - if e, ok := responseDest.(errorer); ok { - if e.error() != nil { - return fmt.Errorf("%s %s error: %s", verb, path, e.error()) - } - } - - return nil + return c.parseResponse(verb, path, response, responseDest) } func (c *Client) authenticatedRequest(params interface{}, verb string, path string, responseDest interface{}) error { diff --git a/server/service/client_live_query_test.go b/server/service/client_live_query_test.go index f082ce237b..b52beeaffc 100644 --- a/server/service/client_live_query_test.go +++ b/server/service/client_live_query_test.go @@ -88,12 +88,14 @@ func TestLiveQueryWithContext(t *testing.T) { baseURL, err := url.Parse(ts.URL) require.NoError(t, err) client := &Client{ - baseURL: baseURL, - urlPrefix: "", - token: "1234", - http: fleethttp.NewClient(), - insecureSkipVerify: false, - writer: nil, + baseClient: &baseClient{ + baseURL: baseURL, + http: fleethttp.NewClient(), + insecureSkipVerify: false, + urlPrefix: "", + }, + token: "1234", + writer: nil, } ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) defer cancelFunc() diff --git a/server/service/device_client.go b/server/service/device_client.go new file mode 100644 index 0000000000..f8f9ff2ee2 --- /dev/null +++ b/server/service/device_client.go @@ -0,0 +1,59 @@ +package service + +import ( + "bytes" + "fmt" + "net/http" + + "github.com/fleetdm/fleet/v4/server/fleet" +) + +// Device client is used to consume `/device/...` endpoints, +// and meant to be used by Fleet Desktop +type DeviceClient struct { + *baseClient + token string +} + +func (dc *DeviceClient) request(verb string, path string, query string, responseDest interface{}) error { + var bodyBytes []byte + request, err := http.NewRequest( + verb, + dc.url(path, query).String(), + bytes.NewBuffer(bodyBytes), + ) + + if err != nil { + return err + } + + response, err := dc.http.Do(request) + if err != nil { + return fmt.Errorf("%s %s: %w", verb, path, err) + } + defer response.Body.Close() + + return dc.parseResponse(verb, path, response, responseDest) +} + +// NewDeviceClient instantiates a new client to perform requests against device endpoints +func NewDeviceClient(addr string, token string, insecureSkipVerify bool, rootCA string) (*DeviceClient, error) { + baseClient, err := newBaseClient(addr, insecureSkipVerify, rootCA, "") + + if err != nil { + return nil, err + } + + return &DeviceClient{baseClient: baseClient, token: token}, nil +} + +// ListDevicePolicies fetches all policies for the device with the provided token +func (dc *DeviceClient) ListDevicePolicies() ([]*fleet.HostPolicy, error) { + verb, path := "GET", "/api/latest/fleet/device/"+dc.token+"/policies" + var responseBody listDevicePoliciesResponse + err := dc.request(verb, path, "", &responseBody) + if err != nil { + return nil, err + } + return responseBody.Policies, nil +} diff --git a/server/service/device_client_test.go b/server/service/device_client_test.go new file mode 100644 index 0000000000..202edc89b2 --- /dev/null +++ b/server/service/device_client_test.go @@ -0,0 +1,60 @@ +package service + +import ( + "bytes" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +type mockHttpClient struct { + resBody string + statusCode int + err error +} + +func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) { + if m.err != nil { + return nil, m.err + } + + res := &http.Response{ + StatusCode: m.statusCode, + Body: ioutil.NopCloser(bytes.NewBufferString(m.resBody)), + } + + return res, nil +} + +func TestDeviceClientListPolicies(t *testing.T) { + client, err := NewDeviceClient("https://test.com", "test-token", true, "") + require.NoError(t, err) + + mockRequestDoer := &mockHttpClient{} + client.http = mockRequestDoer + + t.Run("with wrong license", func(t *testing.T) { + mockRequestDoer.statusCode = http.StatusPaymentRequired + _, err = client.ListDevicePolicies() + require.ErrorIs(t, err, ErrMissingLicense) + }) + + t.Run("with empty policies", func(t *testing.T) { + mockRequestDoer.statusCode = http.StatusOK + mockRequestDoer.resBody = `{"policies": []}` + policies, err := client.ListDevicePolicies() + require.NoError(t, err) + require.Len(t, policies, 0) + }) + + t.Run("with policies", func(t *testing.T) { + mockRequestDoer.statusCode = http.StatusOK + mockRequestDoer.resBody = `{"policies": [{"id": 1}]}` + policies, err := client.ListDevicePolicies() + require.NoError(t, err) + require.Len(t, policies, 1) + require.Equal(t, uint(1), policies[0].ID) + }) +}