mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #34528 # Details This PR implements the agent changes for allowing Fleet admins to require that users authenticate with an IdP prior to having their devices set up. I'll comment on changes inline but the high-level is: 1. Orbit calls the enroll endpoint as usual. This is triggered lazily by any one of a number of subsystems like device token rotation or requesting Fleet config 2. If the enroll endpoint returns the new `ErrEndUserAuthRequired` response, then it opens a window to the `/mdm/sso` Fleet page and retries the enroll endpoint every 30 seconds indefinitely. 3. Any other non-200 response to the enroll request is treated as before (limited # of retries, with backoff) # Checklist for submitter If some of the following don't apply, delete the relevant line. - [ ] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing- changes.md#changes-files) for more information. Will add changelog when story is one. ## Testing - [X] Added/updated automated tests Added test for new retry logic - [X] QA'd all new/changed functionality manually This is kinda hard to test without the associated backend PR: https://github.com/fleetdm/fleet/pull/34835 ## fleetd/orbit/Fleet Desktop - [X] Verified compatibility with the latest released version of Fleet (see [Must rule](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/workflows/fleetd-development-and-release-strategy.md)) This is compatible with all Fleet versions, since older ones won't send the new error. - [X] If the change applies to only one platform, confirmed that `runtime.GOOS` is used as needed to isolate changes This is compatible with all platforms, although it currently should only ever run on Windows and Linux since macOS devices will have end-user auth taken care of before they even download Orbit. - [ ] Verified that fleetd runs on macOS, Linux and Windows Testing this now. - [ ] Verified auto-update works from the released version of component to the new version (see [tools/tuf/test](../tools/tuf/test/README.md)) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added SSO (Single Sign-On) enrollment support for end-user authentication * Enhanced error messaging for authentication-required scenarios * **Bug Fixes** * Improved error handling and retry logic for enrollment failures <!-- end of auto-generated comment: release notes by coderabbit.ai -->
280 lines
7.5 KiB
Go
280 lines
7.5 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var errInvalidScheme = errors.New("address must start with https:// for remote connections")
|
|
|
|
// httpClient interface allows the HTTP methods to be mocked.
|
|
type httpClient interface {
|
|
Do(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
type baseClient struct {
|
|
baseURL *url.URL
|
|
http httpClient
|
|
urlPrefix string
|
|
insecureSkipVerify bool
|
|
// serverCapabilities is a map of capabilities that the server supports.
|
|
// This map is updated on each response we receive from the server.
|
|
serverCapabilities fleet.CapabilityMap
|
|
// clientCapabilities is a map of capabilities that the client supports.
|
|
// This list is given when the client is instantiated and shouldn't be
|
|
// modified afterwards.
|
|
clientCapabilities fleet.CapabilityMap
|
|
}
|
|
|
|
// parseResponse processes the status code and parses the response body.
|
|
// It does not close the response body (should be closed by the caller).
|
|
func (bc *baseClient) parseResponse(verb, path string, response *http.Response, responseDest interface{}) error {
|
|
switch response.StatusCode {
|
|
case http.StatusNotFound:
|
|
return notFoundErr{
|
|
msg: extractServerErrorText(response.Body),
|
|
}
|
|
case http.StatusUnauthorized:
|
|
errText := extractServerErrorText(response.Body)
|
|
if strings.Contains(errText, "password reset required") {
|
|
return ErrPasswordResetRequired
|
|
}
|
|
if strings.Contains(errText, "END_USER_AUTH_REQUIRED") {
|
|
return ErrEndUserAuthRequired
|
|
}
|
|
return ErrUnauthenticated
|
|
case http.StatusPaymentRequired:
|
|
return ErrMissingLicense
|
|
default:
|
|
if response.StatusCode >= 200 && response.StatusCode < 300 {
|
|
break
|
|
}
|
|
|
|
e := &statusCodeErr{
|
|
code: response.StatusCode,
|
|
body: extractServerErrorText(response.Body),
|
|
}
|
|
return fmt.Errorf("%s %s received status %w", verb, path, e)
|
|
}
|
|
|
|
bc.setServerCapabilities(response)
|
|
|
|
if responseDest != nil {
|
|
if e, ok := responseDest.(bodyHandler); ok {
|
|
if err := e.Handle(response); err != nil {
|
|
return fmt.Errorf("%s %s error with custom body handler contents: %w", verb, path, err)
|
|
}
|
|
} else if response.StatusCode != http.StatusNoContent {
|
|
b, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("reading response body: %w", err)
|
|
}
|
|
if err := json.Unmarshal(b, &responseDest); err != nil {
|
|
const maxBodyLen = 200
|
|
truncatedBytes, isHTML := truncateAndDetectHTML(b, maxBodyLen)
|
|
|
|
if isHTML {
|
|
return fmt.Errorf("decode %s %s response: %w, (server returned HTML instead of JSON), body: %s", verb, path, err, truncatedBytes)
|
|
}
|
|
return fmt.Errorf("decode %s %s response: %w, body: %s", verb, path, err, truncatedBytes)
|
|
}
|
|
if e, ok := responseDest.(fleet.Errorer); ok {
|
|
if e.Error() != nil {
|
|
return fmt.Errorf("%s %s error: %w", verb, path, e.Error())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
bc.setServerCapabilities(response)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (bc *baseClient) url(path, rawQuery string) *url.URL {
|
|
u := *bc.baseURL
|
|
u.Path = bc.urlPrefix + path
|
|
u.RawQuery = rawQuery
|
|
return &u
|
|
}
|
|
|
|
// setServerCapabilities updates the server capabilities based on the response
|
|
// from the server.
|
|
func (bc *baseClient) setServerCapabilities(response *http.Response) {
|
|
capabilities := response.Header.Get(fleet.CapabilitiesHeader)
|
|
bc.serverCapabilities.PopulateFromString(capabilities)
|
|
}
|
|
|
|
func (bc *baseClient) GetServerCapabilities() fleet.CapabilityMap {
|
|
return bc.serverCapabilities
|
|
}
|
|
|
|
// setClientCapabilities header is used to set a header with the client
|
|
// capabilities in the given request.
|
|
//
|
|
// This method is defined in baseClient because other clients generally have
|
|
// custom implementations of a method to perform the requests to the server.
|
|
func (bc *baseClient) setClientCapabilitiesHeader(req *http.Request) {
|
|
if len(bc.clientCapabilities) == 0 {
|
|
return
|
|
}
|
|
|
|
if req.Header == nil {
|
|
req.Header = http.Header{}
|
|
}
|
|
|
|
req.Header.Set(fleet.CapabilitiesHeader, bc.clientCapabilities.String())
|
|
}
|
|
|
|
func newBaseClient(
|
|
addr string,
|
|
insecureSkipVerify bool,
|
|
rootCA, urlPrefix string,
|
|
fleetClientCert *tls.Certificate,
|
|
capabilities fleet.CapabilityMap,
|
|
signerWrapper func(*http.Client) *http.Client,
|
|
) (*baseClient, error) {
|
|
baseURL, err := url.Parse(addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing URL: %w", err)
|
|
}
|
|
|
|
allowHTTP := insecureSkipVerify || strings.Contains(baseURL.Host, "localhost") || strings.Contains(baseURL.Host, "127.0.0.1")
|
|
if baseURL.Scheme != "https" && !allowHTTP {
|
|
return nil, errInvalidScheme
|
|
}
|
|
|
|
rootCAPool := x509.NewCertPool()
|
|
|
|
tlsConfig := &tls.Config{
|
|
// Osquery itself requires >= TLS 1.2.
|
|
// https://github.com/osquery/osquery/blob/9713ad9e28f1cfe6c16a823fb88bd531e39e192d/osquery/remote/transports/tls.cpp#L97-L98
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
if fleetClientCert != nil {
|
|
tlsConfig.Certificates = []tls.Certificate{*fleetClientCert}
|
|
}
|
|
|
|
switch {
|
|
case rootCA != "":
|
|
// read in the root cert file specified in the context
|
|
certs, err := os.ReadFile(rootCA)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading root CA: %w", err)
|
|
}
|
|
// add certs to pool
|
|
if ok := rootCAPool.AppendCertsFromPEM(certs); !ok {
|
|
return nil, errors.New("failed to add certificates to root CA pool")
|
|
}
|
|
tlsConfig.RootCAs = rootCAPool
|
|
case insecureSkipVerify:
|
|
// Ignoring "G402: TLS InsecureSkipVerify set true", needed for development/testing.
|
|
tlsConfig.InsecureSkipVerify = true //nolint:gosec
|
|
default:
|
|
rootCAPool, err = x509.SystemCertPool()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("loading system cert pool: %w", err)
|
|
}
|
|
tlsConfig.RootCAs = rootCAPool
|
|
}
|
|
|
|
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(tlsConfig))
|
|
if signerWrapper != nil {
|
|
httpClient = signerWrapper(httpClient)
|
|
}
|
|
client := &baseClient{
|
|
baseURL: baseURL,
|
|
http: httpClient,
|
|
insecureSkipVerify: insecureSkipVerify,
|
|
urlPrefix: urlPrefix,
|
|
clientCapabilities: capabilities,
|
|
serverCapabilities: fleet.CapabilityMap{},
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
type bodyHandler interface {
|
|
Handle(*http.Response) error
|
|
}
|
|
|
|
type FileResponse struct {
|
|
DestPath string
|
|
DestFile string
|
|
destFilePath string
|
|
SkipMediaType bool
|
|
ProgressFunc func(n int)
|
|
}
|
|
|
|
func (f *FileResponse) Handle(resp *http.Response) error {
|
|
var filename string
|
|
if !f.SkipMediaType {
|
|
_, params, err := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
|
|
if err != nil {
|
|
return fmt.Errorf("parsing media type from response header: %w", err)
|
|
}
|
|
filename = params["filename"]
|
|
}
|
|
|
|
if filename == "" {
|
|
filename = f.DestFile
|
|
}
|
|
if filename == "" {
|
|
filename = uuid.NewString()
|
|
}
|
|
|
|
f.destFilePath = filepath.Join(f.DestPath, filename)
|
|
destFile, err := os.Create(f.destFilePath)
|
|
if err != nil {
|
|
return fmt.Errorf("creating file: %w", err)
|
|
}
|
|
defer destFile.Close()
|
|
|
|
var respBodyReader io.Reader = resp.Body
|
|
if f.ProgressFunc != nil {
|
|
respBodyReader = &progressReader{
|
|
Reader: respBodyReader,
|
|
progressFunc: f.ProgressFunc,
|
|
}
|
|
}
|
|
|
|
_, err = io.Copy(destFile, respBodyReader)
|
|
if err != nil {
|
|
return fmt.Errorf("copying from http stream to file: %w", err)
|
|
}
|
|
|
|
if err := destFile.Close(); err != nil {
|
|
return fmt.Errorf("closing file after copy: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (f *FileResponse) GetFilePath() string {
|
|
return f.destFilePath
|
|
}
|
|
|
|
type progressReader struct {
|
|
io.Reader
|
|
progressFunc func(n int)
|
|
}
|
|
|
|
func (pr *progressReader) Read(p []byte) (int, error) {
|
|
n, err := pr.Reader.Read(p)
|
|
pr.progressFunc(n)
|
|
return n, err
|
|
}
|