refactor(core): move 5xx fallback guard from retry layer to policy layer

Revert the authType check in retryWithBackoff and instead add an
early-return guard in handleFallback for 5xx errors. A model switch
cannot fix a server error, so the fallback dialog should not be
offered — the error is propagated to the user after retries are
exhausted.

This keeps retry.ts free of auth-type concerns and places the policy
decision where it belongs (the fallback handler).
This commit is contained in:
Mai Nakagawa 2026-04-11 21:29:47 +09:00
parent 88d8eab578
commit b1685fefb3
No known key found for this signature in database
GPG key ID: D0D676E60A6D4042
7 changed files with 94 additions and 281 deletions

View file

@ -162,59 +162,6 @@ describe('createContentGenerator', () => {
);
});
it('should wrap custom baseUrl generators to avoid SDK streaming hangs', async () => {
const mockGenerator = {
models: {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
},
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
const generator = await createContentGenerator(
{
apiKey: 'test-api-key',
authType: AuthType.GATEWAY,
baseUrl: 'https://example.com',
},
mockConfig,
);
expect(generator).toBeInstanceOf(LoggingContentGenerator);
});
it('should wrap generators when GOOGLE_GEMINI_BASE_URL is set', async () => {
const mockGenerator = {
models: {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
},
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
vi.stubEnv('GOOGLE_GEMINI_BASE_URL', 'https://litellm.example.com');
const generator = await createContentGenerator(
{
apiKey: 'test-api-key',
authType: AuthType.USE_GEMINI,
},
mockConfig,
);
expect(GoogleGenAI).toHaveBeenCalledWith(
expect.objectContaining({
httpOptions: expect.objectContaining({
baseUrl: 'https://litellm.example.com',
}),
}),
);
expect(generator).toBeInstanceOf(LoggingContentGenerator);
});
it('should use standard User-Agent for a2a-server running outside VS Code', async () => {
const mockConfig = {
getModel: vi.fn().mockReturnValue('gemini-pro'),

View file

@ -26,7 +26,6 @@ import { FakeContentGenerator } from './fakeContentGenerator.js';
import { parseCustomHeaders } from '../utils/customHeaderUtils.js';
import { determineSurface } from '../utils/surface.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
import { CustomBaseUrlContentGenerator } from './customBaseUrlContentGenerator.js';
import { getVersion, resolveModel } from '../../index.js';
import type { LlmRole } from '../telemetry/llmRole.js';
@ -102,20 +101,6 @@ export type ContentGeneratorConfig = {
customHeaders?: Record<string, string>;
};
function getConfiguredBaseUrl(
config: ContentGeneratorConfig,
): string | undefined {
if (config.baseUrl) {
return config.baseUrl;
}
if (config.vertexai) {
return process.env['GOOGLE_VERTEX_BASE_URL'] || undefined;
}
return process.env['GOOGLE_GEMINI_BASE_URL'] || undefined;
}
export async function createContentGeneratorConfig(
config: Config,
authType: AuthType | undefined,
@ -293,9 +278,8 @@ export async function createContentGenerator(
headers: Record<string, string>;
} = { headers };
const configuredBaseUrl = getConfiguredBaseUrl(config);
if (configuredBaseUrl) {
httpOptions.baseUrl = configuredBaseUrl;
if (config.baseUrl) {
httpOptions.baseUrl = config.baseUrl;
}
const googleGenAI = new GoogleGenAI({
@ -304,10 +288,7 @@ export async function createContentGenerator(
httpOptions,
...(apiVersionEnv && { apiVersion: apiVersionEnv }),
});
const contentGenerator = configuredBaseUrl
? new CustomBaseUrlContentGenerator(googleGenAI.models)
: googleGenAI.models;
return new LoggingContentGenerator(contentGenerator, gcConfig);
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
}
throw new Error(
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,

View file

@ -1,81 +0,0 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, expect, it, vi } from 'vitest';
import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import { CustomBaseUrlContentGenerator } from './customBaseUrlContentGenerator.js';
import { LlmRole } from '../telemetry/llmRole.js';
import type { ContentGenerator } from './contentGenerator.js';
describe('CustomBaseUrlContentGenerator', () => {
it('adapts generateContent to a single-chunk stream', async () => {
const response = {
candidates: [
{
content: {
parts: [{ text: 'custom base url response' }],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
const wrapped = {
generateContent: vi.fn().mockResolvedValue(response),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
} as unknown as ContentGenerator;
const generator = new CustomBaseUrlContentGenerator(wrapped);
const request = {
model: 'test-model',
contents: 'test prompt',
} as GenerateContentParameters;
const stream = await generator.generateContentStream(
request,
'prompt-id',
LlmRole.MAIN,
);
const chunks: GenerateContentResponse[] = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(wrapped.generateContent).toHaveBeenCalledWith(
request,
'prompt-id',
LlmRole.MAIN,
);
expect(wrapped.generateContentStream).not.toHaveBeenCalled();
expect(chunks).toEqual([response]);
});
it('propagates upstream errors from generateContent', async () => {
const error = new Error('Bad Request');
const wrapped = {
generateContent: vi.fn().mockRejectedValue(error),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
} as unknown as ContentGenerator;
const generator = new CustomBaseUrlContentGenerator(wrapped);
await expect(
generator.generateContentStream(
{
model: 'bad-model',
contents: 'test prompt',
} as GenerateContentParameters,
'prompt-id',
LlmRole.MAIN,
),
).rejects.toThrow(error);
expect(wrapped.generateContentStream).not.toHaveBeenCalled();
});
});

View file

@ -1,112 +0,0 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import type { LlmRole } from '../telemetry/llmRole.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { ContentGenerator } from './contentGenerator.js';
/**
* The SDK streaming path can hang when requests are routed through a custom
* Gemini-compatible base URL and the upstream returns an immediate HTTP error.
* Route those calls through the non-streaming API and adapt the response into a
* single-chunk async generator so upper layers can keep the same interface.
*/
export class CustomBaseUrlContentGenerator implements ContentGenerator {
constructor(private readonly wrapped: ContentGenerator) {}
get userTier() {
return this.wrapped.userTier;
}
get userTierName() {
return this.wrapped.userTierName;
}
get paidTier() {
return this.wrapped.paidTier;
}
async generateContent(
request: GenerateContentParameters,
userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse> {
return this.wrapped.generateContent(request, userPromptId, role);
}
async generateContentStream(
request: GenerateContentParameters,
userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const startTime = Date.now();
debugLogger.debug(
'[DEBUG] [CustomBaseUrlContentGenerator] Falling back from generateContentStream to generateContent',
{
model: request.model,
promptId: userPromptId,
role,
},
);
let response: GenerateContentResponse;
try {
response = await this.wrapped.generateContent(
request,
userPromptId,
role,
);
} catch (error) {
debugLogger.debug(
'[DEBUG] [CustomBaseUrlContentGenerator] Fallback generateContent failed',
{
model: request.model,
promptId: userPromptId,
role,
durationMs: Date.now() - startTime,
error: error instanceof Error ? error.message : String(error),
},
);
throw error;
}
debugLogger.debug(
'[DEBUG] [CustomBaseUrlContentGenerator] Fallback generateContent succeeded',
{
model: request.model,
promptId: userPromptId,
role,
durationMs: Date.now() - startTime,
responseId: response.responseId ?? null,
candidateCount: response.candidates?.length ?? 0,
},
);
return (async function* () {
yield response;
})();
}
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.wrapped.countTokens(request);
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.wrapped.embedContent(request);
}
}

View file

@ -412,4 +412,84 @@ describe('handleFallback', () => {
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
});
});
describe('5xx error handling', () => {
let policyConfig: Config;
let policyHandler: Mock<FallbackModelHandler>;
beforeEach(() => {
vi.clearAllMocks();
policyHandler = vi.fn().mockResolvedValue('retry_once');
policyConfig = createMockConfig();
vi.mocked(policyConfig.getFallbackModelHandler).mockReturnValue(
policyHandler,
);
});
it.each([
[AuthType.USE_GEMINI, 'gemini-api-key'],
[AuthType.USE_VERTEX_AI, 'vertex-ai'],
[AuthType.LOGIN_WITH_GOOGLE, 'oauth-personal'],
[AuthType.COMPUTE_ADC, 'compute-default-credentials'],
[AuthType.GATEWAY, 'gateway'],
[undefined, 'undefined (no auth)'],
])(
'returns null without invoking fallback handler for 500 errors regardless of auth type (%s)',
async (authType, _label) => {
const serverError = new Error('Internal Server Error');
(serverError as NodeJS.ErrnoException & { status?: number }).status =
500;
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
authType,
serverError,
);
expect(result).toBeNull();
expect(policyHandler).not.toHaveBeenCalled();
},
);
it('returns null for other 5xx errors (e.g. 502, 503)', async () => {
for (const status of [502, 503, 504]) {
const serverError = new Error(`Server Error ${status}`);
(serverError as NodeJS.ErrnoException & { status?: number }).status =
status;
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
AuthType.LOGIN_WITH_GOOGLE,
serverError,
);
expect(result).toBeNull();
}
});
it('does not affect non-5xx errors (e.g. 429 quota) for API key auth', async () => {
const quotaError = new TerminalQuotaError('Quota exceeded', {
code: 429,
message: 'quota',
details: [],
});
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
AuthType.USE_GEMINI,
quotaError,
);
// 429 quota errors should always reach the handler regardless of auth type
expect(policyHandler).toHaveBeenCalled();
expect(result).toBe(true);
});
});
});

View file

@ -11,6 +11,7 @@ import {
} from '../utils/secure-browser-launcher.js';
import { debugLogger } from '../utils/debugLogger.js';
import { getErrorMessage } from '../utils/errors.js';
import { getErrorStatus } from '../utils/httpErrors.js';
import type { FallbackIntent, FallbackRecommendation } from './types.js';
import { classifyFailureKind } from '../availability/errorClassification.js';
import {
@ -28,6 +29,15 @@ export async function handleFallback(
authType?: string,
error?: unknown,
): Promise<string | boolean | null> {
// Server errors (5xx) should not trigger the fallback dialog. These are real
// server errors (e.g. from a custom proxy or a transient backend issue) that
// should propagate after retries are exhausted rather than prompting the user
// to switch models — a model switch cannot fix a server-side problem.
const errorStatus = getErrorStatus(error);
if (errorStatus !== undefined && errorStatus >= 500 && errorStatus < 600) {
return null;
}
const chain = resolvePolicyChain(config);
const { failedPolicy, candidates } = buildFallbackPolicyContext(
chain,

View file

@ -331,19 +331,7 @@ export async function retryWithBackoff<T>(
debugLogger.warn(
`Attempt ${attempt} failed${errorMessage ? `: ${errorMessage}` : ''}. Max attempts reached`,
);
// For 500 errors, only trigger the fallback dialog for OAuth-based auth
// (LOGIN_WITH_GOOGLE / COMPUTE_ADC). API key users — including those routing
// through a custom base URL like LiteLLM — receive actual server errors
// that should be propagated rather than spawning an interactive dialog that
// will never resolve.
const is500FallbackEligible =
authType === 'oauth-personal' ||
authType === 'compute-default-credentials';
if (
onPersistent429 &&
(classifiedError instanceof RetryableQuotaError ||
is500FallbackEligible)
) {
if (onPersistent429) {
try {
const fallbackModel = await onPersistent429(
authType,