diff --git a/server/mdm/apple/itunes/api.go b/server/mdm/apple/itunes/api.go index fafac6a82f..f81fe878eb 100644 --- a/server/mdm/apple/itunes/api.go +++ b/server/mdm/apple/itunes/api.go @@ -12,6 +12,7 @@ import ( "time" "github.com/fleetdm/fleet/v4/pkg/fleethttp" + "github.com/fleetdm/fleet/v4/pkg/retry" ) type AssetMetadata struct { @@ -86,6 +87,16 @@ func do[T any](req *http.Request, dest *T) error { if len(limitedBody) > 1000 { limitedBody = limitedBody[:1000] } + + if resp.StatusCode >= http.StatusInternalServerError { + return retry.Do( + func() error { return do(req, dest) }, + retry.WithInterval(1*time.Second), + retry.WithMaxAttempts(4), + ) + + } + return fmt.Errorf("calling Apple iTunes endpoint failed with status %d: %s", resp.StatusCode, string(limitedBody)) } diff --git a/server/mdm/apple/itunes/api_test.go b/server/mdm/apple/itunes/api_test.go index a43031fcea..bec950ff6f 100644 --- a/server/mdm/apple/itunes/api_test.go +++ b/server/mdm/apple/itunes/api_test.go @@ -1,8 +1,11 @@ package itunes import ( + "net/http" + "net/http/httptest" "os" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -19,3 +22,69 @@ func TestGetBaseURL(t *testing.T) { require.Equal(t, customURL, getBaseURL()) }) } + +func setupFakeServer(t *testing.T, handler http.HandlerFunc) { + server := httptest.NewServer(handler) + os.Setenv("FLEET_DEV_ITUNES_URL", server.URL) + t.Cleanup(server.Close) +} + +func TestDoRetries(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + wantCalls int + wantErr bool + }{ + { + name: "success status code", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("{}")) + require.NoError(t, err) + }, + wantCalls: 1, + wantErr: true, + }, + { + name: "bad requests", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("{}")) + require.NoError(t, err) + }, + wantCalls: 1, + wantErr: true, + }, + { + name: "500 requests retries", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + _, err := w.Write([]byte("{}")) + require.NoError(t, err) + }, + wantCalls: 4, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var calls int + setupFakeServer(t, func(w http.ResponseWriter, r *http.Request) { + calls++ + if calls < tt.wantCalls { + tt.handler(w, r) + return + } + }) + + start := time.Now() + req, err := http.NewRequest(http.MethodGet, os.Getenv("FLEET_DEV_ITUNES_URL"), nil) + require.NoError(t, err) + err = do[any](req, nil) + require.NoError(t, err) + require.Equal(t, tt.wantCalls, calls) + require.WithinRange(t, time.Now(), start, start.Add(time.Duration(tt.wantCalls)*time.Second)) + }) + } +}