diff --git a/changes/17728-send-408-instead-of-500-for-apple-mdm-timeout b/changes/17728-send-408-instead-of-500-for-apple-mdm-timeout new file mode 100644 index 0000000000..369f33441b --- /dev/null +++ b/changes/17728-send-408-instead-of-500-for-apple-mdm-timeout @@ -0,0 +1 @@ +* Fixed the `/mdm/apple/mdm` endpoint so that it returns status code 408 (request timeout) instead of 500 (internal server error) when encountering a timeout reading the request body. diff --git a/server/mdm/nanomdm/http/api/api.go b/server/mdm/nanomdm/http/api/api.go index 5649470f85..6cdcda6893 100644 --- a/server/mdm/nanomdm/http/api/api.go +++ b/server/mdm/nanomdm/http/api/api.go @@ -145,6 +145,11 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push b, err := mdmhttp.ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) + var toErr interface{ Timeout() bool } + if errors.As(err, &toErr) && toErr.Timeout() { + http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -299,6 +304,11 @@ func StorePushCertHandler(storage storage.PushCertStore, logger log.Logger) http b, err := mdmhttp.ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) + var toErr interface{ Timeout() bool } + if errors.As(err, &toErr) && toErr.Timeout() { + http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } diff --git a/server/mdm/nanomdm/http/http_test.go b/server/mdm/nanomdm/http/http_test.go new file mode 100644 index 0000000000..26fe47dc4f --- /dev/null +++ b/server/mdm/nanomdm/http/http_test.go @@ -0,0 +1,54 @@ +package http + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestHTTPServerTimeoutError(t *testing.T) { + // ensure that a read timeout error is properly detected + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + code := http.StatusOK + if _, err := io.ReadAll(r.Body); err != nil { + var toErr interface{ Timeout() bool } + if errors.As(err, &toErr) && toErr.Timeout() { + code = http.StatusRequestTimeout + } else { + code = http.StatusInternalServerError + } + } + w.WriteHeader(code) + })) + + srv.Config.ReadTimeout = time.Second + srv.Start() + defer srv.Close() + + req, err := http.NewRequest("POST", srv.URL, slowReader{b: []byte("slowly send this")}) + require.NoError(t, err) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + code := res.StatusCode + require.Equal(t, http.StatusRequestTimeout, code) +} + +type slowReader struct { + b []byte +} + +func (s slowReader) Read(p []byte) (n int, err error) { + if len(s.b) == 0 { + return 0, io.EOF + } + + time.Sleep(200 * time.Millisecond) + n = copy(p, s.b[:len(s.b)/2]) + s.b = s.b[n:] + return n, nil +} diff --git a/server/mdm/nanomdm/http/mdm/mdm.go b/server/mdm/nanomdm/http/mdm/mdm.go index a3bd21796a..bad75bce2b 100644 --- a/server/mdm/nanomdm/http/mdm/mdm.go +++ b/server/mdm/nanomdm/http/mdm/mdm.go @@ -32,6 +32,11 @@ func CheckinHandler(svc service.Checkin, logger log.Logger) http.HandlerFunc { bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) + var toErr interface{ Timeout() bool } + if errors.As(err, &toErr) && toErr.Timeout() { + http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -56,6 +61,11 @@ func CommandAndReportResultsHandler(svc service.CommandAndReportResults, logger bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) + var toErr interface{ Timeout() bool } + if errors.As(err, &toErr) && toErr.Timeout() { + http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } diff --git a/server/mdm/nanomdm/http/mdm/mdm_cert.go b/server/mdm/nanomdm/http/mdm/mdm_cert.go index 69ef246ef6..93308ec841 100644 --- a/server/mdm/nanomdm/http/mdm/mdm_cert.go +++ b/server/mdm/nanomdm/http/mdm/mdm_cert.go @@ -3,6 +3,7 @@ package mdm import ( "context" "crypto/x509" + "errors" "net/http" "net/url" @@ -84,6 +85,11 @@ func CertExtractMdmSignatureMiddleware(next http.Handler, logger log.Logger) htt b, err := mdmhttp.ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) + var toErr interface{ Timeout() bool } + if errors.As(err, &toErr) && toErr.Timeout() { + http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout) + return + } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return }