fix(ai-builder): Expose credential account context to prevent prompt/credential mismatch (#28100)

This commit is contained in:
Albert Alises 2026-04-08 15:22:10 +02:00 committed by GitHub
parent b39fc5d612
commit c2fbf9d643
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 266 additions and 6 deletions

View file

@ -40,6 +40,7 @@
"linkedom": "^0.18.9",
"luxon": "catalog:",
"nanoid": "catalog:",
"p-limit": "^3.1.0",
"pdf-parse": "^1.1.1",
"turndown": "^7.2.0",
"zod": "catalog:",

View file

@ -92,5 +92,55 @@ describe('get-credential tool', () => {
);
expect(context.credentialService.get).toHaveBeenCalledWith('nonexistent');
});
it('includes accountIdentifier when getAccountContext is available', async () => {
const context = createMockContext();
const credential = makeCredentialDetail();
(context.credentialService.get as jest.Mock).mockResolvedValue(credential);
context.credentialService.getAccountContext = jest
.fn()
.mockResolvedValue({ accountIdentifier: 'user@example.com' });
const tool = createGetCredentialTool(context);
const result = (await tool.execute!({ credentialId: 'cred-123' }, {} as never)) as Record<
string,
unknown
>;
expect(context.credentialService.getAccountContext).toHaveBeenCalledWith('cred-123');
expect(result).toEqual({ ...credential, accountIdentifier: 'user@example.com' });
});
it('returns undefined accountIdentifier when getAccountContext returns no identifier', async () => {
const context = createMockContext();
const credential = makeCredentialDetail();
(context.credentialService.get as jest.Mock).mockResolvedValue(credential);
context.credentialService.getAccountContext = jest
.fn()
.mockResolvedValue({ accountIdentifier: undefined });
const tool = createGetCredentialTool(context);
const result = (await tool.execute!({ credentialId: 'cred-123' }, {} as never)) as Record<
string,
unknown
>;
expect(result).toEqual({ ...credential, accountIdentifier: undefined });
});
it('omits accountIdentifier when getAccountContext is not available', async () => {
const context = createMockContext();
const credential = makeCredentialDetail();
(context.credentialService.get as jest.Mock).mockResolvedValue(credential);
const tool = createGetCredentialTool(context);
const result = (await tool.execute!({ credentialId: 'cred-123' }, {} as never)) as Record<
string,
unknown
>;
expect(result).toEqual(credential);
expect(result).not.toHaveProperty('accountIdentifier');
});
});
});

View file

@ -70,6 +70,30 @@ describe('list-credentials tool', () => {
expect(result.total).toBe(1);
});
it('enriches credentials with accountIdentifier when getAccountContext is available', async () => {
const context = createMockContext();
const credentials = [
makeCredential({ id: 'cred-1', name: 'Gmail OAuth' }),
makeCredential({ id: 'cred-2', name: 'Slack API', type: 'slackApi' }),
];
(context.credentialService.list as jest.Mock).mockResolvedValue(credentials);
context.credentialService.getAccountContext = jest
.fn()
.mockResolvedValueOnce({ accountIdentifier: 'user@gmail.com' })
.mockResolvedValueOnce({ accountIdentifier: undefined });
const tool = createListCredentialsTool(context);
const result = (await tool.execute!({}, {} as never)) as {
credentials: Array<{ id: string; name: string; type: string; accountIdentifier?: string }>;
total: number;
};
expect(result.credentials).toHaveLength(2);
expect(result.credentials[0].accountIdentifier).toBe('user@gmail.com');
expect(result.credentials[1].accountIdentifier).toBeUndefined();
expect(result.total).toBe(2);
});
it('passes type filter to the list call', async () => {
const context = createMockContext();
(context.credentialService.list as jest.Mock).mockResolvedValue([]);

View file

@ -11,16 +11,24 @@ export function createGetCredentialTool(context: InstanceAiContext) {
return createTool({
id: 'get-credential',
description:
'Get credential metadata (name, type, node access). Never returns decrypted secrets.',
'Get credential metadata (name, type, node access, account identifier). Never returns decrypted secrets.',
inputSchema: getCredentialInputSchema,
outputSchema: z.object({
id: z.string(),
name: z.string(),
type: z.string(),
nodesWithAccess: z.array(z.object({ nodeType: z.string() })).optional(),
accountIdentifier: z.string().optional(),
}),
execute: async (inputData: z.infer<typeof getCredentialInputSchema>) => {
return await context.credentialService.get(inputData.credentialId);
const detail = await context.credentialService.get(inputData.credentialId);
if (!context.credentialService.getAccountContext) {
return detail;
}
const ctx = await context.credentialService.getAccountContext(inputData.credentialId);
return { ...detail, accountIdentifier: ctx?.accountIdentifier };
},
});
}

View file

@ -1,4 +1,5 @@
import { createTool } from '@mastra/core/tools';
import pLimit from 'p-limit';
import { z } from 'zod';
import type { InstanceAiContext } from '../../types';
@ -29,6 +30,7 @@ export function createListCredentialsTool(context: InstanceAiContext) {
id: 'list-credentials',
description:
'List credentials accessible to the current user. Never exposes secret data. ' +
'Returns a masked accountIdentifier (e.g. "al***@gmail.com") when available, so you know which account each credential is connected to. ' +
'Results are paginated — use limit/offset to page through large sets, or filter by type to narrow results.',
inputSchema: listCredentialsInputSchema,
outputSchema: z.object({
@ -37,6 +39,7 @@ export function createListCredentialsTool(context: InstanceAiContext) {
id: z.string(),
name: z.string(),
type: z.string(),
accountIdentifier: z.string().optional(),
}),
),
total: z.number().describe('Total number of credentials matching the query'),
@ -51,10 +54,30 @@ export function createListCredentialsTool(context: InstanceAiContext) {
const limit = inputData.limit ?? DEFAULT_LIMIT;
const page = allCredentials.slice(offset, offset + limit);
return {
credentials: page.map(({ id, name, type }) => ({ id, name, type })),
total,
};
if (!context.credentialService.getAccountContext) {
return {
credentials: page.map(({ id, name, type }) => ({ id, name, type })),
total,
};
}
const concurrencyLimit = pLimit(10);
const enriched = await Promise.all(
page.map(
async (cred) =>
await concurrencyLimit(async () => {
const ctx = await context.credentialService.getAccountContext!(cred.id);
return {
id: cred.id,
name: cred.name,
type: cred.type,
accountIdentifier: ctx?.accountIdentifier,
};
}),
),
);
return { credentials: enriched, total };
},
});
}

View file

@ -234,6 +234,7 @@ export interface InstanceAiCredentialService {
): CredentialFieldInfo[] | Promise<CredentialFieldInfo[]>;
/** Search available credential types by keyword. Returns matching types with display names. */
searchCredentialTypes?(query: string): Promise<CredentialTypeSearchResult[]>;
getAccountContext?(credentialId: string): Promise<{ accountIdentifier?: string }>;
}
export interface CredentialFieldInfo {

View file

@ -1035,6 +1035,66 @@ export class InstanceAiAdapterService {
return results;
},
async getAccountContext(credentialId: string) {
const credential = await credentialsFinderService.findCredentialForUser(
credentialId,
user,
['credential:read'],
);
if (!credential) {
return { accountIdentifier: undefined };
}
const mask = (id: string): string => {
const atIdx = id.indexOf('@');
if (atIdx > 0) {
const local = id.slice(0, atIdx);
const domain = id.slice(atIdx);
const keep = Math.min(2, local.length);
return local.slice(0, keep) + '***' + domain;
}
if (id.length <= 3) return id;
return id.slice(0, 2) + '***' + id.slice(-1);
};
try {
// Use redacted decryption first — accountIdentifier is not a
// password field so it survives redaction. This avoids exposing
// the full secret payload (tokens, keys) in memory.
const redacted = credentialsService.decrypt(credential, false);
if (typeof redacted.accountIdentifier === 'string' && redacted.accountIdentifier) {
return { accountIdentifier: mask(redacted.accountIdentifier) };
}
for (const key of ['email', 'user', 'username', 'account', 'serviceAccountEmail']) {
const value = redacted[key];
if (typeof value === 'string' && value) {
return { accountIdentifier: mask(value) };
}
}
// Fallback for legacy credentials: oauthTokenData is blanked by
// redaction, so we need unredacted access here only.
const raw = credentialsService.decrypt(credential, true);
const tokenData = raw.oauthTokenData;
if (tokenData && typeof tokenData === 'object') {
const { OauthService } = await import('@/oauth/oauth.service');
const identifier = OauthService.extractAccountIdentifier(
tokenData as Record<string, unknown>,
);
if (identifier) {
return { accountIdentifier: mask(identifier) };
}
}
return { accountIdentifier: undefined };
} catch {
return { accountIdentifier: undefined };
}
},
};
}

View file

@ -2520,4 +2520,50 @@ describe('OauthService', () => {
).rejects.toThrow('Request token failed');
});
});
describe('extractAccountIdentifier', () => {
it('returns email from direct token field', () => {
expect(
OauthService.extractAccountIdentifier({ email: 'user@example.com', access_token: 'tok' }),
).toBe('user@example.com');
});
it('returns login from direct token field (GitHub-style)', () => {
expect(OauthService.extractAccountIdentifier({ login: 'octocat', access_token: 'tok' })).toBe(
'octocat',
);
});
it('extracts email from JWT id_token', () => {
const payload = { email: 'user@gmail.com', sub: '123' };
const idToken = `header.${Buffer.from(JSON.stringify(payload)).toString('base64url')}.sig`;
expect(OauthService.extractAccountIdentifier({ id_token: idToken })).toBe('user@gmail.com');
});
it('extracts preferred_username from JWT id_token when no email', () => {
const payload = { preferred_username: 'admin@contoso.com', sub: '123' };
const idToken = `header.${Buffer.from(JSON.stringify(payload)).toString('base64url')}.sig`;
expect(OauthService.extractAccountIdentifier({ id_token: idToken })).toBe(
'admin@contoso.com',
);
});
it('returns undefined for token data without identifiers', () => {
expect(
OauthService.extractAccountIdentifier({ access_token: 'tok', refresh_token: 'ref' }),
).toBeUndefined();
});
it('handles malformed JWT gracefully', () => {
expect(OauthService.extractAccountIdentifier({ id_token: 'not.a.jwt' })).toBeUndefined();
});
it('prefers direct fields over id_token', () => {
const payload = { email: 'jwt@example.com' };
const idToken = `h.${Buffer.from(JSON.stringify(payload)).toString('base64url')}.s`;
expect(
OauthService.extractAccountIdentifier({ email: 'direct@example.com', id_token: idToken }),
).toBe('direct@example.com');
});
});
});

View file

@ -175,6 +175,15 @@ export class OauthService {
toUpdate: ICredentialDataDecryptedObject,
toDelete: string[] = [],
) {
if (toUpdate.oauthTokenData && typeof toUpdate.oauthTokenData === 'object') {
const identifier = OauthService.extractAccountIdentifier(
toUpdate.oauthTokenData as Record<string, unknown>,
);
if (identifier) {
toUpdate.accountIdentifier = identifier;
}
}
const credentials = new Credentials(credential, credential.type, credential.data);
credentials.updateData(toUpdate, toDelete);
await this.credentialsRepository.update(credential.id, {
@ -183,6 +192,41 @@ export class OauthService {
});
}
static extractAccountIdentifier(tokenData: Record<string, unknown>): string | undefined {
for (const key of ['email', 'login', 'username', 'user', 'account']) {
if (typeof tokenData[key] === 'string' && tokenData[key]) {
return tokenData[key];
}
}
if (typeof tokenData.id_token === 'string') {
const parts = tokenData.id_token.split('.');
if (parts.length === 3) {
try {
const payload: Record<string, unknown> = JSON.parse(
Buffer.from(parts[1], 'base64url').toString(),
);
if (typeof payload.email === 'string' && payload.email) {
return payload.email;
}
if (typeof payload.preferred_username === 'string' && payload.preferred_username) {
return payload.preferred_username;
}
} catch {}
}
}
const authedUser = tokenData.authed_user;
if (authedUser && typeof authedUser === 'object') {
const user = authedUser as Record<string, unknown>;
if (typeof user.id === 'string' && user.id) {
return user.id;
}
}
return undefined;
}
/** Get a credential without user check */
protected async getCredentialWithoutUser(
credentialId: string,

View file

@ -1608,6 +1608,9 @@ importers:
nanoid:
specifier: 'catalog:'
version: 3.3.8
p-limit:
specifier: ^3.1.0
version: 3.1.0
pdf-parse:
specifier: ^1.1.1
version: 1.1.1