mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
Added deny list for checking external user submitted urls (#39947)
This PR changes 3 things. 1. Validate `admin_url` + all URLs for HTTPS/non-private 2. Add custom `DialContext` hook in fleethttp.NewClient(), this is needed for DNS-rebinding protection at connection time 3. Validate Smallstep SCEP challenge endpoint # **IMPORTANT** There are two validations occurring. 1. `CheckURLForSSRF` 2. `SSRFDialContext` ## Why? `CheckURLForSSRF` checks the hostname. It resolves DNS, validates the ip, and then returns an error to the user. It protects certificate authority create/update API endpoints. But then `GetSmallstepSCEPChallenge` calls `http.NewRequest(http.MethodPost, ca.ChallengeURL, ...)` with the original hostname This is where `SSRFDialContext` comes into play. It fires when an actual HTTP request is attempted. Meaning Fleet would first build the request, encode the body, set up TLS, etc., before being blocked at the dial. `CheckURLForSSRF` stops the operation before any of that work happens. `SSRFDialContext` protects the actual challenge fetch that happens later at enrollment time. They're not always called together. The dial-time check is the only thing protecting the enrollment request and DNS rebinding. ## Should we remove `CheckURLForSSRF` This is debatable and I don't have a strong opinion. Removing `CheckURLForSSRF` would still provide the same protection. However, it would return a generic connection error from the HTTP client which would make it slightly hard to diagnose why it is broken. ## What's next I implemented this for certificate authorities. I am sure there are other places in the code base that take user submitted urls and could also use this check. That is outside the scope of this particular PR. But worthy to investigate in the near future. If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files) for more information. ## Testing - [x] Added/updated automated tests - [x] QA'd all new/changed functionality manually <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Security** * Added SSRF protections for validating external URLs and blocking private/IP-metadata ranges; dev mode can bypass checks for local testing * **New Features** * Introduced an SSRF-protected HTTP transport and an option to supply a custom transport per client * **Tests** * Added comprehensive tests covering SSRF validation, dialing behavior, and resolution edge cases <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
f437c13f19
commit
3d4a3e1b87
10 changed files with 591 additions and 65 deletions
1
changes/14284-external-deny-list
Normal file
1
changes/14284-external-deny-list
Normal file
|
|
@ -0,0 +1 @@
|
|||
* Added deny list for checking external urls the fleet server will attempt to contact that are user submitted. Refer to pkg/fleethttp/ssrf.go for full list. In development, the --dev flag skips this check so that testing locally is not impacted. Certificate authorities is the first place this is implemented.
|
||||
|
|
@ -498,6 +498,10 @@ team_settings:
|
|||
// At the same time, GitOps uploads Apple profiles that use the newly configured CAs.
|
||||
func (s *enterpriseIntegrationGitopsTestSuite) TestCAIntegrations() {
|
||||
t := s.T()
|
||||
|
||||
dev_mode.IsEnabled = true
|
||||
t.Cleanup(func() { dev_mode.IsEnabled = false })
|
||||
|
||||
user := s.createGitOpsUser(t)
|
||||
fleetctlConfig := s.createFleetctlConfig(t, user)
|
||||
|
||||
|
|
@ -3477,7 +3481,6 @@ team_settings:
|
|||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
|
@ -208,8 +208,8 @@ func (svc *Service) validatePayload(p *fleet.CertificateAuthorityPayload, errPre
|
|||
}
|
||||
|
||||
func (svc *Service) validateDigicert(ctx context.Context, digicertCA *fleet.DigiCertCA, errPrefix string) error {
|
||||
if err := validateURL(digicertCA.URL, "DigiCert", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, digicertCA.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sDigiCert URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if digicertCA.APIToken == "" || digicertCA.APIToken == fleet.MaskedPassword {
|
||||
return fleet.NewInvalidArgumentError("api_token", fmt.Sprintf("%sInvalid API token. Please correct and try again.", errPrefix))
|
||||
|
|
@ -321,8 +321,8 @@ func (svc *Service) validateHydrant(ctx context.Context, hydrantCA *fleet.Hydran
|
|||
if err := validateCAName(hydrantCA.Name, errPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateURL(hydrantCA.URL, "Hydrant", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, hydrantCA.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sHydrant URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if hydrantCA.ClientID == "" {
|
||||
return fleet.NewInvalidArgumentError("client_id", fmt.Sprintf("%sInvalid Hydrant Client ID. Please correct and try again.", errPrefix))
|
||||
|
|
@ -346,8 +346,8 @@ func (svc *Service) validateEST(ctx context.Context, estProxyCA *fleet.ESTProxyC
|
|||
if err := validateCAName(estProxyCA.Name, errPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateURL(estProxyCA.URL, "EST", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, estProxyCA.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sEST URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if estProxyCA.Username == "" {
|
||||
return fleet.NewInvalidArgumentError("username", fmt.Sprintf("%sInvalid EST Username. Please correct and try again.", errPrefix))
|
||||
|
|
@ -361,18 +361,12 @@ func (svc *Service) validateEST(ctx context.Context, estProxyCA *fleet.ESTProxyC
|
|||
return nil
|
||||
}
|
||||
|
||||
func validateURL(caURL, displayType, errPrefix string) error {
|
||||
if u, err := url.ParseRequestURI(caURL); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sInvalid %s URL. Please correct and try again.", errPrefix, displayType))
|
||||
} else if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%s%s URL scheme must be https or http", errPrefix, displayType))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) validateNDESSCEPProxy(ctx context.Context, ndesSCEP *fleet.NDESSCEPProxyCA, errPrefix string) error {
|
||||
if err := validateURL(ndesSCEP.URL, "NDES SCEP", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, ndesSCEP.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sNDES SCEP URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, ndesSCEP.AdminURL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("admin_url", fmt.Sprintf("%sNDES SCEP admin URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if err := svc.scepConfigService.ValidateSCEPURL(ctx, ndesSCEP.URL); err != nil {
|
||||
level.Error(svc.logger).Log("msg", "Failed to validate NDES SCEP URL", "err", err)
|
||||
|
|
@ -396,8 +390,8 @@ func (svc *Service) validateCustomSCEPProxy(ctx context.Context, customSCEP *fle
|
|||
if err := validateCAName(customSCEP.Name, errPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateURL(customSCEP.URL, "SCEP", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, customSCEP.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sCustom SCEP Proxy URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if customSCEP.Challenge == "" || customSCEP.Challenge == fleet.MaskedPassword {
|
||||
return fleet.NewInvalidArgumentError("challenge", fmt.Sprintf("%sCustom SCEP Proxy challenge cannot be empty", errPrefix))
|
||||
|
|
@ -413,8 +407,8 @@ func (svc *Service) validateSmallstepSCEPProxy(ctx context.Context, smallstepSCE
|
|||
if err := validateCAName(smallstepSCEP.Name, errPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateURL(smallstepSCEP.URL, "Smallstep SCEP", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, smallstepSCEP.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sSmallstep SCEP URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if smallstepSCEP.Username == "" {
|
||||
return fleet.NewInvalidArgumentError("username", fmt.Sprintf("%sSmallstep username cannot be empty", errPrefix))
|
||||
|
|
@ -1257,10 +1251,9 @@ func (svc *Service) validateDigicertUpdate(ctx context.Context, digicert *fleet.
|
|||
}
|
||||
}
|
||||
if digicert.URL != nil {
|
||||
if err := validateURL(*digicert.URL, "DigiCert", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *digicert.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sDigiCert URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
|
||||
// We want to generate a DigiCertCA struct with all required fields to verify the new URL.
|
||||
// If URL or APIToken are not being updated we use the existing values from oldCA
|
||||
digicertCA := fleet.DigiCertCA{
|
||||
|
|
@ -1338,10 +1331,9 @@ func (svc *Service) validateHydrantUpdate(ctx context.Context, hydrant *fleet.Hy
|
|||
}
|
||||
}
|
||||
if hydrant.URL != nil {
|
||||
if err := validateURL(*hydrant.URL, "Hydrant", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *hydrant.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sHydrant URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
|
||||
hydrantCAToVerify := fleet.ESTProxyCA{ // The hydrant service for verification only requires the URL.
|
||||
URL: *hydrant.URL,
|
||||
}
|
||||
|
|
@ -1370,10 +1362,9 @@ func (svc *Service) validateCustomESTUpdate(ctx context.Context, estUpdate *flee
|
|||
}
|
||||
}
|
||||
if estUpdate.URL != nil {
|
||||
if err := validateURL(*estUpdate.URL, "EST", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *estUpdate.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sEST URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
|
||||
hydrantCAToVerify := fleet.ESTProxyCA{ // The EST service for verification only requires the URL.
|
||||
URL: *estUpdate.URL,
|
||||
}
|
||||
|
|
@ -1399,8 +1390,8 @@ func (svc *Service) validateNDESSCEPProxyUpdate(ctx context.Context, ndesSCEP *f
|
|||
// some methods in this fuction require the NDESSCEPProxyCA type so we convert the ndes update payload here
|
||||
|
||||
if ndesSCEP.URL != nil {
|
||||
if err := validateURL(*ndesSCEP.URL, "NDES SCEP", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *ndesSCEP.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sNDES SCEP URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if err := svc.scepConfigService.ValidateSCEPURL(ctx, *ndesSCEP.URL); err != nil {
|
||||
level.Error(svc.logger).Log("msg", "Failed to validate NDES SCEP URL", "err", err)
|
||||
|
|
@ -1414,6 +1405,9 @@ func (svc *Service) validateNDESSCEPProxyUpdate(ctx context.Context, ndesSCEP *f
|
|||
}
|
||||
}
|
||||
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *ndesSCEP.AdminURL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("admin_url", fmt.Sprintf("%sNDES SCEP admin URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
// We want to generate a NDESSCEPProxyCA struct with all required fields to verify the admin URL.
|
||||
// If URL, Username or Password are not being updated we use the existing values from oldCA
|
||||
NDESProxy := fleet.NDESSCEPProxyCA{
|
||||
|
|
@ -1457,8 +1451,8 @@ func (svc *Service) validateCustomSCEPProxyUpdate(ctx context.Context, customSCE
|
|||
}
|
||||
}
|
||||
if customSCEP.URL != nil {
|
||||
if err := validateURL(*customSCEP.URL, "SCEP", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *customSCEP.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sCustom SCEP Proxy URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if err := svc.scepConfigService.ValidateSCEPURL(ctx, *customSCEP.URL); err != nil {
|
||||
level.Error(svc.logger).Log("msg", "Failed to validate custom SCEP URL", "err", err)
|
||||
|
|
@ -1481,8 +1475,8 @@ func (svc *Service) validateSmallstepSCEPProxyUpdate(ctx context.Context, smalls
|
|||
}
|
||||
}
|
||||
if smallstep.URL != nil {
|
||||
if err := validateURL(*smallstep.URL, "SCEP", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *smallstep.URL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sSmallstep SCEP URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
if err := svc.scepConfigService.ValidateSCEPURL(ctx, *smallstep.URL); err != nil {
|
||||
level.Error(svc.logger).Log("msg", "Failed to validate Smallstep SCEP URL", "err", err)
|
||||
|
|
@ -1506,8 +1500,8 @@ func (svc *Service) validateSmallstepSCEPProxyUpdate(ctx context.Context, smalls
|
|||
|
||||
// Additional validation if url was updated
|
||||
if smallstep.ChallengeURL != nil {
|
||||
if err := validateURL(*smallstep.ChallengeURL, "Challenge", errPrefix); err != nil {
|
||||
return err
|
||||
if err := fleethttp.CheckURLForSSRF(ctx, *smallstep.ChallengeURL, nil); err != nil {
|
||||
return fleet.NewInvalidArgumentError("challenge_url", fmt.Sprintf("%sChallenge URL is invalid: %v", errPrefix, err))
|
||||
}
|
||||
smallstepSCEPProxy.ChallengeURL = *smallstep.ChallengeURL
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/ee/server/service/est"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/dev_mode"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
scep_mock "github.com/fleetdm/fleet/v4/server/mock/scep"
|
||||
|
|
@ -144,6 +145,10 @@ func setupMockCAServers(t *testing.T) (digicertServer, hydrantServer *httptest.S
|
|||
}
|
||||
|
||||
func TestCreatingCertificateAuthorities(t *testing.T) {
|
||||
// Enable dev mode so CheckURLForSSRF skips the private-IP blocklist for the duration of this test.
|
||||
dev_mode.IsEnabled = true
|
||||
t.Cleanup(func() { dev_mode.IsEnabled = false })
|
||||
|
||||
digicertServer, hydrantServer := setupMockCAServers(t)
|
||||
digicertURL := digicertServer.URL
|
||||
hydrantURL := hydrantServer.URL
|
||||
|
|
@ -534,7 +539,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
createdCA, err := svc.NewCertificateAuthority(ctx, createDigicertRequest)
|
||||
require.ErrorContains(t, err, "Invalid DigiCert URL")
|
||||
require.ErrorContains(t, err, "DigiCert URL is invalid")
|
||||
require.Len(t, createdCAs, 0)
|
||||
require.Nil(t, createdCA)
|
||||
})
|
||||
|
|
@ -675,7 +680,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
createdCA, err := svc.NewCertificateAuthority(ctx, createHydrantRequest)
|
||||
require.ErrorContains(t, err, "Invalid Hydrant URL.")
|
||||
require.ErrorContains(t, err, "Hydrant URL is invalid")
|
||||
require.Len(t, createdCAs, 0)
|
||||
require.Nil(t, createdCA)
|
||||
})
|
||||
|
|
@ -763,7 +768,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
createdCA, err := svc.NewCertificateAuthority(ctx, createCustomSCEPRequest)
|
||||
require.ErrorContains(t, err, "Invalid SCEP URL.")
|
||||
require.ErrorContains(t, err, "Custom SCEP Proxy URL is invalid")
|
||||
require.Len(t, createdCAs, 0)
|
||||
require.Nil(t, createdCA)
|
||||
})
|
||||
|
|
@ -819,7 +824,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
createdCA, err := svc.NewCertificateAuthority(ctx, createNDESSCEPRequest)
|
||||
require.ErrorContains(t, err, "Invalid NDES SCEP URL.")
|
||||
require.ErrorContains(t, err, "NDES SCEP URL is invalid")
|
||||
require.Len(t, createdCAs, 0)
|
||||
require.Nil(t, createdCA)
|
||||
})
|
||||
|
|
@ -979,7 +984,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
createdCA, err := svc.NewCertificateAuthority(ctx, createSmallstepRequest)
|
||||
require.ErrorContains(t, err, "Invalid Smallstep SCEP URL.")
|
||||
require.ErrorContains(t, err, "Smallstep SCEP URL is invalid")
|
||||
require.Len(t, createdCAs, 0)
|
||||
require.Nil(t, createdCA)
|
||||
})
|
||||
|
|
@ -1092,7 +1097,9 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUpdatingCertificateAuthorities(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Enable dev mode so CheckURLForSSRF skips the private-IP blocklist for the duration of this test.
|
||||
dev_mode.IsEnabled = true
|
||||
t.Cleanup(func() { dev_mode.IsEnabled = false })
|
||||
|
||||
digicertServer, hydrantServer := setupMockCAServers(t)
|
||||
digicertURL := digicertServer.URL
|
||||
|
|
@ -1360,7 +1367,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
err := svc.UpdateCertificateAuthority(ctx, digicertID, payload)
|
||||
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid DigiCert URL. Please correct and try again.")
|
||||
require.ErrorContains(t, err, "DigiCert URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("Bad URL Path", func(t *testing.T) {
|
||||
|
|
@ -1502,7 +1509,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
err := svc.UpdateCertificateAuthority(ctx, hydrantID, payload)
|
||||
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid Hydrant URL. Please correct and try again.")
|
||||
require.ErrorContains(t, err, "Hydrant URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("Bad URL", func(t *testing.T) {
|
||||
|
|
@ -1588,7 +1595,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
err := svc.UpdateCertificateAuthority(ctx, scepID, payload)
|
||||
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid SCEP URL. Please correct and try again.")
|
||||
require.ErrorContains(t, err, "Custom SCEP Proxy URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("Requires challenge when updating URL", func(t *testing.T) {
|
||||
|
|
@ -1651,7 +1658,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
err := svc.UpdateCertificateAuthority(ctx, ndesID, payload)
|
||||
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid NDES SCEP URL. Please correct and try again.")
|
||||
require.ErrorContains(t, err, "NDES SCEP URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("Bad SCEP URL", func(t *testing.T) {
|
||||
|
|
@ -1810,7 +1817,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
err := svc.UpdateCertificateAuthority(ctx, smallstepID, payload)
|
||||
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid SCEP URL. Please correct and try again.")
|
||||
require.ErrorContains(t, err, "Smallstep SCEP URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("Invalid Challenge URL format", func(t *testing.T) {
|
||||
|
|
@ -1826,7 +1833,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
|
|||
}
|
||||
|
||||
err := svc.UpdateCertificateAuthority(ctx, smallstepID, payload)
|
||||
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid Challenge URL. Please correct and try again.")
|
||||
require.ErrorContains(t, err, "Challenge URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("Bad Smallstep SCEP URL", func(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -513,7 +513,7 @@ func (s *SCEPConfigService) GetNDESSCEPChallenge(ctx context.Context, proxy flee
|
|||
// Get the challenge from NDES
|
||||
client := fleethttp.NewClient(fleethttp.WithTimeout(*s.Timeout))
|
||||
client.Transport = ntlmssp.Negotiator{
|
||||
RoundTripper: fleethttp.NewTransport(),
|
||||
RoundTripper: fleethttp.NewSSRFProtectedTransport(),
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, adminURL, http.NoBody)
|
||||
if err != nil {
|
||||
|
|
@ -586,8 +586,10 @@ func (s *SCEPConfigService) ValidateSmallstepChallengeURL(ctx context.Context, c
|
|||
}
|
||||
|
||||
func (s *SCEPConfigService) GetSmallstepSCEPChallenge(ctx context.Context, ca fleet.SmallstepSCEPProxyCA) (string, error) {
|
||||
// Get the challenge from Smallstep
|
||||
client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second))
|
||||
client := fleethttp.NewClient(
|
||||
fleethttp.WithTimeout(30*time.Second),
|
||||
fleethttp.WithTransport(fleethttp.NewSSRFProtectedTransport()),
|
||||
)
|
||||
var reqBody bytes.Buffer
|
||||
if err := json.NewEncoder(&reqBody).Encode(fleet.SmallstepChallengeRequestBody{
|
||||
Webhook: fleet.SmallstepChallengeWebhook{
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import (
|
|||
type clientOpts struct {
|
||||
timeout time.Duration
|
||||
tlsConf *tls.Config
|
||||
transport http.RoundTripper
|
||||
noFollow bool
|
||||
cookieJar http.CookieJar
|
||||
}
|
||||
|
|
@ -56,6 +57,14 @@ func WithCookieJar(jar http.CookieJar) ClientOpt {
|
|||
}
|
||||
}
|
||||
|
||||
// WithTransport sets an explicit RoundTripper on the HTTP client. When set,
|
||||
// this takes precedence over WithTLSClientConfig.
|
||||
func WithTransport(t http.RoundTripper) ClientOpt {
|
||||
return func(o *clientOpts) {
|
||||
o.transport = t
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient returns an HTTP client configured according to the provided
|
||||
// options.
|
||||
func NewClient(opts ...ClientOpt) *http.Client {
|
||||
|
|
@ -71,7 +80,10 @@ func NewClient(opts ...ClientOpt) *http.Client {
|
|||
if co.noFollow {
|
||||
cli.CheckRedirect = noFollowRedirect
|
||||
}
|
||||
if co.tlsConf != nil {
|
||||
switch {
|
||||
case co.transport != nil:
|
||||
cli.Transport = co.transport
|
||||
case co.tlsConf != nil:
|
||||
cli.Transport = NewTransport(WithTLSConfig(co.tlsConf))
|
||||
}
|
||||
if co.cookieJar != nil {
|
||||
|
|
@ -113,6 +125,14 @@ func NewTransport(opts ...TransportOpt) *http.Transport {
|
|||
return tr
|
||||
}
|
||||
|
||||
// Override DialContext with the SSRF-blocking dialer so that every
|
||||
// outbound TCP connection is checked against the blocklist at dial time.
|
||||
func NewSSRFProtectedTransport(opts ...TransportOpt) *http.Transport {
|
||||
tr := NewTransport(opts...)
|
||||
tr.DialContext = SSRFDialContext(nil, nil, nil)
|
||||
return tr
|
||||
}
|
||||
|
||||
func noFollowRedirect(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,6 @@ func TestTransport(t *testing.T) {
|
|||
assert.NotEqual(t, defaultTLSConf, tr.TLSClientConfig)
|
||||
}
|
||||
assert.NotNil(t, tr.Proxy)
|
||||
assert.NotNil(t, tr.DialContext)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
195
pkg/fleethttp/ssrf.go
Normal file
195
pkg/fleethttp/ssrf.go
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
package fleethttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/dev_mode"
|
||||
)
|
||||
|
||||
var (
|
||||
// https://en.wikipedia.org/wiki/Reserved_IP_addresses
|
||||
ipv4Blocklist = []string{
|
||||
"0.0.0.0/8", // Current network (only valid as source address)
|
||||
"10.0.0.0/8", // Private network
|
||||
"100.64.0.0/10", // Shared Address Space
|
||||
"127.0.0.0/8", // Loopback
|
||||
"169.254.0.0/16", // Link-local
|
||||
"172.16.0.0/12", // Private network
|
||||
"192.0.0.0/24", // IETF Protocol Assignments
|
||||
"192.0.2.0/24", // TEST-NET-1, documentation and examples
|
||||
"192.88.99.0/24", // IPv6 to IPv4 relay (includes 2002::/16)
|
||||
"192.168.0.0/16", // Private network
|
||||
"198.18.0.0/15", // Network benchmark tests
|
||||
"198.51.100.0/24", // TEST-NET-2, documentation and examples
|
||||
"203.0.113.0/24", // TEST-NET-3, documentation and examples
|
||||
"224.0.0.0/4", // IP multicast (former Class D network)
|
||||
"240.0.0.0/4", // Reserved (former Class E network)
|
||||
"255.255.255.255/32", // Broadcast
|
||||
}
|
||||
|
||||
ipv6Blocklist = []string{
|
||||
"::1/128", // Loopback
|
||||
"64:ff9b::/96", // IPv4/IPv6 translation (RFC 6052)
|
||||
"64:ff9b:1::/48", // Local-use IPv4/IPv6 translation (RFC 8215)
|
||||
"100::/64", // Discard prefix (RFC 6666)
|
||||
"2001::/32", // Teredo tunneling
|
||||
"2001:10::/28", // Deprecated (previously ORCHID)
|
||||
"2001:20::/28", // ORCHIDv2
|
||||
"2001:db8::/32", // Documentation and example source code
|
||||
"2002::/16", // 6to4
|
||||
"3fff::/20", // Documentation (RFC 9637, 2024)
|
||||
"5f00::/16", // IPv6 Segment Routing (SRv6)
|
||||
"fc00::/7", // Unique local address
|
||||
"fe80::/10", // Link-local address
|
||||
"ff00::/8", // Multicast
|
||||
}
|
||||
)
|
||||
|
||||
var blockedCIDRs []*net.IPNet
|
||||
|
||||
func init() {
|
||||
for _, cidr := range append(ipv4Blocklist, ipv6Blocklist...) {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("fleethttp: invalid blocked CIDR %q: %v", cidr, err))
|
||||
}
|
||||
blockedCIDRs = append(blockedCIDRs, network)
|
||||
}
|
||||
}
|
||||
|
||||
// isBlockedIP returns true when ip falls within any of the protected ranges.
|
||||
func isBlockedIP(ip net.IP) bool {
|
||||
if ip4 := ip.To4(); ip4 != nil && len(ip) == net.IPv6len {
|
||||
ip = ip4
|
||||
}
|
||||
for _, network := range blockedCIDRs {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SSRFError is returned when a URL resolves to a protected IP range.
|
||||
type SSRFError struct {
|
||||
URL string
|
||||
IP net.IP
|
||||
}
|
||||
|
||||
func (e *SSRFError) Error() string {
|
||||
return fmt.Sprintf("URL %q resolves to a blocked address", e.URL)
|
||||
}
|
||||
|
||||
func checkResolvedAddrs(ctx context.Context, host, rawURL string, resolver func(context.Context, string) ([]string, error)) ([]net.IP, error) {
|
||||
addrs, err := resolver(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving host %q: %w", host, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("host %q resolved to no addresses", host)
|
||||
}
|
||||
safe := make([]net.IP, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
h, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
h = addr
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("resolved address %q for host %q is not a valid IP", h, host)
|
||||
}
|
||||
if isBlockedIP(ip) {
|
||||
return nil, &SSRFError{URL: rawURL, IP: ip}
|
||||
}
|
||||
safe = append(safe, ip)
|
||||
}
|
||||
return safe, nil
|
||||
}
|
||||
|
||||
// CheckURLForSSRF validates rawURL against SSRF attack vectors using a static blocklist.
|
||||
func CheckURLForSSRF(ctx context.Context, rawURL string, resolver func(ctx context.Context, host string) ([]string, error)) error {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
scheme := parsed.Scheme
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return fmt.Errorf("URL scheme %q is not allowed; must be http or https", scheme)
|
||||
}
|
||||
|
||||
hostname := parsed.Hostname()
|
||||
if hostname == "" {
|
||||
return errors.New("URL has no host")
|
||||
}
|
||||
|
||||
if dev_mode.IsEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
if isBlockedIP(ip) {
|
||||
return &SSRFError{URL: rawURL, IP: ip}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver.LookupHost
|
||||
}
|
||||
_, err = checkResolvedAddrs(ctx, hostname, rawURL, resolver)
|
||||
return err
|
||||
}
|
||||
|
||||
// SSRFDialContext returns a DialContext function that validates against SSRF attack vectors using a static blocklist.
|
||||
func SSRFDialContext(
|
||||
base *net.Dialer,
|
||||
resolver func(ctx context.Context, host string) ([]string, error),
|
||||
dial func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if base == nil {
|
||||
base = &net.Dialer{}
|
||||
}
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver.LookupHost
|
||||
}
|
||||
if dial == nil {
|
||||
dial = base.DialContext
|
||||
}
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if dev_mode.IsEnabled {
|
||||
return dial(ctx, network, addr)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf dial: splitting host/port from %q: %w", addr, err)
|
||||
}
|
||||
|
||||
safeIPs, err := checkResolvedAddrs(ctx, host, net.JoinHostPort(host, port), resolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// net.Dialer has no API to accept a pre-resolved IP list
|
||||
// This is similar to what go does with dialSerial
|
||||
// /usr/local/go/src/net/dial.go#dialSerial
|
||||
var lastErr error
|
||||
for _, ip := range safeIPs {
|
||||
var conn net.Conn
|
||||
conn, lastErr = dial(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
if lastErr == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
// Context cancelled/timed out — no point trying remaining IPs.
|
||||
return nil, lastErr
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
}
|
||||
301
pkg/fleethttp/ssrf_test.go
Normal file
301
pkg/fleethttp/ssrf_test.go
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
package fleethttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// noopResolver returns a known public IP so that CheckURLForSSRF reaches the
|
||||
// IP-range check without triggering real DNS lookups.
|
||||
func noopResolver(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"93.184.216.34"}, nil // example.com
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFBlockedLiteralIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
blocked := []string{
|
||||
"http://127.0.0.1/mscep/mscep.dll",
|
||||
"http://127.255.255.255/path",
|
||||
"http://10.0.0.1/admin",
|
||||
"http://10.255.255.255/admin",
|
||||
"http://172.16.0.1/admin",
|
||||
"http://172.31.255.255/admin",
|
||||
"http://192.168.0.1/admin",
|
||||
"http://192.168.255.255/admin",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://169.254.0.1/whatever",
|
||||
"http://100.64.0.1/admin",
|
||||
"http://100.127.255.255/admin",
|
||||
"http://0.0.0.0/path",
|
||||
"http://[::1]/path",
|
||||
"http://[fe80::1]/path",
|
||||
"http://[fc00::1]/path",
|
||||
"http://[fdff::1]/path",
|
||||
}
|
||||
|
||||
for _, u := range blocked {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := CheckURLForSSRF(context.Background(), u, noopResolver)
|
||||
require.Error(t, err, "expected SSRF block for %s", u)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr), "expected SSRFError for %s, got %T: %v", u, err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFAllowedPublicIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
allowed := []string{
|
||||
"https://ndes.corp.example.com/mscep/mscep.dll",
|
||||
"https://93.184.216.34/path", // example.com
|
||||
"http://8.8.8.8/path", // Google DNS
|
||||
"https://1.1.1.1/path", // Cloudflare DNS
|
||||
}
|
||||
|
||||
for _, u := range allowed {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := CheckURLForSSRF(context.Background(), u, noopResolver)
|
||||
assert.NoError(t, err, "expected no SSRF block for %s", u)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFDNSResolutionBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Simulate a hostname that resolves to a private IP
|
||||
privateResolver := func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"192.168.1.100"}, nil
|
||||
}
|
||||
|
||||
err := CheckURLForSSRF(context.Background(), "https://attacker-controlled.example.com/admin", privateResolver)
|
||||
require.Error(t, err)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr))
|
||||
assert.Equal(t, net.ParseIP("192.168.1.100").String(), ssrfErr.IP.String())
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFMetadataEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadataURLs := []string{
|
||||
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
|
||||
"http://169.254.169.254/metadata/instance?api-version=2021-02-01",
|
||||
}
|
||||
|
||||
for _, u := range metadataURLs {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := CheckURLForSSRF(context.Background(), u, noopResolver)
|
||||
require.Error(t, err)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFBadScheme(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := CheckURLForSSRF(context.Background(), "file:///etc/passwd", noopResolver)
|
||||
require.Error(t, err)
|
||||
var ssrfErr *SSRFError
|
||||
assert.False(t, errors.As(err, &ssrfErr), "bad-scheme error should not be an SSRFError")
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFResolverError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
failResolver := func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("simulated DNS failure")
|
||||
}
|
||||
err := CheckURLForSSRF(context.Background(), "https://cant-resolve.example.com/admin", failResolver)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "resolving host")
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRF_UnparseableAddressFailsClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A custom resolver returning a non-IP string will be blocked
|
||||
badResolver := func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"not-an-ip"}, nil
|
||||
}
|
||||
err := CheckURLForSSRF(context.Background(), "https://example.com/admin", badResolver)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not a valid IP")
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFMultipleResolutions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mixedResolver := func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{"93.184.216.34", "10.0.0.1"}, nil
|
||||
}
|
||||
err := CheckURLForSSRF(context.Background(), "https://mixed.example.com/admin", mixedResolver)
|
||||
require.Error(t, err)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr))
|
||||
assert.Equal(t, net.ParseIP("10.0.0.1").String(), ssrfErr.IP.String())
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFIPv4MappedBypass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// An attacker could supply an IPv4-mapped IPv6 address like ::ffff:192.168.1.1
|
||||
// to reach a private IPv4 host while bypassing the IPv4 blocklist check.
|
||||
blocked := []string{
|
||||
"http://[::ffff:192.168.1.1]/admin", // RFC 1918 private
|
||||
"http://[::ffff:127.0.0.1]/admin", // Loopback
|
||||
"http://[::ffff:169.254.169.254]/admin", // Link-local metadata
|
||||
"http://[::ffff:10.0.0.1]/admin", // RFC 1918 private
|
||||
}
|
||||
for _, u := range blocked {
|
||||
t.Run(u, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := CheckURLForSSRF(context.Background(), u, noopResolver)
|
||||
require.Error(t, err, "expected SSRF block for IPv4-mapped %s", u)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr), "expected SSRFError for %s, got %T: %v", u, err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURLForSSRFSSRFErrorMessage(t *testing.T) {
|
||||
err := CheckURLForSSRF(context.Background(), "http://127.0.0.1/admin", noopResolver)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "blocked address")
|
||||
assert.Contains(t, err.Error(), "127.0.0.1")
|
||||
}
|
||||
|
||||
// noopDial is used as the dial parameter so tests never open real sockets.
|
||||
func noopDial(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return nil, errors.New("no-op dial: connection not attempted in tests")
|
||||
}
|
||||
|
||||
// captureDial records the addr passed to dial without opening a real socket.
|
||||
func captureDial(got *string) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(_ context.Context, _, addr string) (net.Conn, error) {
|
||||
*got = addr
|
||||
return nil, errors.New("no-op dial: connection not attempted in tests")
|
||||
}
|
||||
}
|
||||
|
||||
// staticResolver returns a fixed list of IPs for any host.
|
||||
func staticResolver(ips ...string) func(ctx context.Context, host string) ([]string, error) {
|
||||
return func(_ context.Context, _ string) ([]string, error) {
|
||||
return ips, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSRFDialContextBlocksPrivateIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
blocked := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"127.0.0.1:80", "127.0.0.1"},
|
||||
{"10.0.0.1:443", "10.0.0.1"},
|
||||
{"172.16.0.1:8080", "172.16.0.1"},
|
||||
{"192.168.1.1:443", "192.168.1.1"},
|
||||
{"169.254.169.254:80", "169.254.169.254"},
|
||||
}
|
||||
|
||||
for _, tc := range blocked {
|
||||
t.Run(tc.addr, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dial := SSRFDialContext(nil, staticResolver(tc.ip), noopDial)
|
||||
conn, err := dial(context.Background(), "tcp", tc.addr)
|
||||
require.Error(t, err, "expected dial to be blocked for %s", tc.addr)
|
||||
assert.Nil(t, conn)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr), "expected SSRFError for %s, got %T: %v", tc.addr, err, err)
|
||||
assert.Equal(t, net.ParseIP(tc.ip).String(), ssrfErr.IP.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSRFDialContextAllowsPublicIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publicIPs := []string{
|
||||
"93.184.216.34",
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
}
|
||||
|
||||
for _, publicIP := range publicIPs {
|
||||
t.Run(publicIP, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dial := SSRFDialContext(nil, staticResolver(publicIP), noopDial)
|
||||
_, err := dial(context.Background(), "tcp", publicIP+":80")
|
||||
var ssrfErr *SSRFError
|
||||
assert.False(t, errors.As(err, &ssrfErr), "public IP %s should not be SSRF-blocked", publicIP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSRFDialContextBlocksMixedResolution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Simulates DNS rebinding: resolver returns one public and one private IP.
|
||||
dial := SSRFDialContext(nil, staticResolver("93.184.216.34", "192.168.1.100"), noopDial)
|
||||
|
||||
_, err := dial(context.Background(), "tcp", "attacker.example.com:443")
|
||||
require.Error(t, err)
|
||||
var ssrfErr *SSRFError
|
||||
assert.True(t, errors.As(err, &ssrfErr))
|
||||
assert.Equal(t, net.ParseIP("192.168.1.100").String(), ssrfErr.IP.String())
|
||||
}
|
||||
|
||||
func TestSSRFDialContextResolverError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
failResolver := func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("simulated DNS failure")
|
||||
}
|
||||
dial := SSRFDialContext(nil, failResolver, noopDial)
|
||||
_, err := dial(context.Background(), "tcp", "cant-resolve.example.com:443")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "resolving")
|
||||
}
|
||||
|
||||
func TestSSRFDialContextDialsResolvedIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var gotAddr string
|
||||
dialFn := SSRFDialContext(nil, staticResolver("93.184.216.34"), captureDial(&gotAddr))
|
||||
_, _ = dialFn(context.Background(), "tcp", "example.com:443")
|
||||
|
||||
// The dialer must receive the resolved IP, not "example.com".
|
||||
host, port, err := net.SplitHostPort(gotAddr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "443", port)
|
||||
assert.NotEmpty(t, host)
|
||||
assert.NotEqual(t, "example.com", host, "dialer must receive resolved IP, not the hostname")
|
||||
assert.Equal(t, net.ParseIP("93.184.216.34").String(), net.ParseIP(host).String())
|
||||
}
|
||||
|
||||
func TestSSRFDialContextNilsUseDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dial := SSRFDialContext(nil, nil, nil)
|
||||
require.NotNil(t, dial)
|
||||
}
|
||||
|
||||
func TestNewSSRFProtectedTransportHasDialContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tr := NewSSRFProtectedTransport()
|
||||
require.NotNil(t, tr.DialContext, "NewSSRFProtectedTransport() must set DialContext for SSRF protection")
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/dev_mode"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/apple/mobileconfig"
|
||||
|
|
@ -34,6 +35,9 @@ import (
|
|||
func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
||||
t := s.T()
|
||||
|
||||
dev_mode.IsEnabled = true
|
||||
t.Cleanup(func() { dev_mode.IsEnabled = false })
|
||||
|
||||
// TODO(hca): test each CA type activities once implemented
|
||||
|
||||
// TODO(hca) test each CA type cannot configure without private key?
|
||||
|
|
@ -229,7 +233,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
{
|
||||
testName: "non-http",
|
||||
url: "nonhttp://bad.com",
|
||||
errMessage: "URL scheme must be https or http",
|
||||
errMessage: "must be http or https",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -261,7 +265,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
res := s.Do("POST", "/api/v1/fleet/spec/certificate_authorities", req, http.StatusUnprocessableEntity)
|
||||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.ndes_scep_proxy")
|
||||
require.Contains(t, errMsg, "Invalid NDES SCEP URL")
|
||||
require.Contains(t, errMsg, "NDES SCEP URL is invalid")
|
||||
checkNDESApplied(t, nil)
|
||||
})
|
||||
|
||||
|
|
@ -432,7 +436,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.digicert")
|
||||
if tc.errMessage == "Invalid URL" {
|
||||
require.Contains(t, errMsg, "Invalid DigiCert URL")
|
||||
require.Contains(t, errMsg, "DigiCert URL is invalid")
|
||||
} else {
|
||||
require.Contains(t, errMsg, tc.errMessage)
|
||||
}
|
||||
|
|
@ -779,7 +783,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.custom_scep_proxy")
|
||||
if tc.errMessage == "Invalid URL" {
|
||||
require.Contains(t, errMsg, "Invalid SCEP URL")
|
||||
require.Contains(t, errMsg, "Custom SCEP Proxy URL is invalid")
|
||||
} else {
|
||||
require.Contains(t, errMsg, tc.errMessage)
|
||||
}
|
||||
|
|
@ -1004,7 +1008,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.smallstep")
|
||||
if tc.errMessage == "Invalid URL" {
|
||||
require.Contains(t, errMsg, "Invalid Smallstep SCEP URL")
|
||||
require.Contains(t, errMsg, "Smallstep SCEP URL is invalid")
|
||||
} else {
|
||||
require.Contains(t, errMsg, tc.errMessage)
|
||||
}
|
||||
|
|
@ -1107,7 +1111,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
res := s.Do("POST", "/api/v1/fleet/spec/certificate_authorities", req, http.StatusUnprocessableEntity)
|
||||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.smallstep")
|
||||
require.Contains(t, errMsg, "Invalid Smallstep SCEP URL")
|
||||
require.Contains(t, errMsg, "Smallstep SCEP URL is invalid")
|
||||
})
|
||||
|
||||
t.Run("smallstep challenge url not set", func(t *testing.T) {
|
||||
|
|
@ -1264,7 +1268,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.hydrant")
|
||||
if tc.errMessage == "Invalid URL" {
|
||||
require.Contains(t, errMsg, "Invalid Hydrant URL")
|
||||
require.Contains(t, errMsg, "Hydrant URL is invalid")
|
||||
} else {
|
||||
require.Contains(t, errMsg, tc.errMessage)
|
||||
}
|
||||
|
|
@ -1390,7 +1394,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
|
|||
errMsg := extractServerErrorText(res.Body)
|
||||
require.Contains(t, errMsg, "certificate_authorities.custom_est_proxy")
|
||||
if tc.errMessage == "Invalid URL" {
|
||||
require.Contains(t, errMsg, "Invalid EST URL")
|
||||
require.Contains(t, errMsg, "EST URL is invalid")
|
||||
} else {
|
||||
require.Contains(t, errMsg, tc.errMessage)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue