Revert "Added deny list for checking external user submitted urls"

This reverts commit 3d4a3e1b87.
This commit is contained in:
Konstantin Sykulev 2026-02-24 16:29:08 -06:00 committed by GitHub
parent 58f8e290d9
commit 8757d365bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 65 additions and 591 deletions

View file

@ -1 +0,0 @@
* Added deny list for checking external urls the fleet server will attempt to contact that are user submitted. Refer to pkg/fleethttp/ssrf.go for full list. In development, the --dev flag skips this check so that testing locally is not impacted. Certificate authorities is the first place this is implemented.

View file

@ -522,10 +522,6 @@ settings:
// At the same time, GitOps uploads Apple profiles that use the newly configured CAs.
func (s *enterpriseIntegrationGitopsTestSuite) TestCAIntegrations() {
t := s.T()
dev_mode.IsEnabled = true
t.Cleanup(func() { dev_mode.IsEnabled = false })
user := s.createGitOpsUser(t)
fleetctlConfig := s.createFleetctlConfig(t, user)
@ -3505,6 +3501,7 @@ team_settings:
} else {
require.NoError(t, err)
}
})
}
}

View file

@ -4,10 +4,10 @@ import (
"context"
"errors"
"fmt"
"net/url"
"regexp"
"strings"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -208,8 +208,8 @@ func (svc *Service) validatePayload(p *fleet.CertificateAuthorityPayload, errPre
}
func (svc *Service) validateDigicert(ctx context.Context, digicertCA *fleet.DigiCertCA, errPrefix string) error {
if err := fleethttp.CheckURLForSSRF(ctx, digicertCA.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sDigiCert URL is invalid: %v", errPrefix, err))
if err := validateURL(digicertCA.URL, "DigiCert", errPrefix); err != nil {
return err
}
if digicertCA.APIToken == "" || digicertCA.APIToken == fleet.MaskedPassword {
return fleet.NewInvalidArgumentError("api_token", fmt.Sprintf("%sInvalid API token. Please correct and try again.", errPrefix))
@ -321,8 +321,8 @@ func (svc *Service) validateHydrant(ctx context.Context, hydrantCA *fleet.Hydran
if err := validateCAName(hydrantCA.Name, errPrefix); err != nil {
return err
}
if err := fleethttp.CheckURLForSSRF(ctx, hydrantCA.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sHydrant URL is invalid: %v", errPrefix, err))
if err := validateURL(hydrantCA.URL, "Hydrant", errPrefix); err != nil {
return err
}
if hydrantCA.ClientID == "" {
return fleet.NewInvalidArgumentError("client_id", fmt.Sprintf("%sInvalid Hydrant Client ID. Please correct and try again.", errPrefix))
@ -346,8 +346,8 @@ func (svc *Service) validateEST(ctx context.Context, estProxyCA *fleet.ESTProxyC
if err := validateCAName(estProxyCA.Name, errPrefix); err != nil {
return err
}
if err := fleethttp.CheckURLForSSRF(ctx, estProxyCA.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sEST URL is invalid: %v", errPrefix, err))
if err := validateURL(estProxyCA.URL, "EST", errPrefix); err != nil {
return err
}
if estProxyCA.Username == "" {
return fleet.NewInvalidArgumentError("username", fmt.Sprintf("%sInvalid EST Username. Please correct and try again.", errPrefix))
@ -361,12 +361,18 @@ func (svc *Service) validateEST(ctx context.Context, estProxyCA *fleet.ESTProxyC
return nil
}
func (svc *Service) validateNDESSCEPProxy(ctx context.Context, ndesSCEP *fleet.NDESSCEPProxyCA, errPrefix string) error {
if err := fleethttp.CheckURLForSSRF(ctx, ndesSCEP.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sNDES SCEP URL is invalid: %v", errPrefix, err))
func validateURL(caURL, displayType, errPrefix string) error {
if u, err := url.ParseRequestURI(caURL); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sInvalid %s URL. Please correct and try again.", errPrefix, displayType))
} else if u.Scheme != "https" && u.Scheme != "http" {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%s%s URL scheme must be https or http", errPrefix, displayType))
}
if err := fleethttp.CheckURLForSSRF(ctx, ndesSCEP.AdminURL, nil); err != nil {
return fleet.NewInvalidArgumentError("admin_url", fmt.Sprintf("%sNDES SCEP admin URL is invalid: %v", errPrefix, err))
return nil
}
func (svc *Service) validateNDESSCEPProxy(ctx context.Context, ndesSCEP *fleet.NDESSCEPProxyCA, errPrefix string) error {
if err := validateURL(ndesSCEP.URL, "NDES SCEP", errPrefix); err != nil {
return err
}
if err := svc.scepConfigService.ValidateSCEPURL(ctx, ndesSCEP.URL); err != nil {
level.Error(svc.logger).Log("msg", "Failed to validate NDES SCEP URL", "err", err)
@ -390,8 +396,8 @@ func (svc *Service) validateCustomSCEPProxy(ctx context.Context, customSCEP *fle
if err := validateCAName(customSCEP.Name, errPrefix); err != nil {
return err
}
if err := fleethttp.CheckURLForSSRF(ctx, customSCEP.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sCustom SCEP Proxy URL is invalid: %v", errPrefix, err))
if err := validateURL(customSCEP.URL, "SCEP", errPrefix); err != nil {
return err
}
if customSCEP.Challenge == "" || customSCEP.Challenge == fleet.MaskedPassword {
return fleet.NewInvalidArgumentError("challenge", fmt.Sprintf("%sCustom SCEP Proxy challenge cannot be empty", errPrefix))
@ -407,8 +413,8 @@ func (svc *Service) validateSmallstepSCEPProxy(ctx context.Context, smallstepSCE
if err := validateCAName(smallstepSCEP.Name, errPrefix); err != nil {
return err
}
if err := fleethttp.CheckURLForSSRF(ctx, smallstepSCEP.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sSmallstep SCEP URL is invalid: %v", errPrefix, err))
if err := validateURL(smallstepSCEP.URL, "Smallstep SCEP", errPrefix); err != nil {
return err
}
if smallstepSCEP.Username == "" {
return fleet.NewInvalidArgumentError("username", fmt.Sprintf("%sSmallstep username cannot be empty", errPrefix))
@ -1251,9 +1257,10 @@ func (svc *Service) validateDigicertUpdate(ctx context.Context, digicert *fleet.
}
}
if digicert.URL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *digicert.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sDigiCert URL is invalid: %v", errPrefix, err))
if err := validateURL(*digicert.URL, "DigiCert", errPrefix); err != nil {
return err
}
// We want to generate a DigiCertCA struct with all required fields to verify the new URL.
// If URL or APIToken are not being updated we use the existing values from oldCA
digicertCA := fleet.DigiCertCA{
@ -1331,9 +1338,10 @@ func (svc *Service) validateHydrantUpdate(ctx context.Context, hydrant *fleet.Hy
}
}
if hydrant.URL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *hydrant.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sHydrant URL is invalid: %v", errPrefix, err))
if err := validateURL(*hydrant.URL, "Hydrant", errPrefix); err != nil {
return err
}
hydrantCAToVerify := fleet.ESTProxyCA{ // The hydrant service for verification only requires the URL.
URL: *hydrant.URL,
}
@ -1362,9 +1370,10 @@ func (svc *Service) validateCustomESTUpdate(ctx context.Context, estUpdate *flee
}
}
if estUpdate.URL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *estUpdate.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sEST URL is invalid: %v", errPrefix, err))
if err := validateURL(*estUpdate.URL, "EST", errPrefix); err != nil {
return err
}
hydrantCAToVerify := fleet.ESTProxyCA{ // The EST service for verification only requires the URL.
URL: *estUpdate.URL,
}
@ -1390,8 +1399,8 @@ func (svc *Service) validateNDESSCEPProxyUpdate(ctx context.Context, ndesSCEP *f
// some methods in this fuction require the NDESSCEPProxyCA type so we convert the ndes update payload here
if ndesSCEP.URL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *ndesSCEP.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sNDES SCEP URL is invalid: %v", errPrefix, err))
if err := validateURL(*ndesSCEP.URL, "NDES SCEP", errPrefix); err != nil {
return err
}
if err := svc.scepConfigService.ValidateSCEPURL(ctx, *ndesSCEP.URL); err != nil {
level.Error(svc.logger).Log("msg", "Failed to validate NDES SCEP URL", "err", err)
@ -1405,9 +1414,6 @@ func (svc *Service) validateNDESSCEPProxyUpdate(ctx context.Context, ndesSCEP *f
}
}
if err := fleethttp.CheckURLForSSRF(ctx, *ndesSCEP.AdminURL, nil); err != nil {
return fleet.NewInvalidArgumentError("admin_url", fmt.Sprintf("%sNDES SCEP admin URL is invalid: %v", errPrefix, err))
}
// We want to generate a NDESSCEPProxyCA struct with all required fields to verify the admin URL.
// If URL, Username or Password are not being updated we use the existing values from oldCA
NDESProxy := fleet.NDESSCEPProxyCA{
@ -1451,8 +1457,8 @@ func (svc *Service) validateCustomSCEPProxyUpdate(ctx context.Context, customSCE
}
}
if customSCEP.URL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *customSCEP.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sCustom SCEP Proxy URL is invalid: %v", errPrefix, err))
if err := validateURL(*customSCEP.URL, "SCEP", errPrefix); err != nil {
return err
}
if err := svc.scepConfigService.ValidateSCEPURL(ctx, *customSCEP.URL); err != nil {
level.Error(svc.logger).Log("msg", "Failed to validate custom SCEP URL", "err", err)
@ -1475,8 +1481,8 @@ func (svc *Service) validateSmallstepSCEPProxyUpdate(ctx context.Context, smalls
}
}
if smallstep.URL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *smallstep.URL, nil); err != nil {
return fleet.NewInvalidArgumentError("url", fmt.Sprintf("%sSmallstep SCEP URL is invalid: %v", errPrefix, err))
if err := validateURL(*smallstep.URL, "SCEP", errPrefix); err != nil {
return err
}
if err := svc.scepConfigService.ValidateSCEPURL(ctx, *smallstep.URL); err != nil {
level.Error(svc.logger).Log("msg", "Failed to validate Smallstep SCEP URL", "err", err)
@ -1500,8 +1506,8 @@ func (svc *Service) validateSmallstepSCEPProxyUpdate(ctx context.Context, smalls
// Additional validation if url was updated
if smallstep.ChallengeURL != nil {
if err := fleethttp.CheckURLForSSRF(ctx, *smallstep.ChallengeURL, nil); err != nil {
return fleet.NewInvalidArgumentError("challenge_url", fmt.Sprintf("%sChallenge URL is invalid: %v", errPrefix, err))
if err := validateURL(*smallstep.ChallengeURL, "Challenge", errPrefix); err != nil {
return err
}
smallstepSCEPProxy.ChallengeURL = *smallstep.ChallengeURL
}

View file

@ -15,7 +15,6 @@ import (
"github.com/fleetdm/fleet/v4/ee/server/service/est"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/dev_mode"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
scep_mock "github.com/fleetdm/fleet/v4/server/mock/scep"
@ -145,10 +144,6 @@ func setupMockCAServers(t *testing.T) (digicertServer, hydrantServer *httptest.S
}
func TestCreatingCertificateAuthorities(t *testing.T) {
// Enable dev mode so CheckURLForSSRF skips the private-IP blocklist for the duration of this test.
dev_mode.IsEnabled = true
t.Cleanup(func() { dev_mode.IsEnabled = false })
digicertServer, hydrantServer := setupMockCAServers(t)
digicertURL := digicertServer.URL
hydrantURL := hydrantServer.URL
@ -539,7 +534,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
}
createdCA, err := svc.NewCertificateAuthority(ctx, createDigicertRequest)
require.ErrorContains(t, err, "DigiCert URL is invalid")
require.ErrorContains(t, err, "Invalid DigiCert URL")
require.Len(t, createdCAs, 0)
require.Nil(t, createdCA)
})
@ -680,7 +675,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
}
createdCA, err := svc.NewCertificateAuthority(ctx, createHydrantRequest)
require.ErrorContains(t, err, "Hydrant URL is invalid")
require.ErrorContains(t, err, "Invalid Hydrant URL.")
require.Len(t, createdCAs, 0)
require.Nil(t, createdCA)
})
@ -768,7 +763,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
}
createdCA, err := svc.NewCertificateAuthority(ctx, createCustomSCEPRequest)
require.ErrorContains(t, err, "Custom SCEP Proxy URL is invalid")
require.ErrorContains(t, err, "Invalid SCEP URL.")
require.Len(t, createdCAs, 0)
require.Nil(t, createdCA)
})
@ -824,7 +819,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
}
createdCA, err := svc.NewCertificateAuthority(ctx, createNDESSCEPRequest)
require.ErrorContains(t, err, "NDES SCEP URL is invalid")
require.ErrorContains(t, err, "Invalid NDES SCEP URL.")
require.Len(t, createdCAs, 0)
require.Nil(t, createdCA)
})
@ -984,7 +979,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
}
createdCA, err := svc.NewCertificateAuthority(ctx, createSmallstepRequest)
require.ErrorContains(t, err, "Smallstep SCEP URL is invalid")
require.ErrorContains(t, err, "Invalid Smallstep SCEP URL.")
require.Len(t, createdCAs, 0)
require.Nil(t, createdCA)
})
@ -1097,9 +1092,7 @@ func TestCreatingCertificateAuthorities(t *testing.T) {
}
func TestUpdatingCertificateAuthorities(t *testing.T) {
// Enable dev mode so CheckURLForSSRF skips the private-IP blocklist for the duration of this test.
dev_mode.IsEnabled = true
t.Cleanup(func() { dev_mode.IsEnabled = false })
t.Parallel()
digicertServer, hydrantServer := setupMockCAServers(t)
digicertURL := digicertServer.URL
@ -1367,7 +1360,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
}
err := svc.UpdateCertificateAuthority(ctx, digicertID, payload)
require.ErrorContains(t, err, "DigiCert URL is invalid")
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid DigiCert URL. Please correct and try again.")
})
t.Run("Bad URL Path", func(t *testing.T) {
@ -1509,7 +1502,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
}
err := svc.UpdateCertificateAuthority(ctx, hydrantID, payload)
require.ErrorContains(t, err, "Hydrant URL is invalid")
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid Hydrant URL. Please correct and try again.")
})
t.Run("Bad URL", func(t *testing.T) {
@ -1595,7 +1588,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
}
err := svc.UpdateCertificateAuthority(ctx, scepID, payload)
require.ErrorContains(t, err, "Custom SCEP Proxy URL is invalid")
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid SCEP URL. Please correct and try again.")
})
t.Run("Requires challenge when updating URL", func(t *testing.T) {
@ -1658,7 +1651,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
}
err := svc.UpdateCertificateAuthority(ctx, ndesID, payload)
require.ErrorContains(t, err, "NDES SCEP URL is invalid")
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid NDES SCEP URL. Please correct and try again.")
})
t.Run("Bad SCEP URL", func(t *testing.T) {
@ -1817,7 +1810,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
}
err := svc.UpdateCertificateAuthority(ctx, smallstepID, payload)
require.ErrorContains(t, err, "Smallstep SCEP URL is invalid")
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid SCEP URL. Please correct and try again.")
})
t.Run("Invalid Challenge URL format", func(t *testing.T) {
@ -1833,7 +1826,7 @@ func TestUpdatingCertificateAuthorities(t *testing.T) {
}
err := svc.UpdateCertificateAuthority(ctx, smallstepID, payload)
require.ErrorContains(t, err, "Challenge URL is invalid")
require.EqualError(t, err, "validation failed: url Couldn't edit certificate authority. Invalid Challenge URL. Please correct and try again.")
})
t.Run("Bad Smallstep SCEP URL", func(t *testing.T) {

View file

@ -513,7 +513,7 @@ func (s *SCEPConfigService) GetNDESSCEPChallenge(ctx context.Context, proxy flee
// Get the challenge from NDES
client := fleethttp.NewClient(fleethttp.WithTimeout(*s.Timeout))
client.Transport = ntlmssp.Negotiator{
RoundTripper: fleethttp.NewSSRFProtectedTransport(),
RoundTripper: fleethttp.NewTransport(),
}
req, err := http.NewRequest(http.MethodGet, adminURL, http.NoBody)
if err != nil {
@ -586,10 +586,8 @@ func (s *SCEPConfigService) ValidateSmallstepChallengeURL(ctx context.Context, c
}
func (s *SCEPConfigService) GetSmallstepSCEPChallenge(ctx context.Context, ca fleet.SmallstepSCEPProxyCA) (string, error) {
client := fleethttp.NewClient(
fleethttp.WithTimeout(30*time.Second),
fleethttp.WithTransport(fleethttp.NewSSRFProtectedTransport()),
)
// Get the challenge from Smallstep
client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second))
var reqBody bytes.Buffer
if err := json.NewEncoder(&reqBody).Encode(fleet.SmallstepChallengeRequestBody{
Webhook: fleet.SmallstepChallengeWebhook{

View file

@ -18,7 +18,6 @@ import (
type clientOpts struct {
timeout time.Duration
tlsConf *tls.Config
transport http.RoundTripper
noFollow bool
cookieJar http.CookieJar
}
@ -57,14 +56,6 @@ func WithCookieJar(jar http.CookieJar) ClientOpt {
}
}
// WithTransport sets an explicit RoundTripper on the HTTP client. When set,
// this takes precedence over WithTLSClientConfig.
func WithTransport(t http.RoundTripper) ClientOpt {
return func(o *clientOpts) {
o.transport = t
}
}
// NewClient returns an HTTP client configured according to the provided
// options.
func NewClient(opts ...ClientOpt) *http.Client {
@ -80,10 +71,7 @@ func NewClient(opts ...ClientOpt) *http.Client {
if co.noFollow {
cli.CheckRedirect = noFollowRedirect
}
switch {
case co.transport != nil:
cli.Transport = co.transport
case co.tlsConf != nil:
if co.tlsConf != nil {
cli.Transport = NewTransport(WithTLSConfig(co.tlsConf))
}
if co.cookieJar != nil {
@ -125,14 +113,6 @@ func NewTransport(opts ...TransportOpt) *http.Transport {
return tr
}
// Override DialContext with the SSRF-blocking dialer so that every
// outbound TCP connection is checked against the blocklist at dial time.
func NewSSRFProtectedTransport(opts ...TransportOpt) *http.Transport {
tr := NewTransport(opts...)
tr.DialContext = SSRFDialContext(nil, nil, nil)
return tr
}
func noFollowRedirect(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
}

View file

@ -66,6 +66,7 @@ func TestTransport(t *testing.T) {
assert.NotEqual(t, defaultTLSConf, tr.TLSClientConfig)
}
assert.NotNil(t, tr.Proxy)
assert.NotNil(t, tr.DialContext)
})
}
}

View file

@ -1,195 +0,0 @@
package fleethttp
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"github.com/fleetdm/fleet/v4/server/dev_mode"
)
var (
// https://en.wikipedia.org/wiki/Reserved_IP_addresses
ipv4Blocklist = []string{
"0.0.0.0/8", // Current network (only valid as source address)
"10.0.0.0/8", // Private network
"100.64.0.0/10", // Shared Address Space
"127.0.0.0/8", // Loopback
"169.254.0.0/16", // Link-local
"172.16.0.0/12", // Private network
"192.0.0.0/24", // IETF Protocol Assignments
"192.0.2.0/24", // TEST-NET-1, documentation and examples
"192.88.99.0/24", // IPv6 to IPv4 relay (includes 2002::/16)
"192.168.0.0/16", // Private network
"198.18.0.0/15", // Network benchmark tests
"198.51.100.0/24", // TEST-NET-2, documentation and examples
"203.0.113.0/24", // TEST-NET-3, documentation and examples
"224.0.0.0/4", // IP multicast (former Class D network)
"240.0.0.0/4", // Reserved (former Class E network)
"255.255.255.255/32", // Broadcast
}
ipv6Blocklist = []string{
"::1/128", // Loopback
"64:ff9b::/96", // IPv4/IPv6 translation (RFC 6052)
"64:ff9b:1::/48", // Local-use IPv4/IPv6 translation (RFC 8215)
"100::/64", // Discard prefix (RFC 6666)
"2001::/32", // Teredo tunneling
"2001:10::/28", // Deprecated (previously ORCHID)
"2001:20::/28", // ORCHIDv2
"2001:db8::/32", // Documentation and example source code
"2002::/16", // 6to4
"3fff::/20", // Documentation (RFC 9637, 2024)
"5f00::/16", // IPv6 Segment Routing (SRv6)
"fc00::/7", // Unique local address
"fe80::/10", // Link-local address
"ff00::/8", // Multicast
}
)
var blockedCIDRs []*net.IPNet
func init() {
for _, cidr := range append(ipv4Blocklist, ipv6Blocklist...) {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
panic(fmt.Sprintf("fleethttp: invalid blocked CIDR %q: %v", cidr, err))
}
blockedCIDRs = append(blockedCIDRs, network)
}
}
// isBlockedIP returns true when ip falls within any of the protected ranges.
func isBlockedIP(ip net.IP) bool {
if ip4 := ip.To4(); ip4 != nil && len(ip) == net.IPv6len {
ip = ip4
}
for _, network := range blockedCIDRs {
if network.Contains(ip) {
return true
}
}
return false
}
// SSRFError is returned when a URL resolves to a protected IP range.
type SSRFError struct {
URL string
IP net.IP
}
func (e *SSRFError) Error() string {
return fmt.Sprintf("URL %q resolves to a blocked address", e.URL)
}
func checkResolvedAddrs(ctx context.Context, host, rawURL string, resolver func(context.Context, string) ([]string, error)) ([]net.IP, error) {
addrs, err := resolver(ctx, host)
if err != nil {
return nil, fmt.Errorf("resolving host %q: %w", host, err)
}
if len(addrs) == 0 {
return nil, fmt.Errorf("host %q resolved to no addresses", host)
}
safe := make([]net.IP, 0, len(addrs))
for _, addr := range addrs {
h, _, err := net.SplitHostPort(addr)
if err != nil {
h = addr
}
ip := net.ParseIP(h)
if ip == nil {
return nil, fmt.Errorf("resolved address %q for host %q is not a valid IP", h, host)
}
if isBlockedIP(ip) {
return nil, &SSRFError{URL: rawURL, IP: ip}
}
safe = append(safe, ip)
}
return safe, nil
}
// CheckURLForSSRF validates rawURL against SSRF attack vectors using a static blocklist.
func CheckURLForSSRF(ctx context.Context, rawURL string, resolver func(ctx context.Context, host string) ([]string, error)) error {
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
scheme := parsed.Scheme
if scheme != "http" && scheme != "https" {
return fmt.Errorf("URL scheme %q is not allowed; must be http or https", scheme)
}
hostname := parsed.Hostname()
if hostname == "" {
return errors.New("URL has no host")
}
if dev_mode.IsEnabled {
return nil
}
if ip := net.ParseIP(hostname); ip != nil {
if isBlockedIP(ip) {
return &SSRFError{URL: rawURL, IP: ip}
}
return nil
}
if resolver == nil {
resolver = net.DefaultResolver.LookupHost
}
_, err = checkResolvedAddrs(ctx, hostname, rawURL, resolver)
return err
}
// SSRFDialContext returns a DialContext function that validates against SSRF attack vectors using a static blocklist.
func SSRFDialContext(
base *net.Dialer,
resolver func(ctx context.Context, host string) ([]string, error),
dial func(ctx context.Context, network, addr string) (net.Conn, error),
) func(ctx context.Context, network, addr string) (net.Conn, error) {
if base == nil {
base = &net.Dialer{}
}
if resolver == nil {
resolver = net.DefaultResolver.LookupHost
}
if dial == nil {
dial = base.DialContext
}
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if dev_mode.IsEnabled {
return dial(ctx, network, addr)
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("ssrf dial: splitting host/port from %q: %w", addr, err)
}
safeIPs, err := checkResolvedAddrs(ctx, host, net.JoinHostPort(host, port), resolver)
if err != nil {
return nil, err
}
// net.Dialer has no API to accept a pre-resolved IP list
// This is similar to what go does with dialSerial
// /usr/local/go/src/net/dial.go#dialSerial
var lastErr error
for _, ip := range safeIPs {
var conn net.Conn
conn, lastErr = dial(ctx, network, net.JoinHostPort(ip.String(), port))
if lastErr == nil {
return conn, nil
}
if ctx.Err() != nil {
// Context cancelled/timed out — no point trying remaining IPs.
return nil, lastErr
}
}
return nil, lastErr
}
}

View file

@ -1,301 +0,0 @@
package fleethttp
import (
"context"
"errors"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// noopResolver returns a known public IP so that CheckURLForSSRF reaches the
// IP-range check without triggering real DNS lookups.
func noopResolver(_ context.Context, _ string) ([]string, error) {
return []string{"93.184.216.34"}, nil // example.com
}
func TestCheckURLForSSRFBlockedLiteralIPs(t *testing.T) {
t.Parallel()
blocked := []string{
"http://127.0.0.1/mscep/mscep.dll",
"http://127.255.255.255/path",
"http://10.0.0.1/admin",
"http://10.255.255.255/admin",
"http://172.16.0.1/admin",
"http://172.31.255.255/admin",
"http://192.168.0.1/admin",
"http://192.168.255.255/admin",
"http://169.254.169.254/latest/meta-data/",
"http://169.254.0.1/whatever",
"http://100.64.0.1/admin",
"http://100.127.255.255/admin",
"http://0.0.0.0/path",
"http://[::1]/path",
"http://[fe80::1]/path",
"http://[fc00::1]/path",
"http://[fdff::1]/path",
}
for _, u := range blocked {
t.Run(u, func(t *testing.T) {
t.Parallel()
err := CheckURLForSSRF(context.Background(), u, noopResolver)
require.Error(t, err, "expected SSRF block for %s", u)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr), "expected SSRFError for %s, got %T: %v", u, err, err)
})
}
}
func TestCheckURLForSSRFAllowedPublicIPs(t *testing.T) {
t.Parallel()
allowed := []string{
"https://ndes.corp.example.com/mscep/mscep.dll",
"https://93.184.216.34/path", // example.com
"http://8.8.8.8/path", // Google DNS
"https://1.1.1.1/path", // Cloudflare DNS
}
for _, u := range allowed {
t.Run(u, func(t *testing.T) {
t.Parallel()
err := CheckURLForSSRF(context.Background(), u, noopResolver)
assert.NoError(t, err, "expected no SSRF block for %s", u)
})
}
}
func TestCheckURLForSSRFDNSResolutionBlocked(t *testing.T) {
t.Parallel()
// Simulate a hostname that resolves to a private IP
privateResolver := func(_ context.Context, _ string) ([]string, error) {
return []string{"192.168.1.100"}, nil
}
err := CheckURLForSSRF(context.Background(), "https://attacker-controlled.example.com/admin", privateResolver)
require.Error(t, err)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr))
assert.Equal(t, net.ParseIP("192.168.1.100").String(), ssrfErr.IP.String())
}
func TestCheckURLForSSRFMetadataEndpoints(t *testing.T) {
t.Parallel()
metadataURLs := []string{
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
"http://169.254.169.254/metadata/instance?api-version=2021-02-01",
}
for _, u := range metadataURLs {
t.Run(u, func(t *testing.T) {
t.Parallel()
err := CheckURLForSSRF(context.Background(), u, noopResolver)
require.Error(t, err)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr))
})
}
}
func TestCheckURLForSSRFBadScheme(t *testing.T) {
t.Parallel()
err := CheckURLForSSRF(context.Background(), "file:///etc/passwd", noopResolver)
require.Error(t, err)
var ssrfErr *SSRFError
assert.False(t, errors.As(err, &ssrfErr), "bad-scheme error should not be an SSRFError")
}
func TestCheckURLForSSRFResolverError(t *testing.T) {
t.Parallel()
failResolver := func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("simulated DNS failure")
}
err := CheckURLForSSRF(context.Background(), "https://cant-resolve.example.com/admin", failResolver)
require.Error(t, err)
assert.Contains(t, err.Error(), "resolving host")
}
func TestCheckURLForSSRF_UnparseableAddressFailsClosed(t *testing.T) {
t.Parallel()
// A custom resolver returning a non-IP string will be blocked
badResolver := func(_ context.Context, _ string) ([]string, error) {
return []string{"not-an-ip"}, nil
}
err := CheckURLForSSRF(context.Background(), "https://example.com/admin", badResolver)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a valid IP")
}
func TestCheckURLForSSRFMultipleResolutions(t *testing.T) {
t.Parallel()
mixedResolver := func(_ context.Context, _ string) ([]string, error) {
return []string{"93.184.216.34", "10.0.0.1"}, nil
}
err := CheckURLForSSRF(context.Background(), "https://mixed.example.com/admin", mixedResolver)
require.Error(t, err)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr))
assert.Equal(t, net.ParseIP("10.0.0.1").String(), ssrfErr.IP.String())
}
func TestCheckURLForSSRFIPv4MappedBypass(t *testing.T) {
t.Parallel()
// An attacker could supply an IPv4-mapped IPv6 address like ::ffff:192.168.1.1
// to reach a private IPv4 host while bypassing the IPv4 blocklist check.
blocked := []string{
"http://[::ffff:192.168.1.1]/admin", // RFC 1918 private
"http://[::ffff:127.0.0.1]/admin", // Loopback
"http://[::ffff:169.254.169.254]/admin", // Link-local metadata
"http://[::ffff:10.0.0.1]/admin", // RFC 1918 private
}
for _, u := range blocked {
t.Run(u, func(t *testing.T) {
t.Parallel()
err := CheckURLForSSRF(context.Background(), u, noopResolver)
require.Error(t, err, "expected SSRF block for IPv4-mapped %s", u)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr), "expected SSRFError for %s, got %T: %v", u, err, err)
})
}
}
func TestCheckURLForSSRFSSRFErrorMessage(t *testing.T) {
err := CheckURLForSSRF(context.Background(), "http://127.0.0.1/admin", noopResolver)
require.Error(t, err)
assert.Contains(t, err.Error(), "blocked address")
assert.Contains(t, err.Error(), "127.0.0.1")
}
// noopDial is used as the dial parameter so tests never open real sockets.
func noopDial(_ context.Context, _, _ string) (net.Conn, error) {
return nil, errors.New("no-op dial: connection not attempted in tests")
}
// captureDial records the addr passed to dial without opening a real socket.
func captureDial(got *string) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(_ context.Context, _, addr string) (net.Conn, error) {
*got = addr
return nil, errors.New("no-op dial: connection not attempted in tests")
}
}
// staticResolver returns a fixed list of IPs for any host.
func staticResolver(ips ...string) func(ctx context.Context, host string) ([]string, error) {
return func(_ context.Context, _ string) ([]string, error) {
return ips, nil
}
}
func TestSSRFDialContextBlocksPrivateIPs(t *testing.T) {
t.Parallel()
blocked := []struct {
addr string
ip string
}{
{"127.0.0.1:80", "127.0.0.1"},
{"10.0.0.1:443", "10.0.0.1"},
{"172.16.0.1:8080", "172.16.0.1"},
{"192.168.1.1:443", "192.168.1.1"},
{"169.254.169.254:80", "169.254.169.254"},
}
for _, tc := range blocked {
t.Run(tc.addr, func(t *testing.T) {
t.Parallel()
dial := SSRFDialContext(nil, staticResolver(tc.ip), noopDial)
conn, err := dial(context.Background(), "tcp", tc.addr)
require.Error(t, err, "expected dial to be blocked for %s", tc.addr)
assert.Nil(t, conn)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr), "expected SSRFError for %s, got %T: %v", tc.addr, err, err)
assert.Equal(t, net.ParseIP(tc.ip).String(), ssrfErr.IP.String())
})
}
}
func TestSSRFDialContextAllowsPublicIPs(t *testing.T) {
t.Parallel()
publicIPs := []string{
"93.184.216.34",
"8.8.8.8",
"1.1.1.1",
}
for _, publicIP := range publicIPs {
t.Run(publicIP, func(t *testing.T) {
t.Parallel()
dial := SSRFDialContext(nil, staticResolver(publicIP), noopDial)
_, err := dial(context.Background(), "tcp", publicIP+":80")
var ssrfErr *SSRFError
assert.False(t, errors.As(err, &ssrfErr), "public IP %s should not be SSRF-blocked", publicIP)
})
}
}
func TestSSRFDialContextBlocksMixedResolution(t *testing.T) {
t.Parallel()
// Simulates DNS rebinding: resolver returns one public and one private IP.
dial := SSRFDialContext(nil, staticResolver("93.184.216.34", "192.168.1.100"), noopDial)
_, err := dial(context.Background(), "tcp", "attacker.example.com:443")
require.Error(t, err)
var ssrfErr *SSRFError
assert.True(t, errors.As(err, &ssrfErr))
assert.Equal(t, net.ParseIP("192.168.1.100").String(), ssrfErr.IP.String())
}
func TestSSRFDialContextResolverError(t *testing.T) {
t.Parallel()
failResolver := func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("simulated DNS failure")
}
dial := SSRFDialContext(nil, failResolver, noopDial)
_, err := dial(context.Background(), "tcp", "cant-resolve.example.com:443")
require.Error(t, err)
assert.Contains(t, err.Error(), "resolving")
}
func TestSSRFDialContextDialsResolvedIP(t *testing.T) {
t.Parallel()
var gotAddr string
dialFn := SSRFDialContext(nil, staticResolver("93.184.216.34"), captureDial(&gotAddr))
_, _ = dialFn(context.Background(), "tcp", "example.com:443")
// The dialer must receive the resolved IP, not "example.com".
host, port, err := net.SplitHostPort(gotAddr)
require.NoError(t, err)
assert.Equal(t, "443", port)
assert.NotEmpty(t, host)
assert.NotEqual(t, "example.com", host, "dialer must receive resolved IP, not the hostname")
assert.Equal(t, net.ParseIP("93.184.216.34").String(), net.ParseIP(host).String())
}
func TestSSRFDialContextNilsUseDefaults(t *testing.T) {
t.Parallel()
dial := SSRFDialContext(nil, nil, nil)
require.NotNil(t, dial)
}
func TestNewSSRFProtectedTransportHasDialContext(t *testing.T) {
t.Parallel()
tr := NewSSRFProtectedTransport()
require.NotNil(t, tr.DialContext, "NewSSRFProtectedTransport() must set DialContext for SSRF protection")
}

View file

@ -16,7 +16,6 @@ import (
eeservice "github.com/fleetdm/fleet/v4/ee/server/service"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/dev_mode"
"github.com/fleetdm/fleet/v4/server/fleet"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/mdm/apple/mobileconfig"
@ -35,9 +34,6 @@ import (
func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
t := s.T()
dev_mode.IsEnabled = true
t.Cleanup(func() { dev_mode.IsEnabled = false })
// TODO(hca): test each CA type activities once implemented
// TODO(hca) test each CA type cannot configure without private key?
@ -233,7 +229,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
{
testName: "non-http",
url: "nonhttp://bad.com",
errMessage: "must be http or https",
errMessage: "URL scheme must be https or http",
},
}
@ -265,7 +261,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
res := s.Do("POST", "/api/v1/fleet/spec/certificate_authorities", req, http.StatusUnprocessableEntity)
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.ndes_scep_proxy")
require.Contains(t, errMsg, "NDES SCEP URL is invalid")
require.Contains(t, errMsg, "Invalid NDES SCEP URL")
checkNDESApplied(t, nil)
})
@ -436,7 +432,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.digicert")
if tc.errMessage == "Invalid URL" {
require.Contains(t, errMsg, "DigiCert URL is invalid")
require.Contains(t, errMsg, "Invalid DigiCert URL")
} else {
require.Contains(t, errMsg, tc.errMessage)
}
@ -783,7 +779,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.custom_scep_proxy")
if tc.errMessage == "Invalid URL" {
require.Contains(t, errMsg, "Custom SCEP Proxy URL is invalid")
require.Contains(t, errMsg, "Invalid SCEP URL")
} else {
require.Contains(t, errMsg, tc.errMessage)
}
@ -1008,7 +1004,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.smallstep")
if tc.errMessage == "Invalid URL" {
require.Contains(t, errMsg, "Smallstep SCEP URL is invalid")
require.Contains(t, errMsg, "Invalid Smallstep SCEP URL")
} else {
require.Contains(t, errMsg, tc.errMessage)
}
@ -1111,7 +1107,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
res := s.Do("POST", "/api/v1/fleet/spec/certificate_authorities", req, http.StatusUnprocessableEntity)
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.smallstep")
require.Contains(t, errMsg, "Smallstep SCEP URL is invalid")
require.Contains(t, errMsg, "Invalid Smallstep SCEP URL")
})
t.Run("smallstep challenge url not set", func(t *testing.T) {
@ -1268,7 +1264,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.hydrant")
if tc.errMessage == "Invalid URL" {
require.Contains(t, errMsg, "Hydrant URL is invalid")
require.Contains(t, errMsg, "Invalid Hydrant URL")
} else {
require.Contains(t, errMsg, tc.errMessage)
}
@ -1394,7 +1390,7 @@ func (s *integrationMDMTestSuite) TestBatchApplyCertificateAuthorities() {
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "certificate_authorities.custom_est_proxy")
if tc.errMessage == "Invalid URL" {
require.Contains(t, errMsg, "EST URL is invalid")
require.Contains(t, errMsg, "Invalid EST URL")
} else {
require.Contains(t, errMsg, tc.errMessage)
}