mirror of
https://github.com/fleetdm/fleet
synced 2026-05-22 08:28:52 +00:00
add Go client to consume device endpoints (#5987)
This adds a new API client named DeviceClient to server/service, meant to consume device endpoints and be used from Fleet Desktop. Some of the logic to make requests and parse responses was very repetitive, so I introduced a private baseClient type and moved some of the logic of the existent Client there. Related to #5685 and #5697
This commit is contained in:
parent
621fe84e43
commit
842ebbb2ae
7 changed files with 356 additions and 94 deletions
111
server/service/base_client.go
Normal file
111
server/service/base_client.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
)
|
||||
|
||||
// httpClient interface allows the HTTP methods to be mocked.
|
||||
type httpClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type baseClient struct {
|
||||
baseURL *url.URL
|
||||
http httpClient
|
||||
urlPrefix string
|
||||
insecureSkipVerify bool
|
||||
}
|
||||
|
||||
func (bc *baseClient) parseResponse(verb, path string, response *http.Response, responseDest interface{}) error {
|
||||
switch response.StatusCode {
|
||||
case http.StatusOK:
|
||||
// ok
|
||||
case http.StatusNotFound:
|
||||
return notFoundErr{}
|
||||
case http.StatusUnauthorized:
|
||||
return ErrUnauthenticated
|
||||
case http.StatusPaymentRequired:
|
||||
return ErrMissingLicense
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"%s %s received status %d %s",
|
||||
verb, path,
|
||||
response.StatusCode,
|
||||
extractServerErrorText(response.Body),
|
||||
)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(&responseDest); err != nil {
|
||||
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
|
||||
}
|
||||
|
||||
if e, ok := responseDest.(errorer); ok {
|
||||
if e.error() != nil {
|
||||
return fmt.Errorf("%s %s error: %w", verb, path, e.error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bc *baseClient) url(path, rawQuery string) *url.URL {
|
||||
u := *bc.baseURL
|
||||
u.Path = bc.urlPrefix + path
|
||||
u.RawQuery = rawQuery
|
||||
return &u
|
||||
}
|
||||
|
||||
func newBaseClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string) (*baseClient, error) {
|
||||
baseURL, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing URL: %w", err)
|
||||
}
|
||||
|
||||
if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") {
|
||||
return nil, errors.New("address must start with https:// for remote connections")
|
||||
}
|
||||
|
||||
rootCAPool := x509.NewCertPool()
|
||||
if rootCA != "" {
|
||||
// read in the root cert file specified in the context
|
||||
certs, err := ioutil.ReadFile(rootCA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading root CA: %w", err)
|
||||
}
|
||||
|
||||
// add certs to pool
|
||||
if ok := rootCAPool.AppendCertsFromPEM(certs); !ok {
|
||||
return nil, errors.New("failed to add certificates to root CA pool")
|
||||
}
|
||||
} else if !insecureSkipVerify {
|
||||
// Use only the system certs (doesn't work on Windows)
|
||||
rootCAPool, err = x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading system cert pool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
RootCAs: rootCAPool,
|
||||
}))
|
||||
|
||||
client := &baseClient{
|
||||
baseURL: baseURL,
|
||||
http: httpClient,
|
||||
insecureSkipVerify: insecureSkipVerify,
|
||||
urlPrefix: urlPrefix,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated, or invalid token")
|
||||
ErrMissingLicense = errors.New("missing or invalid license")
|
||||
)
|
||||
|
||||
type SetupAlreadyErr interface {
|
||||
107
server/service/base_client_test.go
Normal file
107
server/service/base_client_test.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUrlGeneration(t *testing.T) {
|
||||
t.Run("without prefix", func(t *testing.T) {
|
||||
bc, err := newBaseClient("https://test.com", true, "", "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://test.com/test/path", bc.url("test/path", "").String())
|
||||
require.Equal(t, "https://test.com/test/path?raw=query", bc.url("test/path", "raw=query").String())
|
||||
})
|
||||
|
||||
t.Run("with prefix", func(t *testing.T) {
|
||||
bc, err := newBaseClient("https://test.com", true, "", "prefix/")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://test.com/prefix/test/path", bc.url("test/path", "").String())
|
||||
require.Equal(t, "https://test.com/prefix/test/path?raw=query", bc.url("test/path", "raw=query").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseResponseKnownErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
message string
|
||||
code int
|
||||
out error
|
||||
}{
|
||||
{"not found errors", http.StatusNotFound, notFoundErr{}},
|
||||
{"unauthenticated errors", http.StatusUnauthorized, ErrUnauthenticated},
|
||||
{"license errors", http.StatusPaymentRequired, ErrMissingLicense},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.message, func(t *testing.T) {
|
||||
bc, err := newBaseClient("https://test.com", true, "", "")
|
||||
require.NoError(t, err)
|
||||
response := &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)),
|
||||
}
|
||||
err = bc.parseResponse("GET", "", response, &struct{}{})
|
||||
require.ErrorIs(t, err, c.out)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponseOK(t *testing.T) {
|
||||
bc, err := newBaseClient("https://test.com", true, "", "")
|
||||
require.NoError(t, err)
|
||||
response := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)),
|
||||
}
|
||||
|
||||
var resDest struct{ Test string }
|
||||
err = bc.parseResponse("", "", response, &resDest)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ok", resDest.Test)
|
||||
}
|
||||
|
||||
func TestParseResponseGeneralErrors(t *testing.T) {
|
||||
t.Run("general HTTP errors", func(t *testing.T) {
|
||||
bc, err := newBaseClient("https://test.com", true, "", "")
|
||||
require.NoError(t, err)
|
||||
response := &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(`{"test": "ok"}`)),
|
||||
}
|
||||
err = bc.parseResponse("GET", "", response, &struct{}{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("parse errors", func(t *testing.T) {
|
||||
bc, err := newBaseClient("https://test.com", true, "", "")
|
||||
require.NoError(t, err)
|
||||
response := &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(`invalid json`)),
|
||||
}
|
||||
err = bc.parseResponse("GET", "", response, &struct{}{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewBaseClient(t *testing.T) {
|
||||
t.Run("invalid addresses are an error", func(t *testing.T) {
|
||||
_, err := newBaseClient("invalid", true, "", "")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("http is only valid in development", func(t *testing.T) {
|
||||
_, err := newBaseClient("http://test.com", true, "", "")
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = newBaseClient("http://localhost:8080", true, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = newBaseClient("http://127.0.0.1:8080", true, "", "")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
@ -3,36 +3,23 @@ package service
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
// httpClient interface allows the HTTP methods to be mocked.
|
||||
type httpClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// Client is used to consume Fleet APIs from Go code
|
||||
type Client struct {
|
||||
addr string
|
||||
baseURL *url.URL
|
||||
urlPrefix string
|
||||
token string
|
||||
http httpClient
|
||||
insecureSkipVerify bool
|
||||
*baseClient
|
||||
addr string
|
||||
token string
|
||||
|
||||
writer io.Writer
|
||||
}
|
||||
|
|
@ -42,46 +29,15 @@ type ClientOption func(*Client) error
|
|||
func NewClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) {
|
||||
// TODO #265 refactor all optional parameters to functional options
|
||||
// API breaking change, needs a major version release
|
||||
baseURL, err := url.Parse(addr)
|
||||
baseClient, err := newBaseClient(addr, insecureSkipVerify, rootCA, urlPrefix)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing URL: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") {
|
||||
return nil, errors.New("address must start with https:// for remote connections")
|
||||
}
|
||||
|
||||
rootCAPool := x509.NewCertPool()
|
||||
if rootCA != "" {
|
||||
// read in the root cert file specified in the context
|
||||
certs, err := ioutil.ReadFile(rootCA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading root CA: %w", err)
|
||||
}
|
||||
|
||||
// add certs to pool
|
||||
if ok := rootCAPool.AppendCertsFromPEM(certs); !ok {
|
||||
return nil, errors.New("failed to add certificates to root CA pool")
|
||||
}
|
||||
} else if !insecureSkipVerify {
|
||||
// Use only the system certs (doesn't work on Windows)
|
||||
rootCAPool, err = x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading system cert pool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
RootCAs: rootCAPool,
|
||||
}))
|
||||
|
||||
client := &Client{
|
||||
addr: addr,
|
||||
baseURL: baseURL,
|
||||
http: httpClient,
|
||||
insecureSkipVerify: insecureSkipVerify,
|
||||
urlPrefix: urlPrefix,
|
||||
baseClient: baseClient,
|
||||
addr: addr,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
|
|
@ -179,13 +135,6 @@ func (c *Client) SetToken(t string) {
|
|||
c.token = t
|
||||
}
|
||||
|
||||
func (c *Client) url(path, rawQuery string) *url.URL {
|
||||
u := *c.baseURL
|
||||
u.Path = c.urlPrefix + path
|
||||
u.RawQuery = rawQuery
|
||||
return &u
|
||||
}
|
||||
|
||||
// http.RoundTripper that will log debug information about the request and
|
||||
// response, including paths, timing, and body.
|
||||
//
|
||||
|
|
@ -240,34 +189,7 @@ func (c *Client) authenticatedRequestWithQuery(params interface{}, verb string,
|
|||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
switch response.StatusCode {
|
||||
case http.StatusOK:
|
||||
// ok
|
||||
case http.StatusNotFound:
|
||||
return notFoundErr{}
|
||||
case http.StatusUnauthorized:
|
||||
return ErrUnauthenticated
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"%s %s received status %d %s",
|
||||
verb, path,
|
||||
response.StatusCode,
|
||||
extractServerErrorText(response.Body),
|
||||
)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(response.Body).Decode(&responseDest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
|
||||
}
|
||||
|
||||
if e, ok := responseDest.(errorer); ok {
|
||||
if e.error() != nil {
|
||||
return fmt.Errorf("%s %s error: %s", verb, path, e.error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.parseResponse(verb, path, response, responseDest)
|
||||
}
|
||||
|
||||
func (c *Client) authenticatedRequest(params interface{}, verb string, path string, responseDest interface{}) error {
|
||||
|
|
|
|||
|
|
@ -88,12 +88,14 @@ func TestLiveQueryWithContext(t *testing.T) {
|
|||
baseURL, err := url.Parse(ts.URL)
|
||||
require.NoError(t, err)
|
||||
client := &Client{
|
||||
baseURL: baseURL,
|
||||
urlPrefix: "",
|
||||
token: "1234",
|
||||
http: fleethttp.NewClient(),
|
||||
insecureSkipVerify: false,
|
||||
writer: nil,
|
||||
baseClient: &baseClient{
|
||||
baseURL: baseURL,
|
||||
http: fleethttp.NewClient(),
|
||||
insecureSkipVerify: false,
|
||||
urlPrefix: "",
|
||||
},
|
||||
token: "1234",
|
||||
writer: nil,
|
||||
}
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancelFunc()
|
||||
|
|
|
|||
59
server/service/device_client.go
Normal file
59
server/service/device_client.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
// Device client is used to consume `/device/...` endpoints,
|
||||
// and meant to be used by Fleet Desktop
|
||||
type DeviceClient struct {
|
||||
*baseClient
|
||||
token string
|
||||
}
|
||||
|
||||
func (dc *DeviceClient) request(verb string, path string, query string, responseDest interface{}) error {
|
||||
var bodyBytes []byte
|
||||
request, err := http.NewRequest(
|
||||
verb,
|
||||
dc.url(path, query).String(),
|
||||
bytes.NewBuffer(bodyBytes),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := dc.http.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s: %w", verb, path, err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
return dc.parseResponse(verb, path, response, responseDest)
|
||||
}
|
||||
|
||||
// NewDeviceClient instantiates a new client to perform requests against device endpoints
|
||||
func NewDeviceClient(addr string, token string, insecureSkipVerify bool, rootCA string) (*DeviceClient, error) {
|
||||
baseClient, err := newBaseClient(addr, insecureSkipVerify, rootCA, "")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &DeviceClient{baseClient: baseClient, token: token}, nil
|
||||
}
|
||||
|
||||
// ListDevicePolicies fetches all policies for the device with the provided token
|
||||
func (dc *DeviceClient) ListDevicePolicies() ([]*fleet.HostPolicy, error) {
|
||||
verb, path := "GET", "/api/latest/fleet/device/"+dc.token+"/policies"
|
||||
var responseBody listDevicePoliciesResponse
|
||||
err := dc.request(verb, path, "", &responseBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return responseBody.Policies, nil
|
||||
}
|
||||
60
server/service/device_client_test.go
Normal file
60
server/service/device_client_test.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHttpClient struct {
|
||||
resBody string
|
||||
statusCode int
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockHttpClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
res := &http.Response{
|
||||
StatusCode: m.statusCode,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(m.resBody)),
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func TestDeviceClientListPolicies(t *testing.T) {
|
||||
client, err := NewDeviceClient("https://test.com", "test-token", true, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
mockRequestDoer := &mockHttpClient{}
|
||||
client.http = mockRequestDoer
|
||||
|
||||
t.Run("with wrong license", func(t *testing.T) {
|
||||
mockRequestDoer.statusCode = http.StatusPaymentRequired
|
||||
_, err = client.ListDevicePolicies()
|
||||
require.ErrorIs(t, err, ErrMissingLicense)
|
||||
})
|
||||
|
||||
t.Run("with empty policies", func(t *testing.T) {
|
||||
mockRequestDoer.statusCode = http.StatusOK
|
||||
mockRequestDoer.resBody = `{"policies": []}`
|
||||
policies, err := client.ListDevicePolicies()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 0)
|
||||
})
|
||||
|
||||
t.Run("with policies", func(t *testing.T) {
|
||||
mockRequestDoer.statusCode = http.StatusOK
|
||||
mockRequestDoer.resBody = `{"policies": [{"id": 1}]}`
|
||||
policies, err := client.ListDevicePolicies()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
require.Equal(t, uint(1), policies[0].ID)
|
||||
})
|
||||
}
|
||||
Loading…
Reference in a new issue