fix(mcp): optimize MCP tool responses to prevent LLM context overflow (#25300)

* fix(mcp): optimize MCP tool responses to prevent LLM context overflow

* fix: add input validation for maxAggregationBuckets parsing
This commit is contained in:
Vishnu Jain 2026-01-15 18:38:00 +05:30 committed by GitHub
parent b5b07093ba
commit fb6eaef8a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 469 additions and 21 deletions

View file

@ -3,6 +3,8 @@ package org.openmetadata.mcp.tools;
import static org.openmetadata.schema.type.MetadataOperation.VIEW_ALL;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.schema.utils.JsonUtils;
@ -15,6 +17,35 @@ import org.openmetadata.service.security.policyevaluator.ResourceContext;
@Slf4j
public class GetEntityTool implements McpTool {
// Fields to exclude from response to optimize LLM context usage
// These fields are typically verbose and not useful for LLM understanding
private static final List<String> EXCLUDE_FIELDS =
List.of(
"version",
"updatedAt",
"updatedBy",
"changeDescription",
"followers",
"votes",
"totalVotes",
"usageSummary",
"lifeCycle",
"sourceHash",
"fqnParts",
"fqnHash",
"entityRelationship",
"processedLineage",
"upstreamLineage",
"changeSummary",
"tierSources",
"tagSources",
"descriptionSources",
"columnDescriptionStatus",
"descriptionStatus",
"embeddings",
"extension");
@Override
public Map<String, Object> execute(
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
@ -27,7 +58,24 @@ public class GetEntityTool implements McpTool {
new ResourceContext<>(entityType));
LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn);
String fields = "*";
return JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null));
Map<String, Object> entityData =
JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null));
// Clean response to optimize LLM context usage
return cleanEntityResponse(entityData);
}
/**
* Removes verbose fields from entity response to optimize LLM context. Keeps essential fields
* while removing metadata that adds little value for LLM understanding.
*/
private static Map<String, Object> cleanEntityResponse(Map<String, Object> entityData) {
if (entityData == null) {
return new HashMap<>();
}
Map<String, Object> cleaned = new HashMap<>(entityData);
EXCLUDE_FIELDS.forEach(cleaned::remove);
return cleaned;
}
@Override

View file

@ -14,6 +14,11 @@ import org.openmetadata.service.security.auth.CatalogSecurityContext;
@Slf4j
public class GetLineageTool implements McpTool {
// Defaults matching ai-platform GetLineageTool.kt for consistency
private static final int DEFAULT_DEPTH = 3;
// Maximum depth to prevent exponential response growth (lineage graphs can explode)
private static final int MAX_DEPTH = 10;
@Override
public Map<String, Object> execute(
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params) {
@ -23,8 +28,18 @@ public class GetLineageTool implements McpTool {
}
String entityType = (String) params.get("entityType");
String fqn = (String) params.get("fqn");
Integer upstreamDepth = (Integer) params.get("upstreamDepth");
Integer downstreamDepth = (Integer) params.get("downstreamDepth");
// Parse and validate upstream depth with default and max limits
int upstreamDepth = parseDepthParameter(params.get("upstreamDepth"), DEFAULT_DEPTH);
// Parse and validate downstream depth with default and max limits
int downstreamDepth = parseDepthParameter(params.get("downstreamDepth"), DEFAULT_DEPTH);
LOG.info(
"Getting lineage for entity type: {}, FQN: {}, upstreamDepth: {}, downstreamDepth: {}",
entityType,
fqn,
upstreamDepth,
downstreamDepth);
return JsonUtils.getMap(
new LineageRepository().getByName(entityType, fqn, upstreamDepth, downstreamDepth));
@ -33,6 +48,28 @@ public class GetLineageTool implements McpTool {
}
}
/**
* Parses depth parameter with default value and enforces maximum limit to prevent excessive
* response sizes that could overwhelm LLM context.
*/
private static int parseDepthParameter(Object depthObj, int defaultValue) {
if (depthObj == null) {
return Math.min(Math.max(defaultValue, 1), MAX_DEPTH);
}
int depth = defaultValue;
if (depthObj instanceof Number number) {
depth = number.intValue();
} else if (depthObj instanceof String string) {
try {
depth = Integer.parseInt(string);
} catch (NumberFormatException e) {
depth = defaultValue;
}
}
// Enforce maximum depth to prevent exponential response growth
return Math.min(Math.max(depth, 1), MAX_DEPTH);
}
@Override
public Map<String, Object> execute(
Authorizer authorizer,

View file

@ -6,6 +6,7 @@ import static org.openmetadata.service.security.DefaultAuthorizer.getSubjectCont
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.util.ArrayList;
@ -26,6 +27,11 @@ import org.openmetadata.service.security.policyevaluator.SubjectContext;
@Slf4j
public class SearchMetadataTool implements McpTool {
private static final int DEFAULT_MAX_AGGREGATION_BUCKETS = 10;
private static final int MAX_ALLOWED_AGGREGATION_BUCKETS = 50;
private static final int DESCRIPTION_MAX_LENGTH = 500;
private static final int DESCRIPTION_TRUNCATE_LENGTH = 450;
private static final List<String> ESSENTIAL_FIELDS_ONLY =
List.of(
"name",
@ -123,6 +129,34 @@ public class SearchMetadataTool implements McpTool {
}
}
// Parse includeAggregations - defaults to false to keep LLM context size manageable
boolean includeAggregations = false;
if (params.containsKey("includeAggregations")) {
Object aggObj = params.get("includeAggregations");
if (aggObj instanceof Boolean booleanValue) {
includeAggregations = booleanValue;
} else if (aggObj instanceof String) {
includeAggregations = "true".equals(aggObj);
}
}
// Parse maxAggregationBuckets - limit aggregation size to prevent context overflow
int maxAggregationBuckets = DEFAULT_MAX_AGGREGATION_BUCKETS;
if (params.containsKey("maxAggregationBuckets")) {
Object maxBucketsObj = params.get("maxAggregationBuckets");
if (maxBucketsObj instanceof Number number) {
maxAggregationBuckets =
Math.min(Math.max(number.intValue(), 1), MAX_ALLOWED_AGGREGATION_BUCKETS);
} else if (maxBucketsObj instanceof String string) {
try {
maxAggregationBuckets =
Math.min(Math.max(Integer.parseInt(string), 1), MAX_ALLOWED_AGGREGATION_BUCKETS);
} catch (NumberFormatException e) {
maxAggregationBuckets = DEFAULT_MAX_AGGREGATION_BUCKETS;
}
}
}
List<String> requestedFields = new ArrayList<>();
if (params.containsKey("fields")) {
String fieldsParam = (String) params.get("fields");
@ -148,8 +182,6 @@ public class SearchMetadataTool implements McpTool {
queryFilter = JsonUtils.pojoToJson(queryNode);
}
LOG.debug("Applied query filter to query: {}", queryFilter);
} else {
}
LOG.info(
@ -203,7 +235,8 @@ public class SearchMetadataTool implements McpTool {
searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class);
}
return buildEnhancedSearchResponse(searchResponse, query, size, requestedFields);
return buildEnhancedSearchResponse(
searchResponse, query, size, requestedFields, includeAggregations, maxAggregationBuckets);
}
@Override
@ -216,11 +249,14 @@ public class SearchMetadataTool implements McpTool {
"SearchMetadataTool does not support limits enforcement.");
}
public static Map<String, Object> buildEnhancedSearchResponse(
@VisibleForTesting
static Map<String, Object> buildEnhancedSearchResponse(
Map<String, Object> searchResponse,
String query,
int requestedLimit,
List<String> requestedFields) {
List<String> requestedFields,
boolean includeAggregations,
int maxAggregationBuckets) {
if (searchResponse == null) {
return createEmptyResponse();
}
@ -262,9 +298,23 @@ public class SearchMetadataTool implements McpTool {
result.put("returnedCount", cleanedResults.size());
result.put("query", query);
// Add aggregations if present in search response
if (searchResponse.containsKey("aggregations")) {
result.put("aggregations", searchResponse.get("aggregations"));
// Handle aggregations based on includeAggregations flag
if (includeAggregations && searchResponse.containsKey("aggregations")) {
Map<String, Object> rawAggregations = safeGetMap(searchResponse.get("aggregations"));
if (rawAggregations != null && !rawAggregations.isEmpty()) {
Map<String, Object> truncatedAggregations =
truncateAggregations(rawAggregations, maxAggregationBuckets);
result.put("aggregations", truncatedAggregations.get("aggregations"));
if (truncatedAggregations.containsKey("aggregationsTruncated")) {
result.put("aggregationsTruncated", true);
result.put(
"aggregationsMessage",
String.format(
"Aggregation buckets truncated to %d per field to optimize LLM context. "
+ "Set maxAggregationBuckets parameter for more (max %d).",
maxAggregationBuckets, MAX_ALLOWED_AGGREGATION_BUCKETS));
}
}
}
if (totalResults > requestedLimit) {
@ -297,11 +347,11 @@ public class SearchMetadataTool implements McpTool {
}
}
// Cleanup Description in case of huge description
// Truncate long descriptions to optimize LLM context usage
if (result.containsKey("description")) {
String description = (String) result.get("description");
if (description.length() > 3000) {
result.put("description", description.substring(0, 300) + "...");
Object descObj = result.get("description");
if (descObj instanceof String description && description.length() > DESCRIPTION_MAX_LENGTH) {
result.put("description", description.substring(0, DESCRIPTION_TRUNCATE_LENGTH) + "...");
}
}
return result;
@ -322,6 +372,64 @@ public class SearchMetadataTool implements McpTool {
return object;
}
/**
* Truncates aggregation buckets to prevent excessive response size that could overwhelm LLM
* context windows. Based on industry best practices, LLM performance degrades when context
* utilization exceeds 85%, so keeping responses concise is critical.
*
* @param aggregations Raw aggregations from search response
* @param maxBuckets Maximum number of buckets to keep per aggregation field
* @return Map containing truncated aggregations and a flag if any were truncated
*/
@SuppressWarnings("unchecked")
private static Map<String, Object> truncateAggregations(
Map<String, Object> aggregations, int maxBuckets) {
Map<String, Object> result = new HashMap<>();
Map<String, Object> truncatedAggs = new HashMap<>();
boolean anyTruncated = false;
for (Map.Entry<String, Object> entry : aggregations.entrySet()) {
String aggName = entry.getKey();
Object aggValue = entry.getValue();
if (aggValue instanceof Map) {
Map<String, Object> aggMap = (Map<String, Object>) aggValue;
// Check if this aggregation has buckets
if (aggMap.containsKey("buckets")) {
Object bucketsObj = aggMap.get("buckets");
if (bucketsObj instanceof List) {
List<Object> buckets = (List<Object>) bucketsObj;
if (buckets.size() > maxBuckets) {
// Truncate buckets
Map<String, Object> truncatedAgg = new HashMap<>(aggMap);
truncatedAgg.put("buckets", buckets.subList(0, maxBuckets));
truncatedAgg.put("_originalBucketCount", buckets.size());
truncatedAgg.put("_truncated", true);
truncatedAggs.put(aggName, truncatedAgg);
anyTruncated = true;
} else {
truncatedAggs.put(aggName, aggMap);
}
} else {
truncatedAggs.put(aggName, aggMap);
}
} else {
// Not a bucket aggregation (e.g., value_count, sum, etc.)
truncatedAggs.put(aggName, aggMap);
}
} else {
truncatedAggs.put(aggName, aggValue);
}
}
result.put("aggregations", truncatedAggs);
if (anyTruncated) {
result.put("aggregationsTruncated", true);
}
return result;
}
@SuppressWarnings("unchecked")
private static Map<String, Object> safeGetMap(Object obj) {
return (obj instanceof Map) ? (Map<String, Object>) obj : null;

View file

@ -37,6 +37,16 @@
"fields": {
"type": "string",
"description": "Comma-separated additional fields to include. Default returns: name, displayName, fullyQualifiedName, description, entityType, service, database, databaseSchema, serviceType, href, tags, owners, tier, tableType, columnNames.\n\nAdditional fields by entity type:\n- Table entities: columns, schemaDefinition, queries, upstreamLineage, entityRelationship\n- Topic entities: messageSchema, partitions, replicationFactor \n- Dashboard entities: charts, dataModels, project\n- Pipeline entities: tasks, pipelineUrl, scheduleInterval\n- All entities: createdAt, updatedAt, changeDescription, extension, domain, dataProducts, lifeCycle, sourceHash\n\nExample: 'columns,queries' for table column details and sample queries."
},
"includeAggregations": {
"type": "boolean",
"description": "Whether to include aggregation data (facets) in the response. Defaults to false to optimize LLM context size. Set to true only when you need aggregation statistics like counts by service type, owner, tags, etc. Aggregations can significantly increase response size.",
"default": false
},
"maxAggregationBuckets": {
"type": "integer",
"description": "Maximum number of aggregation buckets to return per field when includeAggregations is true. Limits response size to prevent context overflow. Default is 10, maximum allowed is 50.",
"default": 10
}
},
"required": ["queryFilter"],
@ -82,7 +92,7 @@
},
{
"name": "get_entity_details",
"description": "Get detailed information about a specific entity",
"description": "Get detailed information about a specific entity. Response is optimized for LLM context by excluding verbose metadata fields.",
"parameters": {
"description": "Fqn is the fully qualified name of the entity. Entity type could be table, topic etc.",
"type": "object",
@ -223,18 +233,18 @@
},
"upstreamDepth": {
"type": "integer",
"description": "Depth for reaching upstream entities. Default is 5."
"description": "Number of upstream hops to traverse. Default is 3, maximum is 10 to prevent excessive response size.",
"default": 3
},
"downstreamDepth": {
"type": "integer",
"description": "Depth for reaching downstream entities. Default is 5."
"description": "Number of downstream hops to traverse. Default is 3, maximum is 10 to prevent excessive response size.",
"default": 3
}
},
"required": [
"entityType",
"fqn",
"upstreamDepth",
"downstreamDepth"
"fqn"
]
}
}

View file

@ -0,0 +1,245 @@
package org.openmetadata.mcp.tools;
import static org.junit.jupiter.api.Assertions.*;
import java.util.*;
import org.junit.jupiter.api.Test;
/**
* Unit tests for SearchMetadataTool aggregation truncation functionality. Tests the fix for issue
* #25091: MCP server can return excessive data in aggregations, overwhelming LLM context.
*/
class SearchMetadataAggregationTest {
@Test
void testAggregationsExcludedByDefault() {
// Simulate search response with aggregations
Map<String, Object> searchResponse = createSearchResponseWithAggregations(20);
// Call with includeAggregations=false (default behavior)
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), false, 10);
// Aggregations should NOT be present
assertFalse(result.containsKey("aggregations"));
assertFalse(result.containsKey("aggregationsTruncated"));
}
@Test
void testAggregationsIncludedWhenRequested() {
Map<String, Object> searchResponse = createSearchResponseWithAggregations(5);
// Call with includeAggregations=true
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), true, 10);
// Aggregations should be present
assertTrue(result.containsKey("aggregations"));
assertNotNull(result.get("aggregations"));
}
@Test
void testAggregationsTruncatedWhenExceedingLimit() {
// Create response with 20 buckets
Map<String, Object> searchResponse = createSearchResponseWithAggregations(20);
// Call with maxAggregationBuckets=5
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), true, 5);
// Verify aggregations are present
assertTrue(result.containsKey("aggregations"));
@SuppressWarnings("unchecked")
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
// Verify serviceType aggregation is truncated
@SuppressWarnings("unchecked")
Map<String, Object> serviceTypeAgg = (Map<String, Object>) aggregations.get("serviceType");
assertNotNull(serviceTypeAgg);
@SuppressWarnings("unchecked")
List<Object> buckets = (List<Object>) serviceTypeAgg.get("buckets");
assertEquals(5, buckets.size(), "Buckets should be truncated to maxAggregationBuckets");
// Verify truncation metadata
assertEquals(20, serviceTypeAgg.get("_originalBucketCount"));
assertEquals(true, serviceTypeAgg.get("_truncated"));
// Verify global truncation flag
assertEquals(true, result.get("aggregationsTruncated"));
assertTrue(result.containsKey("aggregationsMessage"));
}
@Test
void testAggregationsNotTruncatedWhenUnderLimit() {
// Create response with 5 buckets
Map<String, Object> searchResponse = createSearchResponseWithAggregations(5);
// Call with maxAggregationBuckets=10 (more than bucket count)
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), true, 10);
// Verify aggregations are present
assertTrue(result.containsKey("aggregations"));
@SuppressWarnings("unchecked")
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
@SuppressWarnings("unchecked")
Map<String, Object> serviceTypeAgg = (Map<String, Object>) aggregations.get("serviceType");
@SuppressWarnings("unchecked")
List<Object> buckets = (List<Object>) serviceTypeAgg.get("buckets");
assertEquals(5, buckets.size(), "All buckets should be preserved");
// No truncation metadata should be present
assertFalse(serviceTypeAgg.containsKey("_truncated"));
assertFalse(result.containsKey("aggregationsTruncated"));
}
@Test
void testNonBucketAggregationsPreserved() {
// Create response with value_count aggregation (no buckets)
Map<String, Object> searchResponse = createSearchResponseWithValueCountAgg();
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), true, 5);
assertTrue(result.containsKey("aggregations"));
@SuppressWarnings("unchecked")
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
// Value count aggregation should be preserved as-is
@SuppressWarnings("unchecked")
Map<String, Object> totalCount = (Map<String, Object>) aggregations.get("total_count");
assertNotNull(totalCount);
assertEquals(100, totalCount.get("value"));
}
@Test
void testEmptyAggregationsHandled() {
Map<String, Object> searchResponse = new HashMap<>();
searchResponse.put("hits", createEmptyHits());
searchResponse.put("aggregations", new HashMap<>());
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), true, 10);
// Should not throw and should not add aggregations key for empty map
assertFalse(result.containsKey("aggregations"));
}
@Test
void testMultipleAggregationsTruncatedIndependently() {
// Create response with multiple aggregations of different sizes
Map<String, Object> searchResponse = createSearchResponseWithMultipleAggregations();
Map<String, Object> result =
SearchMetadataTool.buildEnhancedSearchResponse(
searchResponse, "test query", 10, Collections.emptyList(), true, 5);
@SuppressWarnings("unchecked")
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
// serviceType has 20 buckets - should be truncated
@SuppressWarnings("unchecked")
Map<String, Object> serviceTypeAgg = (Map<String, Object>) aggregations.get("serviceType");
@SuppressWarnings("unchecked")
List<Object> serviceTypeBuckets = (List<Object>) serviceTypeAgg.get("buckets");
assertEquals(5, serviceTypeBuckets.size());
assertEquals(true, serviceTypeAgg.get("_truncated"));
// owners has 3 buckets - should NOT be truncated
@SuppressWarnings("unchecked")
Map<String, Object> ownersAgg = (Map<String, Object>) aggregations.get("owners");
@SuppressWarnings("unchecked")
List<Object> ownersBuckets = (List<Object>) ownersAgg.get("buckets");
assertEquals(3, ownersBuckets.size());
assertFalse(ownersAgg.containsKey("_truncated"));
}
// Helper methods to create test data
private Map<String, Object> createSearchResponseWithAggregations(int bucketCount) {
Map<String, Object> response = new HashMap<>();
response.put("hits", createEmptyHits());
Map<String, Object> aggregations = new HashMap<>();
Map<String, Object> serviceTypeAgg = new HashMap<>();
List<Map<String, Object>> buckets = new ArrayList<>();
for (int i = 0; i < bucketCount; i++) {
Map<String, Object> bucket = new HashMap<>();
bucket.put("key", "Service" + i);
bucket.put("doc_count", 100 - i);
buckets.add(bucket);
}
serviceTypeAgg.put("buckets", buckets);
aggregations.put("serviceType", serviceTypeAgg);
response.put("aggregations", aggregations);
return response;
}
private Map<String, Object> createSearchResponseWithValueCountAgg() {
Map<String, Object> response = new HashMap<>();
response.put("hits", createEmptyHits());
Map<String, Object> aggregations = new HashMap<>();
Map<String, Object> totalCount = new HashMap<>();
totalCount.put("value", 100);
aggregations.put("total_count", totalCount);
response.put("aggregations", aggregations);
return response;
}
private Map<String, Object> createSearchResponseWithMultipleAggregations() {
Map<String, Object> response = new HashMap<>();
response.put("hits", createEmptyHits());
Map<String, Object> aggregations = new HashMap<>();
// Large aggregation (20 buckets)
Map<String, Object> serviceTypeAgg = new HashMap<>();
List<Map<String, Object>> serviceTypeBuckets = new ArrayList<>();
for (int i = 0; i < 20; i++) {
Map<String, Object> bucket = new HashMap<>();
bucket.put("key", "Service" + i);
bucket.put("doc_count", 100 - i);
serviceTypeBuckets.add(bucket);
}
serviceTypeAgg.put("buckets", serviceTypeBuckets);
aggregations.put("serviceType", serviceTypeAgg);
// Small aggregation (3 buckets)
Map<String, Object> ownersAgg = new HashMap<>();
List<Map<String, Object>> ownersBuckets = new ArrayList<>();
for (int i = 0; i < 3; i++) {
Map<String, Object> bucket = new HashMap<>();
bucket.put("key", "Owner" + i);
bucket.put("doc_count", 50 - i);
ownersBuckets.add(bucket);
}
ownersAgg.put("buckets", ownersBuckets);
aggregations.put("owners", ownersAgg);
response.put("aggregations", aggregations);
return response;
}
private Map<String, Object> createEmptyHits() {
Map<String, Object> hits = new HashMap<>();
hits.put("hits", Collections.emptyList());
Map<String, Object> total = new HashMap<>();
total.put("value", 0);
hits.put("total", total);
return hits;
}
}