diff --git a/changes/add-trusted-proxies-config b/changes/add-trusted-proxies-config new file mode 100644 index 0000000000..815b8e10f2 --- /dev/null +++ b/changes/add-trusted-proxies-config @@ -0,0 +1 @@ +- Added the FLEET_SERVER_TRUSTED_PROXIES configuration. \ No newline at end of file diff --git a/go.mod b/go.mod index 1ee14ea096..5b0b6193b7 100644 --- a/go.mod +++ b/go.mod @@ -116,6 +116,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 github.com/prometheus/client_golang v1.21.1 github.com/quasilyte/go-ruleguard/dsl v0.3.22 + github.com/realclientip/realclientip-go v1.0.0 github.com/remitly-oss/httpsig-go v1.2.0 github.com/rs/zerolog v1.32.0 github.com/russellhaering/goxmldsig v1.4.0 diff --git a/go.sum b/go.sum index 6c71ac1182..d88b1edf90 100644 --- a/go.sum +++ b/go.sum @@ -750,6 +750,8 @@ github.com/quasilyte/go-ruleguard/dsl v0.3.22 h1:wd8zkOhSNr+I+8Qeciml08ivDt1pSXe github.com/quasilyte/go-ruleguard/dsl v0.3.22/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/realclientip/realclientip-go v1.0.0 h1:+yPxeC0mEaJzq1BfCt2h4BxlyrvIIBzR6suDc3BEF1U= +github.com/realclientip/realclientip-go v1.0.0/go.mod h1:CXnUdVwFRcXFJIRb/dTYqbT7ud48+Pi2pFm80bxDmcI= github.com/remitly-oss/httpsig-go v1.2.0 h1:rI634TJkh+US3qkWQfkJ7VDJgCvlIbyEepsEw+37W50= github.com/remitly-oss/httpsig-go v1.2.0/go.mod h1:HYfozYlK9Zv9GYyw+eIuXugk1OV2kjowVrvdv0KQ4XU= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= diff --git a/server/config/config.go b/server/config/config.go index f2dd2964b7..b4f397648a 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -117,6 +117,7 @@ type ServerConfig struct { VPPVerifyRequestDelay time.Duration `yaml:"vpp_verify_request_delay"` CleanupDistTargetsAge time.Duration `yaml:"cleanup_dist_targets_age"` MaxInstallerSizeBytes int64 `yaml:"max_installer_size"` + TrustedProxies string `yaml:"trusted_proxies"` } func (s *ServerConfig) DefaultHTTPServer(ctx context.Context, handler http.Handler) *http.Server { @@ -1179,6 +1180,8 @@ func (man Manager) addConfigs() { man.addConfigDuration("server.vpp_verify_request_delay", 5*time.Second, "Delay in between requests to verify VPP app installs") man.addConfigDuration("server.cleanup_dist_targets_age", 24*time.Hour, "Specifies the cleanup age for completed live query distributed targets.") man.addConfigByteSize("server.max_installer_size", installersize.Human(installersize.DefaultMaxInstallerSize), "Maximum size in bytes for software installer uploads (e.g. 10GiB, 500MB, 1G)") + man.addConfigString("server.trusted_proxies", "", + "Trusted proxy configuration for client IP extraction: 'none' (RemoteAddr only), a header name (e.g., 'True-Client-IP'), a hop count (e.g., '2'), or comma-separated IP/CIDR ranges") // Hide the sandbox flag as we don't want it to be discoverable for users for now man.hideConfig("server.sandbox_enabled") @@ -1643,6 +1646,7 @@ func (man Manager) LoadConfig() FleetConfig { VPPVerifyRequestDelay: man.getConfigDuration("server.vpp_verify_request_delay"), CleanupDistTargetsAge: man.getConfigDuration("server.cleanup_dist_targets_age"), MaxInstallerSizeBytes: man.getConfigByteSize("server.max_installer_size"), + TrustedProxies: man.getConfigString("server.trusted_proxies"), }, Auth: AuthConfig{ BcryptCost: man.getConfigInt("auth.bcrypt_cost"), diff --git a/server/platform/endpointer/clientip.go b/server/platform/endpointer/clientip.go new file mode 100644 index 0000000000..54f4d3efb2 --- /dev/null +++ b/server/platform/endpointer/clientip.go @@ -0,0 +1,85 @@ +package endpointer + +import ( + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/realclientip/realclientip-go" +) + +// NewClientIPStrategy creates a ClientIPStrategy based on the trusted_proxies configuration. +// +// Config values: +// - "" (empty): Legacy behavior for backwards compatibility - trusts True-Client-IP, +// X-Real-IP, and leftmost X-Forwarded-For. This is deprecated; use "none" when +// exposing the server directly to the internet. +// - "none": Ignores all headers, uses only RemoteAddr. +// - A header name prefixed with `header:` (e.g., "header:True-Client-IP"): +// Trust this single-IP header, fall back to RemoteAddr. +// - A number (e.g., "2"): Trust X-Forwarded-For with this many proxy hops +// - Comma-separated IPs/CIDRs (e.g., "10.0.0.0/8,192.168.0.0/16"): +// Trust X-Forwarded-For from requests originating from these proxy ranges. +func NewClientIPStrategy(trustedProxies string) (realclientip.Strategy, error) { + trustedProxies = strings.TrimSpace(trustedProxies) + + var strategy realclientip.Strategy + var err error + + if trustedProxies == "" { + // Empty: legacy behavior for backwards compatibility. + return &legacyStrategy{}, nil + } else if strings.EqualFold(trustedProxies, "none") { + // "none": Trust no one; return (non-spoofable) RemoteAddr only. + return realclientip.RemoteAddrStrategy{}, nil + } else if headerName, ok := strings.CutPrefix(trustedProxies, "header:"); ok { + // Check if the value is a single IP header name. + strategy, err = realclientip.NewSingleIPHeaderStrategy(headerName) + if err != nil { + return nil, fmt.Errorf("invalid header name %q: %w", trustedProxies, err) + } + } else if hopCount, err := strconv.Atoi(trustedProxies); err == nil { + // Check if it's a number (hop count). + if hopCount < 1 { + return nil, fmt.Errorf("trusted_proxies hop count must be >= 1, got %d", hopCount) + } + strategy, err = realclientip.NewRightmostTrustedCountStrategy("X-Forwarded-For", hopCount) + if err != nil { + return nil, fmt.Errorf("failed to create hop count strategy: %w", err) + } + } else { + // Otherwise, parse as comma-separated IP ranges. + rangeStrs := strings.Split(trustedProxies, ",") + for i := range rangeStrs { + rangeStrs[i] = strings.TrimSpace(rangeStrs[i]) + } + + trustedRanges, err := realclientip.AddressesAndRangesToIPNets(rangeStrs...) + if err != nil { + return nil, fmt.Errorf("invalid trusted_proxies IP ranges: %w", err) + } + + strategy, err = realclientip.NewRightmostTrustedRangeStrategy("X-Forwarded-For", trustedRanges) + if err != nil { + return nil, fmt.Errorf("failed to create IP range strategy: %w", err) + } + } + + // Chain strategy with RemoteAddr as fallback. + return realclientip.NewChainStrategy(strategy, realclientip.RemoteAddrStrategy{}), nil +} + +// legacyStrategy implements the original ExtractIP behavior for backwards compatibility. +// This is deprecated; if your server is exposed directly to the internet, switch to +// the "none" strategy. +type legacyStrategy struct{} + +func (s *legacyStrategy) ClientIP(headers http.Header, remoteAddr string) string { + // Build a minimal http.Request to pass to extractIP + r := &http.Request{ + Header: headers, + RemoteAddr: remoteAddr, + } + return extractIP(r) +} diff --git a/server/platform/endpointer/clientip_test.go b/server/platform/endpointer/clientip_test.go new file mode 100644 index 0000000000..c736ee543b --- /dev/null +++ b/server/platform/endpointer/clientip_test.go @@ -0,0 +1,350 @@ +package endpointer + +import ( + "net/http" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper to create headers with proper canonicalization. +// Pass in pairs of header name and value. For example: +// makeHeaders("X-Forwarded-For", "1.1.1.1", "X-Real-IP", "2.2.2.2") +func makeHeaders(kvs ...string) http.Header { + h := http.Header{} + for i := 0; i < len(kvs); i += 2 { + h.Set(kvs[i], kvs[i+1]) + } + return h +} + +func TestNewClientIPStrategy(t *testing.T) { + tests := []struct { + name string + trustedProxies string + wantErr bool + errContains string + }{ + { + name: "empty uses legacy strategy", + trustedProxies: "", + wantErr: false, + }, + { + name: "none uses RemoteAddr strategy", + trustedProxies: "none", + wantErr: false, + }, + { + name: "None (case insensitive)", + trustedProxies: "None", + wantErr: false, + }, + { + name: "NONE (case insensitive)", + trustedProxies: "NONE", + wantErr: false, + }, + { + name: "True-Client-IP header", + trustedProxies: "header:True-Client-IP", + wantErr: false, + }, + { + name: "X-Real-IP header", + trustedProxies: "header:X-Real-IP", + wantErr: false, + }, + { + name: "CF-Connecting-IP header", + trustedProxies: "header:CF-Connecting-IP", + wantErr: false, + }, + { + name: "X-forwarded-for header", + trustedProxies: "header:X-Forwarded-For", + // This is not a valid single-IP header value + wantErr: true, + }, + { + name: "Forwarded header", + trustedProxies: "header:Forwarded", + // This is not a valid single-IP header value + wantErr: true, + }, + { + name: "hop count 1", + trustedProxies: "1", + wantErr: false, + }, + { + name: "hop count 2", + trustedProxies: "2", + wantErr: false, + }, + { + name: "hop count 0 is invalid", + trustedProxies: "0", + wantErr: true, + errContains: "hop count must be >= 1", + }, + { + name: "single IP range", + trustedProxies: "10.0.0.0/8", + wantErr: false, + }, + { + name: "multiple IP ranges", + trustedProxies: "10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12", + wantErr: false, + }, + { + name: "single IP address", + trustedProxies: "192.168.1.1", + wantErr: false, + }, + { + name: "invalid IP range", + trustedProxies: "not-an-ip", + wantErr: true, + errContains: "invalid trusted_proxies", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + strategy, err := NewClientIPStrategy(tt.trustedProxies) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + require.NoError(t, err) + require.NotNil(t, strategy) + }) + } +} + +func TestClientIPStrategy_Legacy(t *testing.T) { + strategy, err := NewClientIPStrategy("") + require.NoError(t, err) + + tests := []struct { + name string + headers http.Header + remoteAddr string + wantIP string + }{ + { + name: "uses True-Client-IP first", + headers: makeHeaders("True-Client-IP", "1.1.1.1", "X-Real-IP", "2.2.2.2", "X-Forwarded-For", "3.3.3.3, 4.4.4.4"), + remoteAddr: "9.9.9.9:12345", + wantIP: "1.1.1.1", + }, + { + name: "uses X-Real-IP second", + headers: makeHeaders("X-Real-IP", "2.2.2.2", "X-Forwarded-For", "3.3.3.3, 4.4.4.4"), + remoteAddr: "9.9.9.9:12345", + wantIP: "2.2.2.2", + }, + { + name: "uses leftmost X-Forwarded-For third", + headers: makeHeaders("X-Forwarded-For", "3.3.3.3, 4.4.4.4"), + remoteAddr: "9.9.9.9:12345", + wantIP: "3.3.3.3", + }, + { + name: "falls back to RemoteAddr", + headers: http.Header{}, + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := strategy.ClientIP(tt.headers, tt.remoteAddr) + assert.Equal(t, tt.wantIP, ip) + }) + } +} + +func TestClientIPStrategy_None(t *testing.T) { + strategy, err := NewClientIPStrategy("none") + require.NoError(t, err) + + tests := []struct { + name string + headers http.Header + remoteAddr string + wantIP string + }{ + { + name: "ignores True-Client-IP", + headers: makeHeaders("True-Client-IP", "1.1.1.1"), + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + { + name: "ignores X-Real-IP", + headers: makeHeaders("X-Real-IP", "2.2.2.2"), + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + { + name: "ignores X-Forwarded-For", + headers: makeHeaders("X-Forwarded-For", "3.3.3.3, 4.4.4.4"), + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + { + name: "uses RemoteAddr only", + headers: http.Header{}, + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := strategy.ClientIP(tt.headers, tt.remoteAddr) + assert.Equal(t, tt.wantIP, ip) + }) + } +} + +func TestClientIPStrategy_SingleIPHeader(t *testing.T) { + strategy, err := NewClientIPStrategy("header:True-Client-IP") + require.NoError(t, err) + + tests := []struct { + name string + headers http.Header + remoteAddr string + wantIP string + }{ + { + name: "uses True-Client-IP when present", + headers: makeHeaders("True-Client-IP", "1.1.1.1"), + remoteAddr: "9.9.9.9:12345", + wantIP: "1.1.1.1", + }, + { + name: "falls back to RemoteAddr when header missing", + headers: http.Header{}, + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + { + name: "ignores X-Forwarded-For", + headers: makeHeaders("X-Forwarded-For", "3.3.3.3"), + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := strategy.ClientIP(tt.headers, tt.remoteAddr) + assert.Equal(t, tt.wantIP, ip) + }) + } +} + +func TestClientIPStrategy_HopCount(t *testing.T) { + tests := []struct { + name string + hops int + headers http.Header + remoteAddr string + wantIP string + }{ + { + name: "extracts correct IP with 2 hops", + hops: 2, + headers: makeHeaders("X-Forwarded-For", "1.1.1.1, 2.2.2.2, 3.3.3.3"), + remoteAddr: "9.9.9.9:12345", + wantIP: "2.2.2.2", + }, + { + name: "extracts correct IP with 1 hops", + hops: 1, + headers: makeHeaders("X-Forwarded-For", "1.1.1.1, 2.2.2.2, 3.3.3.3"), + remoteAddr: "9.9.9.9:12345", + wantIP: "3.3.3.3", + }, + { + name: "falls back to RemoteAddr when header missing", + hops: 2, + headers: http.Header{}, + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + { + name: "falls back to RemoteAddr when hops > header length", + hops: 2, + headers: makeHeaders("X-Forwarded-For", "1.1.1.1"), + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + strategy, err := NewClientIPStrategy(strconv.Itoa(tt.hops)) + require.NoError(t, err) + + ip := strategy.ClientIP(tt.headers, tt.remoteAddr) + assert.Equal(t, tt.wantIP, ip) + }) + } +} + +func TestClientIPStrategy_IPRanges(t *testing.T) { + // Trust private IP ranges + strategy, err := NewClientIPStrategy("10.0.0.0/8, 192.168.0.0/16") + require.NoError(t, err) + + tests := []struct { + name string + headers http.Header + remoteAddr string + wantIP string + }{ + { + name: "extracts client IP skipping trusted proxies", + headers: makeHeaders("X-Forwarded-For", "1.1.1.1, 10.0.0.5, 192.168.1.1"), + remoteAddr: "10.0.0.1:12345", + wantIP: "1.1.1.1", + }, + { + name: "returns rightmost non-trusted IP", + headers: makeHeaders("X-Forwarded-For", "8.8.8.8, 1.1.1.1, 10.0.0.5"), + remoteAddr: "10.0.0.1:12345", + wantIP: "1.1.1.1", + }, + { + name: "returns RemoteAddr when all IPs are trusted", + headers: makeHeaders("X-Forwarded-For", "192.168.0.1, 10.0.0.5"), + remoteAddr: "99.99.99.99:12345", + wantIP: "99.99.99.99", + }, + { + name: "falls back to RemoteAddr when header missing", + headers: http.Header{}, + remoteAddr: "9.9.9.9:12345", + wantIP: "9.9.9.9", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := strategy.ClientIP(tt.headers, tt.remoteAddr) + assert.Equal(t, tt.wantIP, ip) + }) + } +} diff --git a/server/platform/endpointer/endpoint_utils.go b/server/platform/endpointer/endpoint_utils.go index f86453859b..5592bddd25 100644 --- a/server/platform/endpointer/endpoint_utils.go +++ b/server/platform/endpointer/endpoint_utils.go @@ -246,7 +246,7 @@ var ( xRealIP = http.CanonicalHeaderKey("X-Real-IP") ) -func ExtractIP(r *http.Request) string { +func extractIP(r *http.Request) string { ip := r.RemoteAddr if i := strings.LastIndexByte(ip, ':'); i != -1 { ip = ip[:i] diff --git a/server/service/handler.go b/server/service/handler.go index 35ea068167..a9de416d59 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -104,6 +104,12 @@ func MakeHandler( fn(&eopts) } + // Create the client IP extraction strategy based on config. + ipStrategy, err := endpointer.NewClientIPStrategy(config.Server.TrustedProxies) + if err != nil { + panic(fmt.Sprintf("invalid server.trusted_proxies configuration: %v", err)) + } + fleetAPIOptions := []kithttp.ServerOption{ kithttp.ServerBefore( kithttp.PopulateRequestContext, // populate the request context with common fields @@ -133,7 +139,17 @@ func MakeHandler( } } - r.Use(publicIP) + // Add middleware to extract the client IP and set it in the request context. + r.Use(func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := ipStrategy.ClientIP(r.Header, r.RemoteAddr) + if ip != "" { + r.RemoteAddr = ip + } + handler.ServeHTTP(w, r.WithContext(publicip.NewContext(r.Context(), ip))) + }) + }) + if eopts.httpSigVerifier != nil { r.Use(eopts.httpSigVerifier) } @@ -147,16 +163,6 @@ func MakeHandler( return r } -func publicIP(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := endpointer.ExtractIP(r) - if ip != "" { - r.RemoteAddr = ip - } - handler.ServeHTTP(w, r.WithContext(publicip.NewContext(r.Context(), ip))) - }) -} - // PrometheusMetricsHandler wraps the provided handler with prometheus metrics // middleware and returns the resulting handler that should be mounted for that // route.