mirror of
https://github.com/google-gemini/gemini-cli
synced 2026-04-21 13:37:17 +00:00
refactor(core): move 5xx fallback guard from retry layer to policy layer
Revert the authType check in retryWithBackoff and instead add an early-return guard in handleFallback for 5xx errors. A model switch cannot fix a server error, so the fallback dialog should not be offered — the error is propagated to the user after retries are exhausted. This keeps retry.ts free of auth-type concerns and places the policy decision where it belongs (the fallback handler).
This commit is contained in:
parent
88d8eab578
commit
b1685fefb3
7 changed files with 94 additions and 281 deletions
|
|
@ -162,59 +162,6 @@ describe('createContentGenerator', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should wrap custom baseUrl generators to avoid SDK streaming hangs', async () => {
|
||||
const mockGenerator = {
|
||||
models: {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
},
|
||||
} as unknown as GoogleGenAI;
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.GATEWAY,
|
||||
baseUrl: 'https://example.com',
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(generator).toBeInstanceOf(LoggingContentGenerator);
|
||||
});
|
||||
|
||||
it('should wrap generators when GOOGLE_GEMINI_BASE_URL is set', async () => {
|
||||
const mockGenerator = {
|
||||
models: {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
},
|
||||
} as unknown as GoogleGenAI;
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
vi.stubEnv('GOOGLE_GEMINI_BASE_URL', 'https://litellm.example.com');
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(GoogleGenAI).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
httpOptions: expect.objectContaining({
|
||||
baseUrl: 'https://litellm.example.com',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(generator).toBeInstanceOf(LoggingContentGenerator);
|
||||
});
|
||||
|
||||
it('should use standard User-Agent for a2a-server running outside VS Code', async () => {
|
||||
const mockConfig = {
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ import { FakeContentGenerator } from './fakeContentGenerator.js';
|
|||
import { parseCustomHeaders } from '../utils/customHeaderUtils.js';
|
||||
import { determineSurface } from '../utils/surface.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
import { CustomBaseUrlContentGenerator } from './customBaseUrlContentGenerator.js';
|
||||
import { getVersion, resolveModel } from '../../index.js';
|
||||
import type { LlmRole } from '../telemetry/llmRole.js';
|
||||
|
||||
|
|
@ -102,20 +101,6 @@ export type ContentGeneratorConfig = {
|
|||
customHeaders?: Record<string, string>;
|
||||
};
|
||||
|
||||
function getConfiguredBaseUrl(
|
||||
config: ContentGeneratorConfig,
|
||||
): string | undefined {
|
||||
if (config.baseUrl) {
|
||||
return config.baseUrl;
|
||||
}
|
||||
|
||||
if (config.vertexai) {
|
||||
return process.env['GOOGLE_VERTEX_BASE_URL'] || undefined;
|
||||
}
|
||||
|
||||
return process.env['GOOGLE_GEMINI_BASE_URL'] || undefined;
|
||||
}
|
||||
|
||||
export async function createContentGeneratorConfig(
|
||||
config: Config,
|
||||
authType: AuthType | undefined,
|
||||
|
|
@ -293,9 +278,8 @@ export async function createContentGenerator(
|
|||
headers: Record<string, string>;
|
||||
} = { headers };
|
||||
|
||||
const configuredBaseUrl = getConfiguredBaseUrl(config);
|
||||
if (configuredBaseUrl) {
|
||||
httpOptions.baseUrl = configuredBaseUrl;
|
||||
if (config.baseUrl) {
|
||||
httpOptions.baseUrl = config.baseUrl;
|
||||
}
|
||||
|
||||
const googleGenAI = new GoogleGenAI({
|
||||
|
|
@ -304,10 +288,7 @@ export async function createContentGenerator(
|
|||
httpOptions,
|
||||
...(apiVersionEnv && { apiVersion: apiVersionEnv }),
|
||||
});
|
||||
const contentGenerator = configuredBaseUrl
|
||||
? new CustomBaseUrlContentGenerator(googleGenAI.models)
|
||||
: googleGenAI.models;
|
||||
return new LoggingContentGenerator(contentGenerator, gcConfig);
|
||||
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
|
||||
}
|
||||
throw new Error(
|
||||
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { CustomBaseUrlContentGenerator } from './customBaseUrlContentGenerator.js';
|
||||
import { LlmRole } from '../telemetry/llmRole.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
|
||||
describe('CustomBaseUrlContentGenerator', () => {
|
||||
it('adapts generateContent to a single-chunk stream', async () => {
|
||||
const response = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'custom base url response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
const wrapped = {
|
||||
generateContent: vi.fn().mockResolvedValue(response),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
const generator = new CustomBaseUrlContentGenerator(wrapped);
|
||||
const request = {
|
||||
model: 'test-model',
|
||||
contents: 'test prompt',
|
||||
} as GenerateContentParameters;
|
||||
|
||||
const stream = await generator.generateContentStream(
|
||||
request,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
const chunks: GenerateContentResponse[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(wrapped.generateContent).toHaveBeenCalledWith(
|
||||
request,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
expect(wrapped.generateContentStream).not.toHaveBeenCalled();
|
||||
expect(chunks).toEqual([response]);
|
||||
});
|
||||
|
||||
it('propagates upstream errors from generateContent', async () => {
|
||||
const error = new Error('Bad Request');
|
||||
const wrapped = {
|
||||
generateContent: vi.fn().mockRejectedValue(error),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
const generator = new CustomBaseUrlContentGenerator(wrapped);
|
||||
|
||||
await expect(
|
||||
generator.generateContentStream(
|
||||
{
|
||||
model: 'bad-model',
|
||||
contents: 'test prompt',
|
||||
} as GenerateContentParameters,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
),
|
||||
).rejects.toThrow(error);
|
||||
expect(wrapped.generateContentStream).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { LlmRole } from '../telemetry/llmRole.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
|
||||
/**
|
||||
* The SDK streaming path can hang when requests are routed through a custom
|
||||
* Gemini-compatible base URL and the upstream returns an immediate HTTP error.
|
||||
* Route those calls through the non-streaming API and adapt the response into a
|
||||
* single-chunk async generator so upper layers can keep the same interface.
|
||||
*/
|
||||
export class CustomBaseUrlContentGenerator implements ContentGenerator {
|
||||
constructor(private readonly wrapped: ContentGenerator) {}
|
||||
|
||||
get userTier() {
|
||||
return this.wrapped.userTier;
|
||||
}
|
||||
|
||||
get userTierName() {
|
||||
return this.wrapped.userTierName;
|
||||
}
|
||||
|
||||
get paidTier() {
|
||||
return this.wrapped.paidTier;
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
role: LlmRole,
|
||||
): Promise<GenerateContentResponse> {
|
||||
return this.wrapped.generateContent(request, userPromptId, role);
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
role: LlmRole,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const startTime = Date.now();
|
||||
debugLogger.debug(
|
||||
'[DEBUG] [CustomBaseUrlContentGenerator] Falling back from generateContentStream to generateContent',
|
||||
{
|
||||
model: request.model,
|
||||
promptId: userPromptId,
|
||||
role,
|
||||
},
|
||||
);
|
||||
|
||||
let response: GenerateContentResponse;
|
||||
try {
|
||||
response = await this.wrapped.generateContent(
|
||||
request,
|
||||
userPromptId,
|
||||
role,
|
||||
);
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
'[DEBUG] [CustomBaseUrlContentGenerator] Fallback generateContent failed',
|
||||
{
|
||||
model: request.model,
|
||||
promptId: userPromptId,
|
||||
role,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
},
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
debugLogger.debug(
|
||||
'[DEBUG] [CustomBaseUrlContentGenerator] Fallback generateContent succeeded',
|
||||
{
|
||||
model: request.model,
|
||||
promptId: userPromptId,
|
||||
role,
|
||||
durationMs: Date.now() - startTime,
|
||||
responseId: response.responseId ?? null,
|
||||
candidateCount: response.candidates?.length ?? 0,
|
||||
},
|
||||
);
|
||||
|
||||
return (async function* () {
|
||||
yield response;
|
||||
})();
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.wrapped.countTokens(request);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.wrapped.embedContent(request);
|
||||
}
|
||||
}
|
||||
|
|
@ -412,4 +412,84 @@ describe('handleFallback', () => {
|
|||
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('5xx error handling', () => {
|
||||
let policyConfig: Config;
|
||||
let policyHandler: Mock<FallbackModelHandler>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
policyHandler = vi.fn().mockResolvedValue('retry_once');
|
||||
policyConfig = createMockConfig();
|
||||
vi.mocked(policyConfig.getFallbackModelHandler).mockReturnValue(
|
||||
policyHandler,
|
||||
);
|
||||
});
|
||||
|
||||
it.each([
|
||||
[AuthType.USE_GEMINI, 'gemini-api-key'],
|
||||
[AuthType.USE_VERTEX_AI, 'vertex-ai'],
|
||||
[AuthType.LOGIN_WITH_GOOGLE, 'oauth-personal'],
|
||||
[AuthType.COMPUTE_ADC, 'compute-default-credentials'],
|
||||
[AuthType.GATEWAY, 'gateway'],
|
||||
[undefined, 'undefined (no auth)'],
|
||||
])(
|
||||
'returns null without invoking fallback handler for 500 errors regardless of auth type (%s)',
|
||||
async (authType, _label) => {
|
||||
const serverError = new Error('Internal Server Error');
|
||||
(serverError as NodeJS.ErrnoException & { status?: number }).status =
|
||||
500;
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
authType,
|
||||
serverError,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(policyHandler).not.toHaveBeenCalled();
|
||||
},
|
||||
);
|
||||
|
||||
it('returns null for other 5xx errors (e.g. 502, 503)', async () => {
|
||||
for (const status of [502, 503, 504]) {
|
||||
const serverError = new Error(`Server Error ${status}`);
|
||||
(serverError as NodeJS.ErrnoException & { status?: number }).status =
|
||||
status;
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
serverError,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
}
|
||||
});
|
||||
|
||||
it('does not affect non-5xx errors (e.g. 429 quota) for API key auth', async () => {
|
||||
const quotaError = new TerminalQuotaError('Quota exceeded', {
|
||||
code: 429,
|
||||
message: 'quota',
|
||||
details: [],
|
||||
});
|
||||
|
||||
vi.mocked(policyConfig.getModel).mockReturnValue(
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
);
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AuthType.USE_GEMINI,
|
||||
quotaError,
|
||||
);
|
||||
|
||||
// 429 quota errors should always reach the handler regardless of auth type
|
||||
expect(policyHandler).toHaveBeenCalled();
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import {
|
|||
} from '../utils/secure-browser-launcher.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { getErrorStatus } from '../utils/httpErrors.js';
|
||||
import type { FallbackIntent, FallbackRecommendation } from './types.js';
|
||||
import { classifyFailureKind } from '../availability/errorClassification.js';
|
||||
import {
|
||||
|
|
@ -28,6 +29,15 @@ export async function handleFallback(
|
|||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | boolean | null> {
|
||||
// Server errors (5xx) should not trigger the fallback dialog. These are real
|
||||
// server errors (e.g. from a custom proxy or a transient backend issue) that
|
||||
// should propagate after retries are exhausted rather than prompting the user
|
||||
// to switch models — a model switch cannot fix a server-side problem.
|
||||
const errorStatus = getErrorStatus(error);
|
||||
if (errorStatus !== undefined && errorStatus >= 500 && errorStatus < 600) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const chain = resolvePolicyChain(config);
|
||||
const { failedPolicy, candidates } = buildFallbackPolicyContext(
|
||||
chain,
|
||||
|
|
|
|||
|
|
@ -331,19 +331,7 @@ export async function retryWithBackoff<T>(
|
|||
debugLogger.warn(
|
||||
`Attempt ${attempt} failed${errorMessage ? `: ${errorMessage}` : ''}. Max attempts reached`,
|
||||
);
|
||||
// For 500 errors, only trigger the fallback dialog for OAuth-based auth
|
||||
// (LOGIN_WITH_GOOGLE / COMPUTE_ADC). API key users — including those routing
|
||||
// through a custom base URL like LiteLLM — receive actual server errors
|
||||
// that should be propagated rather than spawning an interactive dialog that
|
||||
// will never resolve.
|
||||
const is500FallbackEligible =
|
||||
authType === 'oauth-personal' ||
|
||||
authType === 'compute-default-credentials';
|
||||
if (
|
||||
onPersistent429 &&
|
||||
(classifiedError instanceof RetryableQuotaError ||
|
||||
is500FallbackEligible)
|
||||
) {
|
||||
if (onPersistent429) {
|
||||
try {
|
||||
const fallbackModel = await onPersistent429(
|
||||
authType,
|
||||
|
|
|
|||
Loading…
Reference in a new issue