diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 83a3fa5536..35d7879f96 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -162,59 +162,6 @@ describe('createContentGenerator', () => { ); }); - it('should wrap custom baseUrl generators to avoid SDK streaming hangs', async () => { - const mockGenerator = { - models: { - generateContent: vi.fn(), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - }, - } as unknown as GoogleGenAI; - vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); - - const generator = await createContentGenerator( - { - apiKey: 'test-api-key', - authType: AuthType.GATEWAY, - baseUrl: 'https://example.com', - }, - mockConfig, - ); - - expect(generator).toBeInstanceOf(LoggingContentGenerator); - }); - - it('should wrap generators when GOOGLE_GEMINI_BASE_URL is set', async () => { - const mockGenerator = { - models: { - generateContent: vi.fn(), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - }, - } as unknown as GoogleGenAI; - vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); - vi.stubEnv('GOOGLE_GEMINI_BASE_URL', 'https://litellm.example.com'); - - const generator = await createContentGenerator( - { - apiKey: 'test-api-key', - authType: AuthType.USE_GEMINI, - }, - mockConfig, - ); - - expect(GoogleGenAI).toHaveBeenCalledWith( - expect.objectContaining({ - httpOptions: expect.objectContaining({ - baseUrl: 'https://litellm.example.com', - }), - }), - ); - expect(generator).toBeInstanceOf(LoggingContentGenerator); - }); - it('should use standard User-Agent for a2a-server running outside VS Code', async () => { const mockConfig = { getModel: vi.fn().mockReturnValue('gemini-pro'), diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index aed1d800de..4fc56b59b4 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -26,7 +26,6 @@ import { FakeContentGenerator } from './fakeContentGenerator.js'; import { parseCustomHeaders } from '../utils/customHeaderUtils.js'; import { determineSurface } from '../utils/surface.js'; import { RecordingContentGenerator } from './recordingContentGenerator.js'; -import { CustomBaseUrlContentGenerator } from './customBaseUrlContentGenerator.js'; import { getVersion, resolveModel } from '../../index.js'; import type { LlmRole } from '../telemetry/llmRole.js'; @@ -102,20 +101,6 @@ export type ContentGeneratorConfig = { customHeaders?: Record; }; -function getConfiguredBaseUrl( - config: ContentGeneratorConfig, -): string | undefined { - if (config.baseUrl) { - return config.baseUrl; - } - - if (config.vertexai) { - return process.env['GOOGLE_VERTEX_BASE_URL'] || undefined; - } - - return process.env['GOOGLE_GEMINI_BASE_URL'] || undefined; -} - export async function createContentGeneratorConfig( config: Config, authType: AuthType | undefined, @@ -293,9 +278,8 @@ export async function createContentGenerator( headers: Record; } = { headers }; - const configuredBaseUrl = getConfiguredBaseUrl(config); - if (configuredBaseUrl) { - httpOptions.baseUrl = configuredBaseUrl; + if (config.baseUrl) { + httpOptions.baseUrl = config.baseUrl; } const googleGenAI = new GoogleGenAI({ @@ -304,10 +288,7 @@ export async function createContentGenerator( httpOptions, ...(apiVersionEnv && { apiVersion: apiVersionEnv }), }); - const contentGenerator = configuredBaseUrl - ? new CustomBaseUrlContentGenerator(googleGenAI.models) - : googleGenAI.models; - return new LoggingContentGenerator(contentGenerator, gcConfig); + return new LoggingContentGenerator(googleGenAI.models, gcConfig); } throw new Error( `Error creating contentGenerator: Unsupported authType: ${config.authType}`, diff --git a/packages/core/src/core/customBaseUrlContentGenerator.test.ts b/packages/core/src/core/customBaseUrlContentGenerator.test.ts deleted file mode 100644 index 3b80e87a43..0000000000 --- a/packages/core/src/core/customBaseUrlContentGenerator.test.ts +++ /dev/null @@ -1,81 +0,0 @@ -/** - * @license - * Copyright 2026 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, expect, it, vi } from 'vitest'; -import type { - GenerateContentParameters, - GenerateContentResponse, -} from '@google/genai'; -import { CustomBaseUrlContentGenerator } from './customBaseUrlContentGenerator.js'; -import { LlmRole } from '../telemetry/llmRole.js'; -import type { ContentGenerator } from './contentGenerator.js'; - -describe('CustomBaseUrlContentGenerator', () => { - it('adapts generateContent to a single-chunk stream', async () => { - const response = { - candidates: [ - { - content: { - parts: [{ text: 'custom base url response' }], - }, - finishReason: 'STOP', - }, - ], - } as unknown as GenerateContentResponse; - const wrapped = { - generateContent: vi.fn().mockResolvedValue(response), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - } as unknown as ContentGenerator; - const generator = new CustomBaseUrlContentGenerator(wrapped); - const request = { - model: 'test-model', - contents: 'test prompt', - } as GenerateContentParameters; - - const stream = await generator.generateContentStream( - request, - 'prompt-id', - LlmRole.MAIN, - ); - const chunks: GenerateContentResponse[] = []; - for await (const chunk of stream) { - chunks.push(chunk); - } - - expect(wrapped.generateContent).toHaveBeenCalledWith( - request, - 'prompt-id', - LlmRole.MAIN, - ); - expect(wrapped.generateContentStream).not.toHaveBeenCalled(); - expect(chunks).toEqual([response]); - }); - - it('propagates upstream errors from generateContent', async () => { - const error = new Error('Bad Request'); - const wrapped = { - generateContent: vi.fn().mockRejectedValue(error), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - } as unknown as ContentGenerator; - const generator = new CustomBaseUrlContentGenerator(wrapped); - - await expect( - generator.generateContentStream( - { - model: 'bad-model', - contents: 'test prompt', - } as GenerateContentParameters, - 'prompt-id', - LlmRole.MAIN, - ), - ).rejects.toThrow(error); - expect(wrapped.generateContentStream).not.toHaveBeenCalled(); - }); -}); diff --git a/packages/core/src/core/customBaseUrlContentGenerator.ts b/packages/core/src/core/customBaseUrlContentGenerator.ts deleted file mode 100644 index 3b5d7b9eaf..0000000000 --- a/packages/core/src/core/customBaseUrlContentGenerator.ts +++ /dev/null @@ -1,112 +0,0 @@ -/** - * @license - * Copyright 2026 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import type { - CountTokensParameters, - CountTokensResponse, - EmbedContentParameters, - EmbedContentResponse, - GenerateContentParameters, - GenerateContentResponse, -} from '@google/genai'; -import type { LlmRole } from '../telemetry/llmRole.js'; -import { debugLogger } from '../utils/debugLogger.js'; -import type { ContentGenerator } from './contentGenerator.js'; - -/** - * The SDK streaming path can hang when requests are routed through a custom - * Gemini-compatible base URL and the upstream returns an immediate HTTP error. - * Route those calls through the non-streaming API and adapt the response into a - * single-chunk async generator so upper layers can keep the same interface. - */ -export class CustomBaseUrlContentGenerator implements ContentGenerator { - constructor(private readonly wrapped: ContentGenerator) {} - - get userTier() { - return this.wrapped.userTier; - } - - get userTierName() { - return this.wrapped.userTierName; - } - - get paidTier() { - return this.wrapped.paidTier; - } - - async generateContent( - request: GenerateContentParameters, - userPromptId: string, - role: LlmRole, - ): Promise { - return this.wrapped.generateContent(request, userPromptId, role); - } - - async generateContentStream( - request: GenerateContentParameters, - userPromptId: string, - role: LlmRole, - ): Promise> { - const startTime = Date.now(); - debugLogger.debug( - '[DEBUG] [CustomBaseUrlContentGenerator] Falling back from generateContentStream to generateContent', - { - model: request.model, - promptId: userPromptId, - role, - }, - ); - - let response: GenerateContentResponse; - try { - response = await this.wrapped.generateContent( - request, - userPromptId, - role, - ); - } catch (error) { - debugLogger.debug( - '[DEBUG] [CustomBaseUrlContentGenerator] Fallback generateContent failed', - { - model: request.model, - promptId: userPromptId, - role, - durationMs: Date.now() - startTime, - error: error instanceof Error ? error.message : String(error), - }, - ); - throw error; - } - - debugLogger.debug( - '[DEBUG] [CustomBaseUrlContentGenerator] Fallback generateContent succeeded', - { - model: request.model, - promptId: userPromptId, - role, - durationMs: Date.now() - startTime, - responseId: response.responseId ?? null, - candidateCount: response.candidates?.length ?? 0, - }, - ); - - return (async function* () { - yield response; - })(); - } - - async countTokens( - request: CountTokensParameters, - ): Promise { - return this.wrapped.countTokens(request); - } - - async embedContent( - request: EmbedContentParameters, - ): Promise { - return this.wrapped.embedContent(request); - } -} diff --git a/packages/core/src/fallback/handler.test.ts b/packages/core/src/fallback/handler.test.ts index 698a5d7cfb..a83d597449 100644 --- a/packages/core/src/fallback/handler.test.ts +++ b/packages/core/src/fallback/handler.test.ts @@ -412,4 +412,84 @@ describe('handleFallback', () => { expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled(); }); }); + + describe('5xx error handling', () => { + let policyConfig: Config; + let policyHandler: Mock; + + beforeEach(() => { + vi.clearAllMocks(); + policyHandler = vi.fn().mockResolvedValue('retry_once'); + policyConfig = createMockConfig(); + vi.mocked(policyConfig.getFallbackModelHandler).mockReturnValue( + policyHandler, + ); + }); + + it.each([ + [AuthType.USE_GEMINI, 'gemini-api-key'], + [AuthType.USE_VERTEX_AI, 'vertex-ai'], + [AuthType.LOGIN_WITH_GOOGLE, 'oauth-personal'], + [AuthType.COMPUTE_ADC, 'compute-default-credentials'], + [AuthType.GATEWAY, 'gateway'], + [undefined, 'undefined (no auth)'], + ])( + 'returns null without invoking fallback handler for 500 errors regardless of auth type (%s)', + async (authType, _label) => { + const serverError = new Error('Internal Server Error'); + (serverError as NodeJS.ErrnoException & { status?: number }).status = + 500; + + const result = await handleFallback( + policyConfig, + MOCK_PRO_MODEL, + authType, + serverError, + ); + + expect(result).toBeNull(); + expect(policyHandler).not.toHaveBeenCalled(); + }, + ); + + it('returns null for other 5xx errors (e.g. 502, 503)', async () => { + for (const status of [502, 503, 504]) { + const serverError = new Error(`Server Error ${status}`); + (serverError as NodeJS.ErrnoException & { status?: number }).status = + status; + + const result = await handleFallback( + policyConfig, + MOCK_PRO_MODEL, + AuthType.LOGIN_WITH_GOOGLE, + serverError, + ); + + expect(result).toBeNull(); + } + }); + + it('does not affect non-5xx errors (e.g. 429 quota) for API key auth', async () => { + const quotaError = new TerminalQuotaError('Quota exceeded', { + code: 429, + message: 'quota', + details: [], + }); + + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); + + const result = await handleFallback( + policyConfig, + MOCK_PRO_MODEL, + AuthType.USE_GEMINI, + quotaError, + ); + + // 429 quota errors should always reach the handler regardless of auth type + expect(policyHandler).toHaveBeenCalled(); + expect(result).toBe(true); + }); + }); }); diff --git a/packages/core/src/fallback/handler.ts b/packages/core/src/fallback/handler.ts index f216f9216c..5de5858bd7 100644 --- a/packages/core/src/fallback/handler.ts +++ b/packages/core/src/fallback/handler.ts @@ -11,6 +11,7 @@ import { } from '../utils/secure-browser-launcher.js'; import { debugLogger } from '../utils/debugLogger.js'; import { getErrorMessage } from '../utils/errors.js'; +import { getErrorStatus } from '../utils/httpErrors.js'; import type { FallbackIntent, FallbackRecommendation } from './types.js'; import { classifyFailureKind } from '../availability/errorClassification.js'; import { @@ -28,6 +29,15 @@ export async function handleFallback( authType?: string, error?: unknown, ): Promise { + // Server errors (5xx) should not trigger the fallback dialog. These are real + // server errors (e.g. from a custom proxy or a transient backend issue) that + // should propagate after retries are exhausted rather than prompting the user + // to switch models — a model switch cannot fix a server-side problem. + const errorStatus = getErrorStatus(error); + if (errorStatus !== undefined && errorStatus >= 500 && errorStatus < 600) { + return null; + } + const chain = resolvePolicyChain(config); const { failedPolicy, candidates } = buildFallbackPolicyContext( chain, diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 4b6e47bd37..46765216b9 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -331,19 +331,7 @@ export async function retryWithBackoff( debugLogger.warn( `Attempt ${attempt} failed${errorMessage ? `: ${errorMessage}` : ''}. Max attempts reached`, ); - // For 500 errors, only trigger the fallback dialog for OAuth-based auth - // (LOGIN_WITH_GOOGLE / COMPUTE_ADC). API key users — including those routing - // through a custom base URL like LiteLLM — receive actual server errors - // that should be propagated rather than spawning an interactive dialog that - // will never resolve. - const is500FallbackEligible = - authType === 'oauth-personal' || - authType === 'compute-default-credentials'; - if ( - onPersistent429 && - (classifiedError instanceof RetryableQuotaError || - is500FallbackEligible) - ) { + if (onPersistent429) { try { const fallbackModel = await onPersistent429( authType,