diff --git a/docs/REST API/rest-api.md b/docs/REST API/rest-api.md index bd576cd916..16610e4eb3 100644 --- a/docs/REST API/rest-api.md +++ b/docs/REST API/rest-api.md @@ -3931,7 +3931,7 @@ This endpoint tells Fleet to run a custom an MDM command, on the targeted macOS | Name | Type | In | Description | | ------------------------- | ------ | ----- | ------------------------------------------------------------------------- | -| command | string | json | A base64-encoded MDM command as described in [Apple's documentation](https://developer.apple.com/documentation/devicemanagement/commands_and_queries) | +| command | string | json | A base64-encoded MDM command as described in [Apple's documentation](https://developer.apple.com/documentation/devicemanagement/commands_and_queries). Supported formats are standard (RFC 4648) and raw (unpadded) encoding (RFC 4648 section 3.2) | | device_ids | array | json | An array of host UUIDs enrolled in Fleet's MDM on which the command should run. | Note that the `EraseDevice` and `DeviceLock` commands are _available in Fleet Premium_ only. diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index 28c38ab66b..bf583c0d93 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -3,7 +3,6 @@ package service import ( "bytes" "context" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -19,6 +18,7 @@ import ( "github.com/VividCortex/mysqlerr" "github.com/docker/go-units" "github.com/fleetdm/fleet/v4/pkg/file" + "github.com/fleetdm/fleet/v4/server" "github.com/fleetdm/fleet/v4/server/authz" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/contexts/license" @@ -974,7 +974,11 @@ func (svc *Service) EnqueueMDMAppleCommand( } } - rawXMLCmd, err := base64.RawStdEncoding.DecodeString(rawBase64Cmd) + // using a padding agnostic decoder because we released this using + // base64.RawStdEncoding, but it was causing problems as many standard + // libraries default to padded strings. We're now supporting both for + // backwards compatibility. + rawXMLCmd, err := server.Base64DecodePaddingAgnostic(rawBase64Cmd) if err != nil { err = fleet.NewInvalidArgumentError("command", "unable to decode base64 command").WithStatus(http.StatusBadRequest) diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index 0ea8fa4ff3..bd53ae2fc2 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -3628,7 +3628,9 @@ func (s *integrationMDMTestSuite) TestEnqueueMDMCommand() { uuid1 := uuid.New().String() s.Do("POST", "/api/latest/fleet/mdm/apple/enqueue", enqueueMDMAppleCommandRequest{ - Command: base64Cmd(newRawCmd(uuid1)), + // explicitly use standard encoding to make sure it also works + // see #11384 + Command: base64.StdEncoding.EncodeToString([]byte(newRawCmd(uuid1))), DeviceIDs: []string{"no-such-host"}, }, http.StatusNotFound) diff --git a/server/utils.go b/server/utils.go index 06367dea8b..ce0444386c 100644 --- a/server/utils.go +++ b/server/utils.go @@ -138,3 +138,10 @@ func GetTemplate(templatePath string, templateName string) (*template.Template, return t, nil } + +// Base64DecodePaddingAgnostic decodes a base64 string that might be encoded +// using raw encoding or standard encoding (padded) +func Base64DecodePaddingAgnostic(s string) ([]byte, error) { + us := strings.TrimRight(s, string(base64.StdPadding)) + return base64.RawStdEncoding.DecodeString(us) +} diff --git a/server/utils_test.go b/server/utils_test.go index 3240080d27..1a57e0c3e8 100644 --- a/server/utils_test.go +++ b/server/utils_test.go @@ -1,6 +1,7 @@ package server import ( + "encoding/base64" "errors" "net/url" "testing" @@ -101,3 +102,24 @@ func TestMaskURLError(t *testing.T) { require.NotContains(t, masked.Error(), "42") }) } + +func TestBase64DecodePaddingAgnostic(t *testing.T) { + cases := []struct { + in string + want []byte + err error + }{ + {"", []byte{}, nil}, + {"==", []byte{}, nil}, + {"==", []byte{}, nil}, + {"dGVzdA==", []byte("test"), nil}, + {"dGVzdA", []byte("test"), nil}, + {"dGVzdA==ABC", []byte("tes"), base64.CorruptInputError(6)}, + } + + for _, c := range cases { + got, err := Base64DecodePaddingAgnostic(c.in) + require.Equal(t, c.err, err) + require.Equal(t, got, c.want) + } +}