add Go client to consume device endpoints (#5987)

This adds a new API client named DeviceClient to server/service, meant to consume device endpoints and be used from Fleet Desktop.

Some of the logic to make requests and parse responses was very repetitive, so I introduced a private baseClient type and moved some of the logic of the existent Client there.

Related to #5685 and #5697
This commit is contained in:
Roberto Dip 2022-06-01 20:05:05 -03:00 committed by GitHub
parent 621fe84e43
commit 842ebbb2ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 356 additions and 94 deletions

View file

@ -0,0 +1,111 @@
package service
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
)
// httpClient interface allows the HTTP methods to be mocked.
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}
type baseClient struct {
baseURL *url.URL
http httpClient
urlPrefix string
insecureSkipVerify bool
}
func (bc *baseClient) parseResponse(verb, path string, response *http.Response, responseDest interface{}) error {
switch response.StatusCode {
case http.StatusOK:
// ok
case http.StatusNotFound:
return notFoundErr{}
case http.StatusUnauthorized:
return ErrUnauthenticated
case http.StatusPaymentRequired:
return ErrMissingLicense
default:
return fmt.Errorf(
"%s %s received status %d %s",
verb, path,
response.StatusCode,
extractServerErrorText(response.Body),
)
}
if err := json.NewDecoder(response.Body).Decode(&responseDest); err != nil {
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
}
if e, ok := responseDest.(errorer); ok {
if e.error() != nil {
return fmt.Errorf("%s %s error: %w", verb, path, e.error())
}
}
return nil
}
func (bc *baseClient) url(path, rawQuery string) *url.URL {
u := *bc.baseURL
u.Path = bc.urlPrefix + path
u.RawQuery = rawQuery
return &u
}
func newBaseClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string) (*baseClient, error) {
baseURL, err := url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("parsing URL: %w", err)
}
if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") {
return nil, errors.New("address must start with https:// for remote connections")
}
rootCAPool := x509.NewCertPool()
if rootCA != "" {
// read in the root cert file specified in the context
certs, err := ioutil.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("reading root CA: %w", err)
}
// add certs to pool
if ok := rootCAPool.AppendCertsFromPEM(certs); !ok {
return nil, errors.New("failed to add certificates to root CA pool")
}
} else if !insecureSkipVerify {
// Use only the system certs (doesn't work on Windows)
rootCAPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("loading system cert pool: %w", err)
}
}
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
InsecureSkipVerify: insecureSkipVerify,
RootCAs: rootCAPool,
}))
client := &baseClient{
baseURL: baseURL,
http: httpClient,
insecureSkipVerify: insecureSkipVerify,
urlPrefix: urlPrefix,
}
return client, nil
}

View file

@ -9,6 +9,7 @@ import (
var (
ErrUnauthenticated = errors.New("unauthenticated, or invalid token")
ErrMissingLicense = errors.New("missing or invalid license")
)
type SetupAlreadyErr interface {

View file

@ -0,0 +1,107 @@
package service
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestUrlGeneration(t *testing.T) {
t.Run("without prefix", func(t *testing.T) {
bc, err := newBaseClient("https://test.com", true, "", "")
require.NoError(t, err)
require.Equal(t, "https://test.com/test/path", bc.url("test/path", "").String())
require.Equal(t, "https://test.com/test/path?raw=query", bc.url("test/path", "raw=query").String())
})
t.Run("with prefix", func(t *testing.T) {
bc, err := newBaseClient("https://test.com", true, "", "prefix/")
require.NoError(t, err)
require.Equal(t, "https://test.com/prefix/test/path", bc.url("test/path", "").String())
require.Equal(t, "https://test.com/prefix/test/path?raw=query", bc.url("test/path", "raw=query").String())
})
}
func TestParseResponseKnownErrors(t *testing.T) {
cases := []struct {
message string
code int
out error
}{
{"not found errors", http.StatusNotFound, notFoundErr{}},
{"unauthenticated errors", http.StatusUnauthorized, ErrUnauthenticated},
{"license errors", http.StatusPaymentRequired, ErrMissingLicense},
}
for _, c := range cases {
t.Run(c.message, func(t *testing.T) {
bc, err := newBaseClient("https://test.com", true, "", "")
require.NoError(t, err)
response := &http.Response{
StatusCode: c.code,
Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)),
}
err = bc.parseResponse("GET", "", response, &struct{}{})
require.ErrorIs(t, err, c.out)
})
}
}
func TestParseResponseOK(t *testing.T) {
bc, err := newBaseClient("https://test.com", true, "", "")
require.NoError(t, err)
response := &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)),
}
var resDest struct{ Test string }
err = bc.parseResponse("", "", response, &resDest)
require.NoError(t, err)
require.Equal(t, "ok", resDest.Test)
}
func TestParseResponseGeneralErrors(t *testing.T) {
t.Run("general HTTP errors", func(t *testing.T) {
bc, err := newBaseClient("https://test.com", true, "", "")
require.NoError(t, err)
response := &http.Response{
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)),
}
err = bc.parseResponse("GET", "", response, &struct{}{})
require.Error(t, err)
})
t.Run("parse errors", func(t *testing.T) {
bc, err := newBaseClient("https://test.com", true, "", "")
require.NoError(t, err)
response := &http.Response{
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(bytes.NewBufferString(`invalid json`)),
}
err = bc.parseResponse("GET", "", response, &struct{}{})
require.Error(t, err)
})
}
func TestNewBaseClient(t *testing.T) {
t.Run("invalid addresses are an error", func(t *testing.T) {
_, err := newBaseClient("invalid", true, "", "")
require.Error(t, err)
})
t.Run("http is only valid in development", func(t *testing.T) {
_, err := newBaseClient("http://test.com", true, "", "")
require.Error(t, err)
_, err = newBaseClient("http://localhost:8080", true, "", "")
require.NoError(t, err)
_, err = newBaseClient("http://127.0.0.1:8080", true, "", "")
require.NoError(t, err)
})
}

View file

@ -3,36 +3,23 @@ package service
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
// httpClient interface allows the HTTP methods to be mocked.
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}
// Client is used to consume Fleet APIs from Go code
type Client struct {
addr string
baseURL *url.URL
urlPrefix string
token string
http httpClient
insecureSkipVerify bool
*baseClient
addr string
token string
writer io.Writer
}
@ -42,46 +29,15 @@ type ClientOption func(*Client) error
func NewClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) {
// TODO #265 refactor all optional parameters to functional options
// API breaking change, needs a major version release
baseURL, err := url.Parse(addr)
baseClient, err := newBaseClient(addr, insecureSkipVerify, rootCA, urlPrefix)
if err != nil {
return nil, fmt.Errorf("parsing URL: %w", err)
return nil, err
}
if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") {
return nil, errors.New("address must start with https:// for remote connections")
}
rootCAPool := x509.NewCertPool()
if rootCA != "" {
// read in the root cert file specified in the context
certs, err := ioutil.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("reading root CA: %w", err)
}
// add certs to pool
if ok := rootCAPool.AppendCertsFromPEM(certs); !ok {
return nil, errors.New("failed to add certificates to root CA pool")
}
} else if !insecureSkipVerify {
// Use only the system certs (doesn't work on Windows)
rootCAPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("loading system cert pool: %w", err)
}
}
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
InsecureSkipVerify: insecureSkipVerify,
RootCAs: rootCAPool,
}))
client := &Client{
addr: addr,
baseURL: baseURL,
http: httpClient,
insecureSkipVerify: insecureSkipVerify,
urlPrefix: urlPrefix,
baseClient: baseClient,
addr: addr,
}
for _, option := range options {
@ -179,13 +135,6 @@ func (c *Client) SetToken(t string) {
c.token = t
}
func (c *Client) url(path, rawQuery string) *url.URL {
u := *c.baseURL
u.Path = c.urlPrefix + path
u.RawQuery = rawQuery
return &u
}
// http.RoundTripper that will log debug information about the request and
// response, including paths, timing, and body.
//
@ -240,34 +189,7 @@ func (c *Client) authenticatedRequestWithQuery(params interface{}, verb string,
}
defer response.Body.Close()
switch response.StatusCode {
case http.StatusOK:
// ok
case http.StatusNotFound:
return notFoundErr{}
case http.StatusUnauthorized:
return ErrUnauthenticated
default:
return fmt.Errorf(
"%s %s received status %d %s",
verb, path,
response.StatusCode,
extractServerErrorText(response.Body),
)
}
err = json.NewDecoder(response.Body).Decode(&responseDest)
if err != nil {
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
}
if e, ok := responseDest.(errorer); ok {
if e.error() != nil {
return fmt.Errorf("%s %s error: %s", verb, path, e.error())
}
}
return nil
return c.parseResponse(verb, path, response, responseDest)
}
func (c *Client) authenticatedRequest(params interface{}, verb string, path string, responseDest interface{}) error {

View file

@ -88,12 +88,14 @@ func TestLiveQueryWithContext(t *testing.T) {
baseURL, err := url.Parse(ts.URL)
require.NoError(t, err)
client := &Client{
baseURL: baseURL,
urlPrefix: "",
token: "1234",
http: fleethttp.NewClient(),
insecureSkipVerify: false,
writer: nil,
baseClient: &baseClient{
baseURL: baseURL,
http: fleethttp.NewClient(),
insecureSkipVerify: false,
urlPrefix: "",
},
token: "1234",
writer: nil,
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFunc()

View file

@ -0,0 +1,59 @@
package service
import (
"bytes"
"fmt"
"net/http"
"github.com/fleetdm/fleet/v4/server/fleet"
)
// Device client is used to consume `/device/...` endpoints,
// and meant to be used by Fleet Desktop
type DeviceClient struct {
*baseClient
token string
}
func (dc *DeviceClient) request(verb string, path string, query string, responseDest interface{}) error {
var bodyBytes []byte
request, err := http.NewRequest(
verb,
dc.url(path, query).String(),
bytes.NewBuffer(bodyBytes),
)
if err != nil {
return err
}
response, err := dc.http.Do(request)
if err != nil {
return fmt.Errorf("%s %s: %w", verb, path, err)
}
defer response.Body.Close()
return dc.parseResponse(verb, path, response, responseDest)
}
// NewDeviceClient instantiates a new client to perform requests against device endpoints
func NewDeviceClient(addr string, token string, insecureSkipVerify bool, rootCA string) (*DeviceClient, error) {
baseClient, err := newBaseClient(addr, insecureSkipVerify, rootCA, "")
if err != nil {
return nil, err
}
return &DeviceClient{baseClient: baseClient, token: token}, nil
}
// ListDevicePolicies fetches all policies for the device with the provided token
func (dc *DeviceClient) ListDevicePolicies() ([]*fleet.HostPolicy, error) {
verb, path := "GET", "/api/latest/fleet/device/"+dc.token+"/policies"
var responseBody listDevicePoliciesResponse
err := dc.request(verb, path, "", &responseBody)
if err != nil {
return nil, err
}
return responseBody.Policies, nil
}

View file

@ -0,0 +1,60 @@
package service
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
type mockHttpClient struct {
resBody string
statusCode int
err error
}
func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) {
if m.err != nil {
return nil, m.err
}
res := &http.Response{
StatusCode: m.statusCode,
Body: ioutil.NopCloser(bytes.NewBufferString(m.resBody)),
}
return res, nil
}
func TestDeviceClientListPolicies(t *testing.T) {
client, err := NewDeviceClient("https://test.com", "test-token", true, "")
require.NoError(t, err)
mockRequestDoer := &mockHttpClient{}
client.http = mockRequestDoer
t.Run("with wrong license", func(t *testing.T) {
mockRequestDoer.statusCode = http.StatusPaymentRequired
_, err = client.ListDevicePolicies()
require.ErrorIs(t, err, ErrMissingLicense)
})
t.Run("with empty policies", func(t *testing.T) {
mockRequestDoer.statusCode = http.StatusOK
mockRequestDoer.resBody = `{"policies": []}`
policies, err := client.ListDevicePolicies()
require.NoError(t, err)
require.Len(t, policies, 0)
})
t.Run("with policies", func(t *testing.T) {
mockRequestDoer.statusCode = http.StatusOK
mockRequestDoer.resBody = `{"policies": [{"id": 1}]}`
policies, err := client.ListDevicePolicies()
require.NoError(t, err)
require.Len(t, policies, 1)
require.Equal(t, uint(1), policies[0].ID)
})
}