diff --git a/changes/31850-retry-apple-vpp-api-timeout b/changes/31850-retry-apple-vpp-api-timeout new file mode 100644 index 0000000000..5792254a40 --- /dev/null +++ b/changes/31850-retry-apple-vpp-api-timeout @@ -0,0 +1 @@ +* Added retries with backoff when Apple's assets API fails with a timeout error. diff --git a/ee/server/service/software_installers.go b/ee/server/service/software_installers.go index 9b9dbf83ed..3b5efd9655 100644 --- a/ee/server/service/software_installers.go +++ b/ee/server/service/software_installers.go @@ -1225,7 +1225,7 @@ func (svc *Service) InstallVPPAppPostValidation(ctx context.Context, host *fleet // this app is not assigned to this device, check if we have licenses // left and assign it. if len(assignments) == 0 { - assets, err := vpp.GetAssets(token, &vpp.AssetFilter{AdamID: vppApp.AdamID}) + assets, err := vpp.GetAssets(ctx, token, &vpp.AssetFilter{AdamID: vppApp.AdamID}) if err != nil { return "", ctxerr.Wrap(ctx, err, "getting assets from VPP API") } diff --git a/ee/server/service/vpp.go b/ee/server/service/vpp.go index 96f5a95378..52c3d3a686 100644 --- a/ee/server/service/vpp.go +++ b/ee/server/service/vpp.go @@ -150,7 +150,7 @@ func (svc *Service) BatchAssociateVPPApps(ctx context.Context, teamName string, var missingAssets []string - assets, err := vpp.GetAssets(token, nil) + assets, err := vpp.GetAssets(ctx, token, nil) if err != nil { return nil, ctxerr.Wrap(ctx, err, "unable to retrieve assets") } @@ -218,7 +218,7 @@ func (svc *Service) GetAppStoreApps(ctx context.Context, teamID *uint) ([]*fleet return nil, ctxerr.Wrap(ctx, err, "retrieving VPP token") } - assets, err := vpp.GetAssets(vppToken, nil) + assets, err := vpp.GetAssets(ctx, vppToken, nil) if err != nil { return nil, ctxerr.Wrap(ctx, err, "fetching Apple VPP assets") } @@ -368,7 +368,7 @@ func (svc *Service) AddAppStoreApp(ctx context.Context, teamID *uint, appID flee return 0, ctxerr.Wrap(ctx, err, "retrieving VPP token") } - assets, err := vpp.GetAssets(vppToken, &vpp.AssetFilter{AdamID: appID.AdamID}) + assets, err := vpp.GetAssets(ctx, vppToken, &vpp.AssetFilter{AdamID: appID.AdamID}) if err != nil { return 0, ctxerr.Wrap(ctx, err, "retrieving VPP asset") } diff --git a/server/mdm/apple/vpp/api.go b/server/mdm/apple/vpp/api.go index 86a91b5054..7d93dcc658 100644 --- a/server/mdm/apple/vpp/api.go +++ b/server/mdm/apple/vpp/api.go @@ -2,9 +2,12 @@ package vpp import ( "bytes" + "context" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" "net/url" "os" @@ -153,7 +156,33 @@ type AssetFilter struct { } // GetAssets fetches the assets from Apple's VPP API with optional filters. -func GetAssets(token string, filter *AssetFilter) ([]Asset, error) { +func GetAssets(ctx context.Context, token string, filter *AssetFilter) ([]Asset, error) { + var assets []Asset + var returnErr error + + _ = retry.Do(func() error { + var err error + assets, err = getAssetsOnce(ctx, token, filter) + returnErr = err + + var ne net.Error + // if we still have some time left on the current request's context + // deadline and the error is a timeout, we may retry + if dl, _ := ctx.Deadline(); (dl.IsZero() || time.Until(dl) >= time.Second) && errors.As(err, &ne) && ne.Timeout() { + // will retry + return err + } + // returnErr may be != nil, but it's not an error that we should retry + return nil + }, + retry.WithBackoffMultiplier(3), + retry.WithInterval(100*time.Millisecond), + retry.WithMaxAttempts(3), + ) + return assets, returnErr +} + +func getAssetsOnce(ctx context.Context, token string, filter *AssetFilter) ([]Asset, error) { baseURL := getBaseURL() + "/assets" reqURL, err := url.Parse(baseURL) if err != nil { @@ -175,7 +204,7 @@ func GetAssets(token string, filter *AssetFilter) ([]Asset, error) { reqURL.RawQuery = query.Encode() } - req, err := http.NewRequest(http.MethodGet, reqURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) if err != nil { return nil, fmt.Errorf("creating request to Apple VPP endpoint: %w", err) } diff --git a/server/mdm/apple/vpp/api_test.go b/server/mdm/apple/vpp/api_test.go index 5d265b9b19..b9d125bba4 100644 --- a/server/mdm/apple/vpp/api_test.go +++ b/server/mdm/apple/vpp/api_test.go @@ -7,9 +7,11 @@ import ( "net/http" "net/http/httptest" "os" + "sync/atomic" "testing" "time" + "github.com/fleetdm/fleet/v4/pkg/fleethttp" "github.com/stretchr/testify/require" ) @@ -153,13 +155,22 @@ func TestAssociateAssets(t *testing.T) { } func TestGetAssets(t *testing.T) { + originalClient := client + client = fleethttp.NewClient(fleethttp.WithTimeout(time.Second)) + t.Cleanup(func() { + client = originalClient + }) + + var requestCount atomic.Int64 + tests := []struct { - name string - token string - filter *AssetFilter - handler http.HandlerFunc - expectedAssets []Asset - expectedErrMsg string + name string + token string + filter *AssetFilter + handler http.HandlerFunc + expectedAssets []Asset + expectedErrMsg string + expectedRequests int }{ { name: "valid token and filters", @@ -191,7 +202,8 @@ func TestGetAssets(t *testing.T) { {AdamID: "12345", PricingParam: "STDQ"}, {AdamID: "67890", PricingParam: "PLUS"}, }, - expectedErrMsg: "", + expectedErrMsg: "", + expectedRequests: 1, }, { name: "server error", @@ -201,8 +213,9 @@ func TestGetAssets(t *testing.T) { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintln(w, `Internal Server Error`) }, - expectedAssets: nil, - expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n", + expectedAssets: nil, + expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n", + expectedRequests: 1, }, { name: "client error", @@ -212,16 +225,73 @@ func TestGetAssets(t *testing.T) { w.WriteHeader(http.StatusBadRequest) fmt.Fprintln(w, `{"errorInfo":{},"errorMessage":"Bad Request","errorNumber":400}`) }, - expectedAssets: nil, - expectedErrMsg: "retrieving assets: Apple VPP endpoint returned error: Bad Request (error number: 400)", + expectedAssets: nil, + expectedErrMsg: "retrieving assets: Apple VPP endpoint returned error: Bad Request (error number: 400)", + expectedRequests: 1, + }, + { + name: "always times out", + token: "valid_token", + filter: nil, + handler: func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Second + 500*time.Millisecond) // longer than the 1s client timeout + type resp struct { + Assets []Asset `json:"assets"` + } + assets := resp{ + Assets: []Asset{ + {AdamID: "12345", PricingParam: "STDQ"}, + {AdamID: "67890", PricingParam: "PLUS"}, + }, + } + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(assets)) + }, + expectedAssets: nil, + expectedErrMsg: "context deadline exceeded (Client.Timeout exceeded while awaiting headers)", + expectedRequests: 3, + }, + { + name: "times out then valid", + token: "valid_token", + filter: nil, + handler: func(w http.ResponseWriter, r *http.Request) { + if requestCount.Load() < 2 { + time.Sleep(time.Second + 500*time.Millisecond) // longer than the 1s client timeout + } + + type resp struct { + Assets []Asset `json:"assets"` + } + assets := resp{ + Assets: []Asset{ + {AdamID: "12345", PricingParam: "STDQ"}, + {AdamID: "67890", PricingParam: "PLUS"}, + }, + } + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(assets)) + }, + expectedAssets: []Asset{ + {AdamID: "12345", PricingParam: "STDQ"}, + {AdamID: "67890", PricingParam: "PLUS"}, + }, + expectedErrMsg: "", + expectedRequests: 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setupFakeServer(t, tt.handler) + requestCount.Store(0) - assets, err := GetAssets(tt.token, tt.filter) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + tt.handler(w, r) + }) + setupFakeServer(t, h) + + assets, err := GetAssets(t.Context(), tt.token, tt.filter) if tt.expectedErrMsg != "" { require.Error(t, err) require.Contains(t, err.Error(), tt.expectedErrMsg) @@ -229,6 +299,7 @@ func TestGetAssets(t *testing.T) { require.NoError(t, err) require.Equal(t, tt.expectedAssets, assets) } + require.EqualValues(t, tt.expectedRequests, requestCount.Load()) }) } }