mirror of
https://github.com/fleetdm/fleet
synced 2026-04-22 05:57:36 +00:00
403 lines
12 KiB
Go
403 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"howett.net/plist"
|
|
)
|
|
|
|
type mdmProxy struct {
|
|
migrateUDIDs map[string]struct{}
|
|
migratePercentage int
|
|
existingServerURL string
|
|
existingHostname string
|
|
fleetServerURL string
|
|
existingProxy *httputil.ReverseProxy
|
|
fleetProxy *httputil.ReverseProxy
|
|
// mutex is used to sync reads/updates to the migrateUDIDs and migratePercentage
|
|
mutex sync.RWMutex
|
|
// token is used to authenticate updates to the migrateUDIDs and migratePercentage
|
|
token string
|
|
debug bool
|
|
logSkipped bool
|
|
}
|
|
|
|
func skipRequest(r *http.Request) bool {
|
|
// Throw out a bunch of common junk requests
|
|
return strings.Contains(r.URL.Path, ".php") ||
|
|
strings.Contains(r.URL.Path, ".git") ||
|
|
strings.Contains(r.URL.Path, ".yml") ||
|
|
strings.Contains(r.URL.Path, ".txt") ||
|
|
strings.Contains(r.URL.Path, ".py") ||
|
|
strings.Contains(r.URL.Path, "wp-") ||
|
|
strings.Contains(r.URL.Path, "private") ||
|
|
(r.URL.Path == "/" && r.Method == http.MethodPost)
|
|
}
|
|
|
|
func (m *mdmProxy) handleProxy(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Host != "" {
|
|
log.Printf("%s %s Forbidden", r.Method, r.URL.String())
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if skipRequest(r) {
|
|
if m.logSkipped {
|
|
log.Printf("Forbidden skipped request: %s %s", r.Method, r.URL.String())
|
|
}
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Send all SCEP requests to the existing server
|
|
if strings.Contains(r.URL.Path, "scep") {
|
|
log.Printf("%s %s -> Existing (SCEP)", r.Method, r.URL.String())
|
|
m.existingProxy.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Send all micromdm API requests to the existing server
|
|
if strings.HasPrefix(r.URL.Path, "/v1") || strings.HasPrefix(r.URL.Path, "/push") {
|
|
log.Printf("%s %s -> Existing (API)", r.Method, r.URL.String())
|
|
m.existingProxy.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/" && r.Method == http.MethodGet {
|
|
log.Printf("%s %s -> Existing (Home)", r.Method, r.URL.String())
|
|
m.existingProxy.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Send all micromdm repo requests to the existing server
|
|
if strings.HasPrefix(r.URL.Path, "/repo") {
|
|
log.Printf("%s %s -> Existing (Repo)", r.Method, r.URL.String())
|
|
m.existingProxy.ServeHTTP(w, r)
|
|
return
|
|
|
|
}
|
|
|
|
if !strings.HasPrefix(r.URL.Path, "/mdm") {
|
|
if m.logSkipped {
|
|
log.Printf("Forbidden non-mdm request: %s %s", r.Method, r.URL.String())
|
|
}
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Read the body of the request
|
|
body, err := io.ReadAll(r.Body)
|
|
_ = r.Body.Close()
|
|
if err != nil {
|
|
log.Println("Failed to read request body: ", err.Error())
|
|
http.Error(w, "Unable to read request body", http.StatusUnprocessableEntity)
|
|
return
|
|
}
|
|
// Reset body so that the reverse proxy request includes it
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
// Get the UDID from request
|
|
udid, err := udidFromRequestBody(body)
|
|
if err != nil {
|
|
log.Printf("%s %s Failed to get UDID: %v", r.Method, r.URL.String(), err)
|
|
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
|
|
return
|
|
}
|
|
|
|
// Migrated UDIDs go to the Fleet server, otherwise requests go to the existing server.
|
|
if udid != "" && m.isUDIDMigrated(udid) {
|
|
log.Printf("%s %s (%s) -> Fleet", r.Method, r.URL.String(), udid)
|
|
if m.debug {
|
|
log.Printf("Fleet request: %s", string(body))
|
|
}
|
|
m.fleetProxy.ServeHTTP(w, r)
|
|
} else {
|
|
log.Printf("%s %s (%s) -> Existing", r.Method, r.URL.String(), udid)
|
|
m.existingProxy.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
func (m *mdmProxy) handleUpdatePercentage(w http.ResponseWriter, r *http.Request) {
|
|
if m.token == "" {
|
|
http.Error(w, "Set auth token to enable remote updates", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Authorization header must be provided", http.StatusUnauthorized)
|
|
return
|
|
|
|
}
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
http.Error(w, "Authorization header must start with \"Bearer \"", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if authHeader != "Bearer "+m.token {
|
|
http.Error(w, "Authorization header does not match", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
defer r.Body.Close()
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to read body: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
percentage, err := strconv.Atoi(string(body))
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Cannot read body as integer: %v", err), http.StatusUnprocessableEntity)
|
|
return
|
|
}
|
|
if percentage < 0 || percentage > 100 {
|
|
http.Error(w, "Percentage should be in range (0, 100)", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
m.migratePercentage = percentage
|
|
|
|
msg := fmt.Sprintf("Migrate percentage updated: %v\n", percentage)
|
|
log.Print(msg)
|
|
fmt.Fprint(w, msg)
|
|
}
|
|
|
|
func (m *mdmProxy) handleUpdateMigrateUDIDs(w http.ResponseWriter, r *http.Request) {
|
|
if m.token == "" {
|
|
http.Error(w, "Set auth token to enable remote updates", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Authorization header must be provided", http.StatusUnauthorized)
|
|
return
|
|
|
|
}
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
http.Error(w, "Authorization header must start with \"Bearer \"", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if authHeader != "Bearer "+m.token {
|
|
http.Error(w, "Authorization header does not match", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
defer r.Body.Close()
|
|
udids, err := processUDIDs(r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
m.migrateUDIDs = udids
|
|
|
|
msg := fmt.Sprintf("Migrate UDIDs updated: %v\n", udids)
|
|
log.Print(msg)
|
|
fmt.Fprint(w, msg)
|
|
}
|
|
|
|
func processUDIDs(in io.Reader) (map[string]struct{}, error) {
|
|
scanner := bufio.NewScanner(in)
|
|
scanner.Split(bufio.ScanWords)
|
|
udids := make(map[string]struct{})
|
|
for scanner.Scan() {
|
|
udids[strings.TrimSpace(scanner.Text())] = struct{}{}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("Failed to scan UDIDs: %w", err)
|
|
}
|
|
return udids, nil
|
|
}
|
|
|
|
func (m *mdmProxy) isUDIDMigrated(udid string) bool {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
// If the UDID is manually included, it's always migrated
|
|
if _, ok := m.migrateUDIDs[udid]; ok {
|
|
return true
|
|
}
|
|
|
|
// Otherwise migrate by percentage
|
|
return udidIncludedByPercentage(udid, m.migratePercentage)
|
|
}
|
|
|
|
func udidFromRequestBody(body []byte) (string, error) {
|
|
// Not all requests (eg. SCEP) contain a UDID. Return empty without an error in this case.
|
|
if len(body) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
type mdmRequest struct {
|
|
UDID string `plist:""`
|
|
}
|
|
var req mdmRequest
|
|
_, err := plist.Unmarshal(body, &req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unmarshal request: %w body: %s", err, string(body))
|
|
}
|
|
if req.UDID == "" {
|
|
return "", errors.New("request body does not contain UDID")
|
|
}
|
|
|
|
return req.UDID, nil
|
|
}
|
|
|
|
func hashUDID(udid string) uint {
|
|
hash := fnv.New32a()
|
|
hash.Write([]byte(udid))
|
|
return uint(hash.Sum32())
|
|
}
|
|
|
|
func udidIncludedByPercentage(udid string, percentage int) bool {
|
|
index := hashUDID(udid) % 100
|
|
return int(index) < percentage //nolint:gosec // G115 false positive
|
|
}
|
|
|
|
func makeExistingProxy(existingURL, existingDNSName string) *httputil.ReverseProxy {
|
|
targetURL, err := url.Parse(existingURL)
|
|
if err != nil {
|
|
panic("failed to parse fleet-url: " + err.Error())
|
|
}
|
|
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
|
|
|
// Allow TLS validation to use the "old" server name
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.TLSClientConfig.ServerName = existingDNSName
|
|
proxy.Transport = transport
|
|
|
|
return proxy
|
|
}
|
|
|
|
func makeFleetProxy(fleetURL string, debug bool) *httputil.ReverseProxy {
|
|
targetURL, err := url.Parse(fleetURL)
|
|
if err != nil {
|
|
panic("failed to parse fleet-url: " + err.Error())
|
|
}
|
|
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
|
if debug {
|
|
proxy.ModifyResponse = func(r *http.Response) error {
|
|
b, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = r.Body.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.Body = io.NopCloser(bytes.NewReader(b))
|
|
|
|
log.Println("Fleet response: ", string(b))
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return proxy
|
|
}
|
|
|
|
func main() {
|
|
authToken := flag.String("auth-token", "", "Auth token for remote flag updates (remote updates disabled if not provided)")
|
|
existingURL := flag.String("existing-url", "", "Existing MDM server URL (full path) (required)")
|
|
existingHostname := flag.String("existing-hostname", "", "Hostname for existing MDM server (eg. 'mdm.example.com') (required)")
|
|
fleetURL := flag.String("fleet-url", "", "Fleet MDM server URL (full path) (required)")
|
|
migratePercentage := flag.Int("migrate-percentage", 0, "Percentage of clients to migrate from existing MDM to Fleet")
|
|
migrateUDIDs := flag.String("migrate-udids", "", "Space/newline-delimited list of UDIDs to migrate always")
|
|
serverAddr := flag.String("server-address", ":8080", "Address for server to listen on")
|
|
debug := flag.Bool("debug", false, "Enable debug logging")
|
|
logSkipped := flag.Bool("log-skipped", false, "Log skipped requests (usually from web scanners)")
|
|
check := flag.String("check", "", "Print whether the specified UDID is migrated with the current configuration, then exit")
|
|
flag.Parse()
|
|
|
|
udids, err := processUDIDs(bytes.NewBufferString(*migrateUDIDs))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
proxy := mdmProxy{
|
|
token: *authToken,
|
|
existingServerURL: *existingURL,
|
|
fleetServerURL: *fleetURL,
|
|
existingHostname: *existingHostname,
|
|
migratePercentage: *migratePercentage,
|
|
migrateUDIDs: udids,
|
|
existingProxy: makeExistingProxy(*existingURL, *existingHostname),
|
|
fleetProxy: makeFleetProxy(*fleetURL, *debug),
|
|
debug: *debug,
|
|
logSkipped: *logSkipped,
|
|
}
|
|
|
|
if len(*check) > 0 {
|
|
if proxy.isUDIDMigrated(*check) {
|
|
fmt.Printf("%s IS migrated\n", *check)
|
|
} else {
|
|
fmt.Printf("%s IS NOT migrated\n", *check)
|
|
}
|
|
os.Exit(0)
|
|
}
|
|
|
|
// Check required flags
|
|
if *existingURL == "" {
|
|
log.Fatal("--existing-url must be set")
|
|
}
|
|
if *existingHostname == "" {
|
|
log.Fatal("--existing-hostname must be set")
|
|
}
|
|
if *fleetURL == "" {
|
|
log.Fatal("--fleet-url must be set")
|
|
}
|
|
|
|
log.Printf("--migrate-udids set: %v", udids)
|
|
log.Printf("--migrate-percentage set: %d", *migratePercentage)
|
|
log.Printf("--existing-url set: %s", *existingURL)
|
|
log.Printf("--existing-hostname set: %s", *existingHostname)
|
|
log.Printf("--fleet-url set: %s", *fleetURL)
|
|
log.Printf("--debug set: %v", *debug)
|
|
log.Printf("--log-skipped set: %v", *logSkipped)
|
|
if *authToken != "" {
|
|
log.Printf("--auth-token set. Remote configuration enabled.")
|
|
} else {
|
|
log.Printf("--auth-token is empty. Remote configuration disabled.")
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
// Health check endpoint used for load balancers
|
|
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("OK"))
|
|
if err != nil {
|
|
log.Printf("/healthz error: %v", err)
|
|
}
|
|
})
|
|
// Remote management of migration (enabled if auth token set)
|
|
mux.HandleFunc("/admin/udids", proxy.handleUpdateMigrateUDIDs)
|
|
mux.HandleFunc("/admin/percentage", proxy.handleUpdatePercentage)
|
|
// Handler for the actual proxying
|
|
mux.HandleFunc("/", proxy.handleProxy)
|
|
|
|
log.Printf("Starting server on %s", *serverAddr)
|
|
server := &http.Server{
|
|
Addr: *serverAddr,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
Handler: mux,
|
|
}
|
|
err = server.ListenAndServe()
|
|
if err != nil {
|
|
fmt.Printf("Error starting server: %s\n", err)
|
|
}
|
|
}
|