fleet/server/mdm/apple/vpp/api_test.go
Ian Littman 2f25580c3a
Only allow FLEET_DEV_* env vars when --dev is passed, allow overriding configs one at a time in dev (#38652)
Resolves #38484. This includes a CI job change to make sure we don't
introduce any more env vars that don't get proxied (and thus turned off
outside `--dev`).

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)

## Testing

- [x] Added/updated automated tests

Manual QA touched hot paths, but did _not_ manually test every
FLEET_DEV_* environment variable change.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Centralized dev-mode environment management for consistent FLEET_DEV_*
handling and test-friendly overrides.
* Dev-mode allows targeted overrides for certain dev-only configuration
when running with --dev.

* **Chores**
* Migrated environment access to the centralized dev-mode helper across
the codebase.
  * Added CI checks to enforce proper usage of FLEET_DEV_* variables.

* **Documentation**
  * Added guidance on dev-mode environment variable rules and overrides.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Victor Lyuboslavsky <2685025+getvictor@users.noreply.github.com>
2026-01-27 14:32:56 -06:00

379 lines
10 KiB
Go

package vpp
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/dev_mode"
"github.com/stretchr/testify/require"
)
func setupFakeServer(t *testing.T, handler http.HandlerFunc) {
server := httptest.NewServer(handler)
dev_mode.SetOverride("FLEET_DEV_VPP_URL", server.URL, t)
t.Cleanup(server.Close)
}
func TestGetConfig(t *testing.T) {
tests := []struct {
name string
token string
handler http.HandlerFunc
wantName string
expectedErrMsg string
}{
{
name: "valid token",
token: "valid_token",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, `{"locationName": "Test Location"}`)
},
wantName: "Test Location",
expectedErrMsg: "",
},
{
name: "invalid token",
token: "invalid_token",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprintln(w, `{"errorNumber": 9622}`)
},
wantName: "",
expectedErrMsg: "making request to Apple VPP endpoint: Apple VPP endpoint returned error: (error number: 9622)",
},
{
name: "server error",
token: "valid_token",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, `Internal Server Error`)
},
wantName: "",
expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setupFakeServer(t, tt.handler)
name, err := GetConfig(tt.token)
if tt.expectedErrMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedErrMsg)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantName, name)
})
}
}
func TestAssociateAssets(t *testing.T) {
tests := []struct {
name string
token string
params *AssociateAssetsRequest
handler http.HandlerFunc
expectedErrMsg string
}{
{
name: "valid request",
token: "valid_token",
params: &AssociateAssetsRequest{
Assets: []Asset{{AdamID: "12345", PricingParam: "STDQ"}},
SerialNumbers: []string{"SN12345"},
},
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/assets/associate", r.URL.Path)
require.Equal(t, "Bearer valid_token", r.Header.Get("Authorization"))
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
var reqParams AssociateAssetsRequest
err = json.Unmarshal(body, &reqParams)
require.NoError(t, err)
require.Equal(t, []Asset{{AdamID: "12345", PricingParam: "STDQ"}}, reqParams.Assets)
require.Equal(t, []string{"SN12345"}, reqParams.SerialNumbers)
_, _ = w.Write([]byte(`{"eventId": "123"}`))
},
expectedErrMsg: "",
},
{
name: "server error",
token: "valid_token",
params: &AssociateAssetsRequest{
Assets: []Asset{{AdamID: "12345", PricingParam: "STDQ"}},
SerialNumbers: []string{"SN12345"},
},
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, `Internal Server Error`)
},
expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n",
},
{
name: "client error",
token: "valid_token",
params: &AssociateAssetsRequest{
Assets: []Asset{{AdamID: "12345", PricingParam: "STDQ"}},
SerialNumbers: []string{"SN12345"},
},
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, `{"errorInfo":{},"errorMessage":"Bad Request","errorNumber":400}`)
},
expectedErrMsg: "making request to Apple VPP endpoint: Apple VPP endpoint returned error: Bad Request (error number: 400)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setupFakeServer(t, tt.handler)
_, err := AssociateAssets(tt.token, tt.params)
if tt.expectedErrMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAssets(t *testing.T) {
originalClient := client
client = fleethttp.NewClient(fleethttp.WithTimeout(time.Second))
t.Cleanup(func() {
client = originalClient
})
var requestCount atomic.Int64
tests := []struct {
name string
token string
filter *AssetFilter
handler http.HandlerFunc
expectedAssets []Asset
expectedErrMsg string
expectedRequests int
}{
{
name: "valid token and filters",
token: "valid_token",
filter: &AssetFilter{
AdamID: "12345",
},
handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/assets", r.URL.Path)
require.Equal(t, "Bearer valid_token", r.Header.Get("Authorization"))
query := r.URL.Query()
require.Equal(t, "12345", query.Get("adamId"))
type resp struct {
Assets []Asset `json:"assets"`
}
assets := resp{
Assets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
}
w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(assets))
},
expectedAssets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
expectedErrMsg: "",
expectedRequests: 1,
},
{
name: "server error",
token: "valid_token",
filter: nil,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, `Internal Server Error`)
},
expectedAssets: nil,
expectedErrMsg: "calling Apple VPP endpoint failed with status 500: Internal Server Error\n",
expectedRequests: 1,
},
{
name: "client error",
token: "valid_token",
filter: nil,
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, `{"errorInfo":{},"errorMessage":"Bad Request","errorNumber":400}`)
},
expectedAssets: nil,
expectedErrMsg: "retrieving assets: Apple VPP endpoint returned error: Bad Request (error number: 400)",
expectedRequests: 1,
},
{
name: "always times out",
token: "valid_token",
filter: nil,
handler: func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second + 500*time.Millisecond) // longer than the 1s client timeout
type resp struct {
Assets []Asset `json:"assets"`
}
assets := resp{
Assets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
}
w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(assets))
},
expectedAssets: nil,
expectedErrMsg: "context deadline exceeded (Client.Timeout exceeded while awaiting headers)",
expectedRequests: 3,
},
{
name: "times out then valid",
token: "valid_token",
filter: nil,
handler: func(w http.ResponseWriter, r *http.Request) {
if requestCount.Load() < 2 {
time.Sleep(time.Second + 500*time.Millisecond) // longer than the 1s client timeout
}
type resp struct {
Assets []Asset `json:"assets"`
}
assets := resp{
Assets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
}
w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(assets))
},
expectedAssets: []Asset{
{AdamID: "12345", PricingParam: "STDQ"},
{AdamID: "67890", PricingParam: "PLUS"},
},
expectedErrMsg: "",
expectedRequests: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
requestCount.Store(0)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
tt.handler(w, r)
})
setupFakeServer(t, h)
assets, err := GetAssets(t.Context(), tt.token, tt.filter)
if tt.expectedErrMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedErrMsg)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedAssets, assets)
}
require.EqualValues(t, tt.expectedRequests, requestCount.Load())
})
}
}
func TestDoRetryAfter(t *testing.T) {
tests := []struct {
name string
handler http.HandlerFunc
wantCalls int
wantErr bool
}{
{
name: "no retry-after header",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, err := w.Write([]byte("{}"))
require.NoError(t, err)
},
wantCalls: 1,
wantErr: true,
},
{
name: "invalid retry-after header",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Retry-After", "foo")
w.WriteHeader(http.StatusInternalServerError)
_, err := w.Write([]byte("{}"))
require.NoError(t, err)
},
wantCalls: 1,
wantErr: true,
},
{
name: "three retries",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Retry-After", "1")
w.WriteHeader(http.StatusInternalServerError)
_, err := w.Write([]byte("{}"))
require.NoError(t, err)
},
wantCalls: 3,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var calls int
setupFakeServer(t, func(w http.ResponseWriter, r *http.Request) {
calls++
if calls < tt.wantCalls {
tt.handler(w, r)
return
}
})
start := time.Now()
req, err := http.NewRequest(http.MethodGet, dev_mode.Env("FLEET_DEV_VPP_URL"), nil)
require.NoError(t, err)
err = do[any](req, "test-token", nil)
require.NoError(t, err)
require.Equal(t, tt.wantCalls, calls)
require.WithinRange(t, time.Now(), start, start.Add(time.Duration(tt.wantCalls)*time.Second))
})
}
}
func TestGetBaseURL(t *testing.T) {
t.Run("Default URL", func(t *testing.T) {
require.Equal(t, "https://vpp.itunes.apple.com/mdm/v2", getBaseURL())
})
t.Run("Custom URL", func(t *testing.T) {
customURL := "http://localhost:8000"
dev_mode.SetOverride("FLEET_DEV_VPP_URL", customURL, t)
require.Equal(t, customURL, getBaseURL())
})
}