Change status code 500=>408 when the MDM protocol endpoints time out reading the request body (#19698)

This commit is contained in:
Martin Angers 2024-06-12 16:30:49 -04:00 committed by GitHub
parent 33b087955b
commit 468a9ff608
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 81 additions and 0 deletions

View file

@ -0,0 +1 @@
* Fixed the `/mdm/apple/mdm` endpoint so that it returns status code 408 (request timeout) instead of 500 (internal server error) when encountering a timeout reading the request body.

View file

@ -145,6 +145,11 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push
b, err := mdmhttp.ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
var toErr interface{ Timeout() bool }
if errors.As(err, &toErr) && toErr.Timeout() {
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
@ -299,6 +304,11 @@ func StorePushCertHandler(storage storage.PushCertStore, logger log.Logger) http
b, err := mdmhttp.ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
var toErr interface{ Timeout() bool }
if errors.As(err, &toErr) && toErr.Timeout() {
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

View file

@ -0,0 +1,54 @@
package http
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestHTTPServerTimeoutError(t *testing.T) {
// ensure that a read timeout error is properly detected
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
code := http.StatusOK
if _, err := io.ReadAll(r.Body); err != nil {
var toErr interface{ Timeout() bool }
if errors.As(err, &toErr) && toErr.Timeout() {
code = http.StatusRequestTimeout
} else {
code = http.StatusInternalServerError
}
}
w.WriteHeader(code)
}))
srv.Config.ReadTimeout = time.Second
srv.Start()
defer srv.Close()
req, err := http.NewRequest("POST", srv.URL, slowReader{b: []byte("slowly send this")})
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
code := res.StatusCode
require.Equal(t, http.StatusRequestTimeout, code)
}
type slowReader struct {
b []byte
}
func (s slowReader) Read(p []byte) (n int, err error) {
if len(s.b) == 0 {
return 0, io.EOF
}
time.Sleep(200 * time.Millisecond)
n = copy(p, s.b[:len(s.b)/2])
s.b = s.b[n:]
return n, nil
}

View file

@ -32,6 +32,11 @@ func CheckinHandler(svc service.Checkin, logger log.Logger) http.HandlerFunc {
bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
var toErr interface{ Timeout() bool }
if errors.As(err, &toErr) && toErr.Timeout() {
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
@ -56,6 +61,11 @@ func CommandAndReportResultsHandler(svc service.CommandAndReportResults, logger
bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
var toErr interface{ Timeout() bool }
if errors.As(err, &toErr) && toErr.Timeout() {
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

View file

@ -3,6 +3,7 @@ package mdm
import (
"context"
"crypto/x509"
"errors"
"net/http"
"net/url"
@ -84,6 +85,11 @@ func CertExtractMdmSignatureMiddleware(next http.Handler, logger log.Logger) htt
b, err := mdmhttp.ReadAllAndReplaceBody(r)
if err != nil {
logger.Info("msg", "reading body", "err", err)
var toErr interface{ Timeout() bool }
if errors.As(err, &toErr) && toErr.Timeout() {
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
return
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}