mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 01:18:42 +00:00
Change status code 500=>408 when the MDM protocol endpoints time out reading the request body (#19698)
This commit is contained in:
parent
33b087955b
commit
468a9ff608
5 changed files with 81 additions and 0 deletions
|
|
@ -0,0 +1 @@
|
|||
* Fixed the `/mdm/apple/mdm` endpoint so that it returns status code 408 (request timeout) instead of 500 (internal server error) when encountering a timeout reading the request body.
|
||||
|
|
@ -145,6 +145,11 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push
|
|||
b, err := mdmhttp.ReadAllAndReplaceBody(r)
|
||||
if err != nil {
|
||||
logger.Info("msg", "reading body", "err", err)
|
||||
var toErr interface{ Timeout() bool }
|
||||
if errors.As(err, &toErr) && toErr.Timeout() {
|
||||
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -299,6 +304,11 @@ func StorePushCertHandler(storage storage.PushCertStore, logger log.Logger) http
|
|||
b, err := mdmhttp.ReadAllAndReplaceBody(r)
|
||||
if err != nil {
|
||||
logger.Info("msg", "reading body", "err", err)
|
||||
var toErr interface{ Timeout() bool }
|
||||
if errors.As(err, &toErr) && toErr.Timeout() {
|
||||
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
54
server/mdm/nanomdm/http/http_test.go
Normal file
54
server/mdm/nanomdm/http/http_test.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHTTPServerTimeoutError(t *testing.T) {
|
||||
// ensure that a read timeout error is properly detected
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
code := http.StatusOK
|
||||
if _, err := io.ReadAll(r.Body); err != nil {
|
||||
var toErr interface{ Timeout() bool }
|
||||
if errors.As(err, &toErr) && toErr.Timeout() {
|
||||
code = http.StatusRequestTimeout
|
||||
} else {
|
||||
code = http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
}))
|
||||
|
||||
srv.Config.ReadTimeout = time.Second
|
||||
srv.Start()
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", srv.URL, slowReader{b: []byte("slowly send this")})
|
||||
require.NoError(t, err)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
code := res.StatusCode
|
||||
require.Equal(t, http.StatusRequestTimeout, code)
|
||||
}
|
||||
|
||||
type slowReader struct {
|
||||
b []byte
|
||||
}
|
||||
|
||||
func (s slowReader) Read(p []byte) (n int, err error) {
|
||||
if len(s.b) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
n = copy(p, s.b[:len(s.b)/2])
|
||||
s.b = s.b[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
|
@ -32,6 +32,11 @@ func CheckinHandler(svc service.Checkin, logger log.Logger) http.HandlerFunc {
|
|||
bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r)
|
||||
if err != nil {
|
||||
logger.Info("msg", "reading body", "err", err)
|
||||
var toErr interface{ Timeout() bool }
|
||||
if errors.As(err, &toErr) && toErr.Timeout() {
|
||||
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -56,6 +61,11 @@ func CommandAndReportResultsHandler(svc service.CommandAndReportResults, logger
|
|||
bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r)
|
||||
if err != nil {
|
||||
logger.Info("msg", "reading body", "err", err)
|
||||
var toErr interface{ Timeout() bool }
|
||||
if errors.As(err, &toErr) && toErr.Timeout() {
|
||||
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package mdm
|
|||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
|
|
@ -84,6 +85,11 @@ func CertExtractMdmSignatureMiddleware(next http.Handler, logger log.Logger) htt
|
|||
b, err := mdmhttp.ReadAllAndReplaceBody(r)
|
||||
if err != nil {
|
||||
logger.Info("msg", "reading body", "err", err)
|
||||
var toErr interface{ Timeout() bool }
|
||||
if errors.As(err, &toErr) && toErr.Timeout() {
|
||||
http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue