fleet/tools/mdm/migration/mdmproxy/mdmproxy.go
Victor Lyuboslavsky e2d9a9016c
Add gosimple linter (#23250)
#23249

Add gosimple linter to golangci-lint CI job.
2024-10-29 14:17:51 -05:00

403 lines
12 KiB
Go

package main
import (
"bufio"
"bytes"
"errors"
"flag"
"fmt"
"hash/fnv"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"howett.net/plist"
)
type mdmProxy struct {
migrateUDIDs map[string]struct{}
migratePercentage int
existingServerURL string
existingHostname string
fleetServerURL string
existingProxy *httputil.ReverseProxy
fleetProxy *httputil.ReverseProxy
// mutex is used to sync reads/updates to the migrateUDIDs and migratePercentage
mutex sync.RWMutex
// token is used to authenticate updates to the migrateUDIDs and migratePercentage
token string
debug bool
logSkipped bool
}
func skipRequest(r *http.Request) bool {
// Throw out a bunch of common junk requests
return strings.Contains(r.URL.Path, ".php") ||
strings.Contains(r.URL.Path, ".git") ||
strings.Contains(r.URL.Path, ".yml") ||
strings.Contains(r.URL.Path, ".txt") ||
strings.Contains(r.URL.Path, ".py") ||
strings.Contains(r.URL.Path, "wp-") ||
strings.Contains(r.URL.Path, "private") ||
(r.URL.Path == "/" && r.Method == http.MethodPost)
}
func (m *mdmProxy) handleProxy(w http.ResponseWriter, r *http.Request) {
if r.URL.Host != "" {
log.Printf("%s %s Forbidden", r.Method, r.URL.String())
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
if skipRequest(r) {
if m.logSkipped {
log.Printf("Forbidden skipped request: %s %s", r.Method, r.URL.String())
}
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Send all SCEP requests to the existing server
if strings.Contains(r.URL.Path, "scep") {
log.Printf("%s %s -> Existing (SCEP)", r.Method, r.URL.String())
m.existingProxy.ServeHTTP(w, r)
return
}
// Send all micromdm API requests to the existing server
if strings.HasPrefix(r.URL.Path, "/v1") || strings.HasPrefix(r.URL.Path, "/push") {
log.Printf("%s %s -> Existing (API)", r.Method, r.URL.String())
m.existingProxy.ServeHTTP(w, r)
return
}
if r.URL.Path == "/" && r.Method == http.MethodGet {
log.Printf("%s %s -> Existing (Home)", r.Method, r.URL.String())
m.existingProxy.ServeHTTP(w, r)
return
}
// Send all micromdm repo requests to the existing server
if strings.HasPrefix(r.URL.Path, "/repo") {
log.Printf("%s %s -> Existing (Repo)", r.Method, r.URL.String())
m.existingProxy.ServeHTTP(w, r)
return
}
if !strings.HasPrefix(r.URL.Path, "/mdm") {
if m.logSkipped {
log.Printf("Forbidden non-mdm request: %s %s", r.Method, r.URL.String())
}
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Read the body of the request
body, err := io.ReadAll(r.Body)
_ = r.Body.Close()
if err != nil {
log.Println("Failed to read request body: ", err.Error())
http.Error(w, "Unable to read request body", http.StatusUnprocessableEntity)
return
}
// Reset body so that the reverse proxy request includes it
r.Body = io.NopCloser(bytes.NewReader(body))
// Get the UDID from request
udid, err := udidFromRequestBody(body)
if err != nil {
log.Printf("%s %s Failed to get UDID: %v", r.Method, r.URL.String(), err)
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// Migrated UDIDs go to the Fleet server, otherwise requests go to the existing server.
if udid != "" && m.isUDIDMigrated(udid) {
log.Printf("%s %s (%s) -> Fleet", r.Method, r.URL.String(), udid)
if m.debug {
log.Printf("Fleet request: %s", string(body))
}
m.fleetProxy.ServeHTTP(w, r)
} else {
log.Printf("%s %s (%s) -> Existing", r.Method, r.URL.String(), udid)
m.existingProxy.ServeHTTP(w, r)
}
}
func (m *mdmProxy) handleUpdatePercentage(w http.ResponseWriter, r *http.Request) {
if m.token == "" {
http.Error(w, "Set auth token to enable remote updates", http.StatusUnauthorized)
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header must be provided", http.StatusUnauthorized)
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Authorization header must start with \"Bearer \"", http.StatusUnauthorized)
return
}
if authHeader != "Bearer "+m.token {
http.Error(w, "Authorization header does not match", http.StatusUnauthorized)
return
}
body, err := io.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read body: %v", err), http.StatusInternalServerError)
return
}
percentage, err := strconv.Atoi(string(body))
if err != nil {
http.Error(w, fmt.Sprintf("Cannot read body as integer: %v", err), http.StatusUnprocessableEntity)
return
}
if percentage < 0 || percentage > 100 {
http.Error(w, "Percentage should be in range (0, 100)", http.StatusBadRequest)
return
}
m.mutex.Lock()
defer m.mutex.Unlock()
m.migratePercentage = percentage
msg := fmt.Sprintf("Migrate percentage updated: %v\n", percentage)
log.Print(msg)
fmt.Fprint(w, msg)
}
func (m *mdmProxy) handleUpdateMigrateUDIDs(w http.ResponseWriter, r *http.Request) {
if m.token == "" {
http.Error(w, "Set auth token to enable remote updates", http.StatusUnauthorized)
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header must be provided", http.StatusUnauthorized)
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Authorization header must start with \"Bearer \"", http.StatusUnauthorized)
return
}
if authHeader != "Bearer "+m.token {
http.Error(w, "Authorization header does not match", http.StatusUnauthorized)
return
}
defer r.Body.Close()
udids, err := processUDIDs(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mutex.Lock()
defer m.mutex.Unlock()
m.migrateUDIDs = udids
msg := fmt.Sprintf("Migrate UDIDs updated: %v\n", udids)
log.Print(msg)
fmt.Fprint(w, msg)
}
func processUDIDs(in io.Reader) (map[string]struct{}, error) {
scanner := bufio.NewScanner(in)
scanner.Split(bufio.ScanWords)
udids := make(map[string]struct{})
for scanner.Scan() {
udids[strings.TrimSpace(scanner.Text())] = struct{}{}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("Failed to scan UDIDs: %w", err)
}
return udids, nil
}
func (m *mdmProxy) isUDIDMigrated(udid string) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
// If the UDID is manually included, it's always migrated
if _, ok := m.migrateUDIDs[udid]; ok {
return true
}
// Otherwise migrate by percentage
return udidIncludedByPercentage(udid, m.migratePercentage)
}
func udidFromRequestBody(body []byte) (string, error) {
// Not all requests (eg. SCEP) contain a UDID. Return empty without an error in this case.
if len(body) == 0 {
return "", nil
}
type mdmRequest struct {
UDID string `plist:""`
}
var req mdmRequest
_, err := plist.Unmarshal(body, &req)
if err != nil {
return "", fmt.Errorf("unmarshal request: %w body: %s", err, string(body))
}
if req.UDID == "" {
return "", errors.New("request body does not contain UDID")
}
return req.UDID, nil
}
func hashUDID(udid string) uint {
hash := fnv.New32a()
hash.Write([]byte(udid))
return uint(hash.Sum32())
}
func udidIncludedByPercentage(udid string, percentage int) bool {
index := hashUDID(udid) % 100
return int(index) < percentage //nolint:gosec // G115 false positive
}
func makeExistingProxy(existingURL, existingDNSName string) *httputil.ReverseProxy {
targetURL, err := url.Parse(existingURL)
if err != nil {
panic("failed to parse fleet-url: " + err.Error())
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Allow TLS validation to use the "old" server name
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig.ServerName = existingDNSName
proxy.Transport = transport
return proxy
}
func makeFleetProxy(fleetURL string, debug bool) *httputil.ReverseProxy {
targetURL, err := url.Parse(fleetURL)
if err != nil {
panic("failed to parse fleet-url: " + err.Error())
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
if debug {
proxy.ModifyResponse = func(r *http.Response) error {
b, err := io.ReadAll(r.Body)
if err != nil {
return err
}
err = r.Body.Close()
if err != nil {
return err
}
r.Body = io.NopCloser(bytes.NewReader(b))
log.Println("Fleet response: ", string(b))
return nil
}
}
return proxy
}
func main() {
authToken := flag.String("auth-token", "", "Auth token for remote flag updates (remote updates disabled if not provided)")
existingURL := flag.String("existing-url", "", "Existing MDM server URL (full path) (required)")
existingHostname := flag.String("existing-hostname", "", "Hostname for existing MDM server (eg. 'mdm.example.com') (required)")
fleetURL := flag.String("fleet-url", "", "Fleet MDM server URL (full path) (required)")
migratePercentage := flag.Int("migrate-percentage", 0, "Percentage of clients to migrate from existing MDM to Fleet")
migrateUDIDs := flag.String("migrate-udids", "", "Space/newline-delimited list of UDIDs to migrate always")
serverAddr := flag.String("server-address", ":8080", "Address for server to listen on")
debug := flag.Bool("debug", false, "Enable debug logging")
logSkipped := flag.Bool("log-skipped", false, "Log skipped requests (usually from web scanners)")
check := flag.String("check", "", "Print whether the specified UDID is migrated with the current configuration, then exit")
flag.Parse()
udids, err := processUDIDs(bytes.NewBufferString(*migrateUDIDs))
if err != nil {
panic(err)
}
proxy := mdmProxy{
token: *authToken,
existingServerURL: *existingURL,
fleetServerURL: *fleetURL,
existingHostname: *existingHostname,
migratePercentage: *migratePercentage,
migrateUDIDs: udids,
existingProxy: makeExistingProxy(*existingURL, *existingHostname),
fleetProxy: makeFleetProxy(*fleetURL, *debug),
debug: *debug,
logSkipped: *logSkipped,
}
if len(*check) > 0 {
if proxy.isUDIDMigrated(*check) {
fmt.Printf("%s IS migrated\n", *check)
} else {
fmt.Printf("%s IS NOT migrated\n", *check)
}
os.Exit(0)
}
// Check required flags
if *existingURL == "" {
log.Fatal("--existing-url must be set")
}
if *existingHostname == "" {
log.Fatal("--existing-hostname must be set")
}
if *fleetURL == "" {
log.Fatal("--fleet-url must be set")
}
log.Printf("--migrate-udids set: %v", udids)
log.Printf("--migrate-percentage set: %d", *migratePercentage)
log.Printf("--existing-url set: %s", *existingURL)
log.Printf("--existing-hostname set: %s", *existingHostname)
log.Printf("--fleet-url set: %s", *fleetURL)
log.Printf("--debug set: %v", *debug)
log.Printf("--log-skipped set: %v", *logSkipped)
if *authToken != "" {
log.Printf("--auth-token set. Remote configuration enabled.")
} else {
log.Printf("--auth-token is empty. Remote configuration disabled.")
}
mux := http.NewServeMux()
// Health check endpoint used for load balancers
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("OK"))
if err != nil {
log.Printf("/healthz error: %v", err)
}
})
// Remote management of migration (enabled if auth token set)
mux.HandleFunc("/admin/udids", proxy.handleUpdateMigrateUDIDs)
mux.HandleFunc("/admin/percentage", proxy.handleUpdatePercentage)
// Handler for the actual proxying
mux.HandleFunc("/", proxy.handleProxy)
log.Printf("Starting server on %s", *serverAddr)
server := &http.Server{
Addr: *serverAddr,
ReadHeaderTimeout: 10 * time.Second,
Handler: mux,
}
err = server.ListenAndServe()
if err != nil {
fmt.Printf("Error starting server: %s\n", err)
}
}