Bugfix: retry VPP assets API call on Apple timeout, until our own context hits its timeout (#33313)

This commit is contained in:
Martin Angers 2025-09-23 10:46:30 -04:00 committed by GitHub
parent 834ab62ed0
commit 64f27c69aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 120 additions and 19 deletions

View file

@ -0,0 +1 @@
* Added retries with backoff when Apple's assets API fails with a timeout error.

View file

@ -1225,7 +1225,7 @@ func (svc *Service) InstallVPPAppPostValidation(ctx context.Context, host *fleet
// this app is not assigned to this device, check if we have licenses
// left and assign it.
if len(assignments) == 0 {
assets, err := vpp.GetAssets(token, &vpp.AssetFilter{AdamID: vppApp.AdamID})
assets, err := vpp.GetAssets(ctx, token, &vpp.AssetFilter{AdamID: vppApp.AdamID})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "getting assets from VPP API")
}

View file

@ -150,7 +150,7 @@ func (svc *Service) BatchAssociateVPPApps(ctx context.Context, teamName string,
var missingAssets []string
assets, err := vpp.GetAssets(token, nil)
assets, err := vpp.GetAssets(ctx, token, nil)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "unable to retrieve assets")
}
@ -218,7 +218,7 @@ func (svc *Service) GetAppStoreApps(ctx context.Context, teamID *uint) ([]*fleet
return nil, ctxerr.Wrap(ctx, err, "retrieving VPP token")
}
assets, err := vpp.GetAssets(vppToken, nil)
assets, err := vpp.GetAssets(ctx, vppToken, nil)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "fetching Apple VPP assets")
}
@ -368,7 +368,7 @@ func (svc *Service) AddAppStoreApp(ctx context.Context, teamID *uint, appID flee
return 0, ctxerr.Wrap(ctx, err, "retrieving VPP token")
}
assets, err := vpp.GetAssets(vppToken, &vpp.AssetFilter{AdamID: appID.AdamID})
assets, err := vpp.GetAssets(ctx, vppToken, &vpp.AssetFilter{AdamID: appID.AdamID})
if err != nil {
return 0, ctxerr.Wrap(ctx, err, "retrieving VPP asset")
}

View file

@ -2,9 +2,12 @@ package vpp
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
@ -153,7 +156,33 @@ type AssetFilter struct {
}
// GetAssets fetches the assets from Apple's VPP API with optional filters.
func GetAssets(token string, filter *AssetFilter) ([]Asset, error) {
func GetAssets(ctx context.Context, token string, filter *AssetFilter) ([]Asset, error) {
var assets []Asset
var returnErr error
_ = retry.Do(func() error {
var err error
assets, err = getAssetsOnce(ctx, token, filter)
returnErr = err
var ne net.Error
// if we still have some time left on the current request's context
// deadline and the error is a timeout, we may retry
if dl, _ := ctx.Deadline(); (dl.IsZero() || time.Until(dl) >= time.Second) && errors.As(err, &ne) && ne.Timeout() {
// will retry
return err
}
// returnErr may be != nil, but it's not an error that we should retry
return nil
},
retry.WithBackoffMultiplier(3),
retry.WithInterval(100*time.Millisecond),
retry.WithMaxAttempts(3),
)
return assets, returnErr
}
func getAssetsOnce(ctx context.Context, token string, filter *AssetFilter) ([]Asset, error) {
baseURL := getBaseURL() + "/assets"
reqURL, err := url.Parse(baseURL)
if err != nil {
@ -175,7 +204,7 @@ func GetAssets(token string, filter *AssetFilter) ([]Asset, error) {
reqURL.RawQuery = query.Encode()
}
req, err := http.NewRequest(http.MethodGet, reqURL.String(), nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("creating request to Apple VPP endpoint: %w", err)
}

View file

@ -7,9 +7,11 @@ import (
"net/http"
"net/http/httptest"
"os"
"sync/atomic"
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/stretchr/testify/require"
)
@ -153,13 +155,22 @@ func TestAssociateAssets(t *testing.T) {
}
func TestGetAssets(t *testing.T) {
originalClient := client
client = fleethttp.NewClient(fleethttp.WithTimeout(time.Second))
t.Cleanup(func() {
client = originalClient
})
var requestCount atomic.Int64
tests := []struct {
name string
token string
filter *AssetFilter
handler http.HandlerFunc
expectedAssets []Asset
expectedErrMsg string
name string
token string
filter *AssetFilter
handler http.HandlerFunc
expectedAssets []Asset
expectedErrMsg string
expectedRequests int
}{
{
name: "valid token and filters",
@ -191,7 +202,8 @@ func TestGetAssets(t *testing.T) {
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
expectedErrMsg: "",
expectedErrMsg: "",
expectedRequests: 1,
},
{
name: "server error",
@ -201,8 +213,9 @@ func TestGetAssets(t *testing.T) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, `Internal Server Error`)
},
expectedAssets: nil,
expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n",
expectedAssets: nil,
expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n",
expectedRequests: 1,
},
{
name: "client error",
@ -212,16 +225,73 @@ func TestGetAssets(t *testing.T) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, `{"errorInfo":{},"errorMessage":"Bad Request","errorNumber":400}`)
},
expectedAssets: nil,
expectedErrMsg: "retrieving assets: Apple VPP endpoint returned error: Bad Request (error number: 400)",
expectedAssets: nil,
expectedErrMsg: "retrieving assets: Apple VPP endpoint returned error: Bad Request (error number: 400)",
expectedRequests: 1,
},
{
name: "always times out",
token: "valid_token",
filter: nil,
handler: func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second + 500*time.Millisecond) // longer than the 1s client timeout
type resp struct {
Assets []Asset `json:"assets"`
}
assets := resp{
Assets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
}
w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(assets))
},
expectedAssets: nil,
expectedErrMsg: "context deadline exceeded (Client.Timeout exceeded while awaiting headers)",
expectedRequests: 3,
},
{
name: "times out then valid",
token: "valid_token",
filter: nil,
handler: func(w http.ResponseWriter, r *http.Request) {
if requestCount.Load() < 2 {
time.Sleep(time.Second + 500*time.Millisecond) // longer than the 1s client timeout
}
type resp struct {
Assets []Asset `json:"assets"`
}
assets := resp{
Assets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
}
w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(assets))
},
expectedAssets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
expectedErrMsg: "",
expectedRequests: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setupFakeServer(t, tt.handler)
requestCount.Store(0)
assets, err := GetAssets(tt.token, tt.filter)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
tt.handler(w, r)
})
setupFakeServer(t, h)
assets, err := GetAssets(t.Context(), tt.token, tt.filter)
if tt.expectedErrMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedErrMsg)
@ -229,6 +299,7 @@ func TestGetAssets(t *testing.T) {
require.NoError(t, err)
require.Equal(t, tt.expectedAssets, assets)
}
require.EqualValues(t, tt.expectedRequests, requestCount.Load())
})
}
}