mirror of
https://github.com/open-metadata/OpenMetadata
synced 2026-05-24 09:39:11 +00:00
fix(mcp): optimize MCP tool responses to prevent LLM context overflow (#25300)
* fix(mcp): optimize MCP tool responses to prevent LLM context overflow * fix: add input validation for maxAggregationBuckets parsing
This commit is contained in:
parent
b5b07093ba
commit
fb6eaef8a5
5 changed files with 469 additions and 21 deletions
|
|
@ -3,6 +3,8 @@ package org.openmetadata.mcp.tools;
|
|||
import static org.openmetadata.schema.type.MetadataOperation.VIEW_ALL;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.openmetadata.schema.utils.JsonUtils;
|
||||
|
|
@ -15,6 +17,35 @@ import org.openmetadata.service.security.policyevaluator.ResourceContext;
|
|||
|
||||
@Slf4j
|
||||
public class GetEntityTool implements McpTool {
|
||||
|
||||
// Fields to exclude from response to optimize LLM context usage
|
||||
// These fields are typically verbose and not useful for LLM understanding
|
||||
private static final List<String> EXCLUDE_FIELDS =
|
||||
List.of(
|
||||
"version",
|
||||
"updatedAt",
|
||||
"updatedBy",
|
||||
"changeDescription",
|
||||
"followers",
|
||||
"votes",
|
||||
"totalVotes",
|
||||
"usageSummary",
|
||||
"lifeCycle",
|
||||
"sourceHash",
|
||||
"fqnParts",
|
||||
"fqnHash",
|
||||
"entityRelationship",
|
||||
"processedLineage",
|
||||
"upstreamLineage",
|
||||
"changeSummary",
|
||||
"tierSources",
|
||||
"tagSources",
|
||||
"descriptionSources",
|
||||
"columnDescriptionStatus",
|
||||
"descriptionStatus",
|
||||
"embeddings",
|
||||
"extension");
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params)
|
||||
|
|
@ -27,7 +58,24 @@ public class GetEntityTool implements McpTool {
|
|||
new ResourceContext<>(entityType));
|
||||
LOG.info("Getting details for entity type: {}, FQN: {}", entityType, fqn);
|
||||
String fields = "*";
|
||||
return JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null));
|
||||
Map<String, Object> entityData =
|
||||
JsonUtils.getMap(Entity.getEntityByName(entityType, fqn, fields, null));
|
||||
|
||||
// Clean response to optimize LLM context usage
|
||||
return cleanEntityResponse(entityData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes verbose fields from entity response to optimize LLM context. Keeps essential fields
|
||||
* while removing metadata that adds little value for LLM understanding.
|
||||
*/
|
||||
private static Map<String, Object> cleanEntityResponse(Map<String, Object> entityData) {
|
||||
if (entityData == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
Map<String, Object> cleaned = new HashMap<>(entityData);
|
||||
EXCLUDE_FIELDS.forEach(cleaned::remove);
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ import org.openmetadata.service.security.auth.CatalogSecurityContext;
|
|||
@Slf4j
|
||||
public class GetLineageTool implements McpTool {
|
||||
|
||||
// Defaults matching ai-platform GetLineageTool.kt for consistency
|
||||
private static final int DEFAULT_DEPTH = 3;
|
||||
// Maximum depth to prevent exponential response growth (lineage graphs can explode)
|
||||
private static final int MAX_DEPTH = 10;
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer, CatalogSecurityContext securityContext, Map<String, Object> params) {
|
||||
|
|
@ -23,8 +28,18 @@ public class GetLineageTool implements McpTool {
|
|||
}
|
||||
String entityType = (String) params.get("entityType");
|
||||
String fqn = (String) params.get("fqn");
|
||||
Integer upstreamDepth = (Integer) params.get("upstreamDepth");
|
||||
Integer downstreamDepth = (Integer) params.get("downstreamDepth");
|
||||
|
||||
// Parse and validate upstream depth with default and max limits
|
||||
int upstreamDepth = parseDepthParameter(params.get("upstreamDepth"), DEFAULT_DEPTH);
|
||||
// Parse and validate downstream depth with default and max limits
|
||||
int downstreamDepth = parseDepthParameter(params.get("downstreamDepth"), DEFAULT_DEPTH);
|
||||
|
||||
LOG.info(
|
||||
"Getting lineage for entity type: {}, FQN: {}, upstreamDepth: {}, downstreamDepth: {}",
|
||||
entityType,
|
||||
fqn,
|
||||
upstreamDepth,
|
||||
downstreamDepth);
|
||||
|
||||
return JsonUtils.getMap(
|
||||
new LineageRepository().getByName(entityType, fqn, upstreamDepth, downstreamDepth));
|
||||
|
|
@ -33,6 +48,28 @@ public class GetLineageTool implements McpTool {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses depth parameter with default value and enforces maximum limit to prevent excessive
|
||||
* response sizes that could overwhelm LLM context.
|
||||
*/
|
||||
private static int parseDepthParameter(Object depthObj, int defaultValue) {
|
||||
if (depthObj == null) {
|
||||
return Math.min(Math.max(defaultValue, 1), MAX_DEPTH);
|
||||
}
|
||||
int depth = defaultValue;
|
||||
if (depthObj instanceof Number number) {
|
||||
depth = number.intValue();
|
||||
} else if (depthObj instanceof String string) {
|
||||
try {
|
||||
depth = Integer.parseInt(string);
|
||||
} catch (NumberFormatException e) {
|
||||
depth = defaultValue;
|
||||
}
|
||||
}
|
||||
// Enforce maximum depth to prevent exponential response growth
|
||||
return Math.min(Math.max(depth, 1), MAX_DEPTH);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(
|
||||
Authorizer authorizer,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import static org.openmetadata.service.security.DefaultAuthorizer.getSubjectCont
|
|||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import jakarta.ws.rs.core.Response;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
|
@ -26,6 +27,11 @@ import org.openmetadata.service.security.policyevaluator.SubjectContext;
|
|||
@Slf4j
|
||||
public class SearchMetadataTool implements McpTool {
|
||||
|
||||
private static final int DEFAULT_MAX_AGGREGATION_BUCKETS = 10;
|
||||
private static final int MAX_ALLOWED_AGGREGATION_BUCKETS = 50;
|
||||
private static final int DESCRIPTION_MAX_LENGTH = 500;
|
||||
private static final int DESCRIPTION_TRUNCATE_LENGTH = 450;
|
||||
|
||||
private static final List<String> ESSENTIAL_FIELDS_ONLY =
|
||||
List.of(
|
||||
"name",
|
||||
|
|
@ -123,6 +129,34 @@ public class SearchMetadataTool implements McpTool {
|
|||
}
|
||||
}
|
||||
|
||||
// Parse includeAggregations - defaults to false to keep LLM context size manageable
|
||||
boolean includeAggregations = false;
|
||||
if (params.containsKey("includeAggregations")) {
|
||||
Object aggObj = params.get("includeAggregations");
|
||||
if (aggObj instanceof Boolean booleanValue) {
|
||||
includeAggregations = booleanValue;
|
||||
} else if (aggObj instanceof String) {
|
||||
includeAggregations = "true".equals(aggObj);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse maxAggregationBuckets - limit aggregation size to prevent context overflow
|
||||
int maxAggregationBuckets = DEFAULT_MAX_AGGREGATION_BUCKETS;
|
||||
if (params.containsKey("maxAggregationBuckets")) {
|
||||
Object maxBucketsObj = params.get("maxAggregationBuckets");
|
||||
if (maxBucketsObj instanceof Number number) {
|
||||
maxAggregationBuckets =
|
||||
Math.min(Math.max(number.intValue(), 1), MAX_ALLOWED_AGGREGATION_BUCKETS);
|
||||
} else if (maxBucketsObj instanceof String string) {
|
||||
try {
|
||||
maxAggregationBuckets =
|
||||
Math.min(Math.max(Integer.parseInt(string), 1), MAX_ALLOWED_AGGREGATION_BUCKETS);
|
||||
} catch (NumberFormatException e) {
|
||||
maxAggregationBuckets = DEFAULT_MAX_AGGREGATION_BUCKETS;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<String> requestedFields = new ArrayList<>();
|
||||
if (params.containsKey("fields")) {
|
||||
String fieldsParam = (String) params.get("fields");
|
||||
|
|
@ -148,8 +182,6 @@ public class SearchMetadataTool implements McpTool {
|
|||
queryFilter = JsonUtils.pojoToJson(queryNode);
|
||||
}
|
||||
LOG.debug("Applied query filter to query: {}", queryFilter);
|
||||
} else {
|
||||
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
|
|
@ -203,7 +235,8 @@ public class SearchMetadataTool implements McpTool {
|
|||
searchResponse = JsonUtils.convertValue(response.getEntity(), Map.class);
|
||||
}
|
||||
|
||||
return buildEnhancedSearchResponse(searchResponse, query, size, requestedFields);
|
||||
return buildEnhancedSearchResponse(
|
||||
searchResponse, query, size, requestedFields, includeAggregations, maxAggregationBuckets);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -216,11 +249,14 @@ public class SearchMetadataTool implements McpTool {
|
|||
"SearchMetadataTool does not support limits enforcement.");
|
||||
}
|
||||
|
||||
public static Map<String, Object> buildEnhancedSearchResponse(
|
||||
@VisibleForTesting
|
||||
static Map<String, Object> buildEnhancedSearchResponse(
|
||||
Map<String, Object> searchResponse,
|
||||
String query,
|
||||
int requestedLimit,
|
||||
List<String> requestedFields) {
|
||||
List<String> requestedFields,
|
||||
boolean includeAggregations,
|
||||
int maxAggregationBuckets) {
|
||||
if (searchResponse == null) {
|
||||
return createEmptyResponse();
|
||||
}
|
||||
|
|
@ -262,9 +298,23 @@ public class SearchMetadataTool implements McpTool {
|
|||
result.put("returnedCount", cleanedResults.size());
|
||||
result.put("query", query);
|
||||
|
||||
// Add aggregations if present in search response
|
||||
if (searchResponse.containsKey("aggregations")) {
|
||||
result.put("aggregations", searchResponse.get("aggregations"));
|
||||
// Handle aggregations based on includeAggregations flag
|
||||
if (includeAggregations && searchResponse.containsKey("aggregations")) {
|
||||
Map<String, Object> rawAggregations = safeGetMap(searchResponse.get("aggregations"));
|
||||
if (rawAggregations != null && !rawAggregations.isEmpty()) {
|
||||
Map<String, Object> truncatedAggregations =
|
||||
truncateAggregations(rawAggregations, maxAggregationBuckets);
|
||||
result.put("aggregations", truncatedAggregations.get("aggregations"));
|
||||
if (truncatedAggregations.containsKey("aggregationsTruncated")) {
|
||||
result.put("aggregationsTruncated", true);
|
||||
result.put(
|
||||
"aggregationsMessage",
|
||||
String.format(
|
||||
"Aggregation buckets truncated to %d per field to optimize LLM context. "
|
||||
+ "Set maxAggregationBuckets parameter for more (max %d).",
|
||||
maxAggregationBuckets, MAX_ALLOWED_AGGREGATION_BUCKETS));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (totalResults > requestedLimit) {
|
||||
|
|
@ -297,11 +347,11 @@ public class SearchMetadataTool implements McpTool {
|
|||
}
|
||||
}
|
||||
|
||||
// Cleanup Description in case of huge description
|
||||
// Truncate long descriptions to optimize LLM context usage
|
||||
if (result.containsKey("description")) {
|
||||
String description = (String) result.get("description");
|
||||
if (description.length() > 3000) {
|
||||
result.put("description", description.substring(0, 300) + "...");
|
||||
Object descObj = result.get("description");
|
||||
if (descObj instanceof String description && description.length() > DESCRIPTION_MAX_LENGTH) {
|
||||
result.put("description", description.substring(0, DESCRIPTION_TRUNCATE_LENGTH) + "...");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
@ -322,6 +372,64 @@ public class SearchMetadataTool implements McpTool {
|
|||
return object;
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncates aggregation buckets to prevent excessive response size that could overwhelm LLM
|
||||
* context windows. Based on industry best practices, LLM performance degrades when context
|
||||
* utilization exceeds 85%, so keeping responses concise is critical.
|
||||
*
|
||||
* @param aggregations Raw aggregations from search response
|
||||
* @param maxBuckets Maximum number of buckets to keep per aggregation field
|
||||
* @return Map containing truncated aggregations and a flag if any were truncated
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, Object> truncateAggregations(
|
||||
Map<String, Object> aggregations, int maxBuckets) {
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
Map<String, Object> truncatedAggs = new HashMap<>();
|
||||
boolean anyTruncated = false;
|
||||
|
||||
for (Map.Entry<String, Object> entry : aggregations.entrySet()) {
|
||||
String aggName = entry.getKey();
|
||||
Object aggValue = entry.getValue();
|
||||
|
||||
if (aggValue instanceof Map) {
|
||||
Map<String, Object> aggMap = (Map<String, Object>) aggValue;
|
||||
|
||||
// Check if this aggregation has buckets
|
||||
if (aggMap.containsKey("buckets")) {
|
||||
Object bucketsObj = aggMap.get("buckets");
|
||||
if (bucketsObj instanceof List) {
|
||||
List<Object> buckets = (List<Object>) bucketsObj;
|
||||
if (buckets.size() > maxBuckets) {
|
||||
// Truncate buckets
|
||||
Map<String, Object> truncatedAgg = new HashMap<>(aggMap);
|
||||
truncatedAgg.put("buckets", buckets.subList(0, maxBuckets));
|
||||
truncatedAgg.put("_originalBucketCount", buckets.size());
|
||||
truncatedAgg.put("_truncated", true);
|
||||
truncatedAggs.put(aggName, truncatedAgg);
|
||||
anyTruncated = true;
|
||||
} else {
|
||||
truncatedAggs.put(aggName, aggMap);
|
||||
}
|
||||
} else {
|
||||
truncatedAggs.put(aggName, aggMap);
|
||||
}
|
||||
} else {
|
||||
// Not a bucket aggregation (e.g., value_count, sum, etc.)
|
||||
truncatedAggs.put(aggName, aggMap);
|
||||
}
|
||||
} else {
|
||||
truncatedAggs.put(aggName, aggValue);
|
||||
}
|
||||
}
|
||||
|
||||
result.put("aggregations", truncatedAggs);
|
||||
if (anyTruncated) {
|
||||
result.put("aggregationsTruncated", true);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, Object> safeGetMap(Object obj) {
|
||||
return (obj instanceof Map) ? (Map<String, Object>) obj : null;
|
||||
|
|
|
|||
|
|
@ -37,6 +37,16 @@
|
|||
"fields": {
|
||||
"type": "string",
|
||||
"description": "Comma-separated additional fields to include. Default returns: name, displayName, fullyQualifiedName, description, entityType, service, database, databaseSchema, serviceType, href, tags, owners, tier, tableType, columnNames.\n\nAdditional fields by entity type:\n- Table entities: columns, schemaDefinition, queries, upstreamLineage, entityRelationship\n- Topic entities: messageSchema, partitions, replicationFactor \n- Dashboard entities: charts, dataModels, project\n- Pipeline entities: tasks, pipelineUrl, scheduleInterval\n- All entities: createdAt, updatedAt, changeDescription, extension, domain, dataProducts, lifeCycle, sourceHash\n\nExample: 'columns,queries' for table column details and sample queries."
|
||||
},
|
||||
"includeAggregations": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include aggregation data (facets) in the response. Defaults to false to optimize LLM context size. Set to true only when you need aggregation statistics like counts by service type, owner, tags, etc. Aggregations can significantly increase response size.",
|
||||
"default": false
|
||||
},
|
||||
"maxAggregationBuckets": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of aggregation buckets to return per field when includeAggregations is true. Limits response size to prevent context overflow. Default is 10, maximum allowed is 50.",
|
||||
"default": 10
|
||||
}
|
||||
},
|
||||
"required": ["queryFilter"],
|
||||
|
|
@ -82,7 +92,7 @@
|
|||
},
|
||||
{
|
||||
"name": "get_entity_details",
|
||||
"description": "Get detailed information about a specific entity",
|
||||
"description": "Get detailed information about a specific entity. Response is optimized for LLM context by excluding verbose metadata fields.",
|
||||
"parameters": {
|
||||
"description": "Fqn is the fully qualified name of the entity. Entity type could be table, topic etc.",
|
||||
"type": "object",
|
||||
|
|
@ -223,18 +233,18 @@
|
|||
},
|
||||
"upstreamDepth": {
|
||||
"type": "integer",
|
||||
"description": "Depth for reaching upstream entities. Default is 5."
|
||||
"description": "Number of upstream hops to traverse. Default is 3, maximum is 10 to prevent excessive response size.",
|
||||
"default": 3
|
||||
},
|
||||
"downstreamDepth": {
|
||||
"type": "integer",
|
||||
"description": "Depth for reaching downstream entities. Default is 5."
|
||||
"description": "Number of downstream hops to traverse. Default is 3, maximum is 10 to prevent excessive response size.",
|
||||
"default": 3
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"entityType",
|
||||
"fqn",
|
||||
"upstreamDepth",
|
||||
"downstreamDepth"
|
||||
"fqn"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,245 @@
|
|||
package org.openmetadata.mcp.tools;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.util.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
/**
|
||||
* Unit tests for SearchMetadataTool aggregation truncation functionality. Tests the fix for issue
|
||||
* #25091: MCP server can return excessive data in aggregations, overwhelming LLM context.
|
||||
*/
|
||||
class SearchMetadataAggregationTest {
|
||||
|
||||
@Test
|
||||
void testAggregationsExcludedByDefault() {
|
||||
// Simulate search response with aggregations
|
||||
Map<String, Object> searchResponse = createSearchResponseWithAggregations(20);
|
||||
|
||||
// Call with includeAggregations=false (default behavior)
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), false, 10);
|
||||
|
||||
// Aggregations should NOT be present
|
||||
assertFalse(result.containsKey("aggregations"));
|
||||
assertFalse(result.containsKey("aggregationsTruncated"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAggregationsIncludedWhenRequested() {
|
||||
Map<String, Object> searchResponse = createSearchResponseWithAggregations(5);
|
||||
|
||||
// Call with includeAggregations=true
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), true, 10);
|
||||
|
||||
// Aggregations should be present
|
||||
assertTrue(result.containsKey("aggregations"));
|
||||
assertNotNull(result.get("aggregations"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAggregationsTruncatedWhenExceedingLimit() {
|
||||
// Create response with 20 buckets
|
||||
Map<String, Object> searchResponse = createSearchResponseWithAggregations(20);
|
||||
|
||||
// Call with maxAggregationBuckets=5
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), true, 5);
|
||||
|
||||
// Verify aggregations are present
|
||||
assertTrue(result.containsKey("aggregations"));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
|
||||
|
||||
// Verify serviceType aggregation is truncated
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> serviceTypeAgg = (Map<String, Object>) aggregations.get("serviceType");
|
||||
assertNotNull(serviceTypeAgg);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Object> buckets = (List<Object>) serviceTypeAgg.get("buckets");
|
||||
assertEquals(5, buckets.size(), "Buckets should be truncated to maxAggregationBuckets");
|
||||
|
||||
// Verify truncation metadata
|
||||
assertEquals(20, serviceTypeAgg.get("_originalBucketCount"));
|
||||
assertEquals(true, serviceTypeAgg.get("_truncated"));
|
||||
|
||||
// Verify global truncation flag
|
||||
assertEquals(true, result.get("aggregationsTruncated"));
|
||||
assertTrue(result.containsKey("aggregationsMessage"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAggregationsNotTruncatedWhenUnderLimit() {
|
||||
// Create response with 5 buckets
|
||||
Map<String, Object> searchResponse = createSearchResponseWithAggregations(5);
|
||||
|
||||
// Call with maxAggregationBuckets=10 (more than bucket count)
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), true, 10);
|
||||
|
||||
// Verify aggregations are present
|
||||
assertTrue(result.containsKey("aggregations"));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> serviceTypeAgg = (Map<String, Object>) aggregations.get("serviceType");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Object> buckets = (List<Object>) serviceTypeAgg.get("buckets");
|
||||
assertEquals(5, buckets.size(), "All buckets should be preserved");
|
||||
|
||||
// No truncation metadata should be present
|
||||
assertFalse(serviceTypeAgg.containsKey("_truncated"));
|
||||
assertFalse(result.containsKey("aggregationsTruncated"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNonBucketAggregationsPreserved() {
|
||||
// Create response with value_count aggregation (no buckets)
|
||||
Map<String, Object> searchResponse = createSearchResponseWithValueCountAgg();
|
||||
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), true, 5);
|
||||
|
||||
assertTrue(result.containsKey("aggregations"));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
|
||||
|
||||
// Value count aggregation should be preserved as-is
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> totalCount = (Map<String, Object>) aggregations.get("total_count");
|
||||
assertNotNull(totalCount);
|
||||
assertEquals(100, totalCount.get("value"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEmptyAggregationsHandled() {
|
||||
Map<String, Object> searchResponse = new HashMap<>();
|
||||
searchResponse.put("hits", createEmptyHits());
|
||||
searchResponse.put("aggregations", new HashMap<>());
|
||||
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), true, 10);
|
||||
|
||||
// Should not throw and should not add aggregations key for empty map
|
||||
assertFalse(result.containsKey("aggregations"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultipleAggregationsTruncatedIndependently() {
|
||||
// Create response with multiple aggregations of different sizes
|
||||
Map<String, Object> searchResponse = createSearchResponseWithMultipleAggregations();
|
||||
|
||||
Map<String, Object> result =
|
||||
SearchMetadataTool.buildEnhancedSearchResponse(
|
||||
searchResponse, "test query", 10, Collections.emptyList(), true, 5);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> aggregations = (Map<String, Object>) result.get("aggregations");
|
||||
|
||||
// serviceType has 20 buckets - should be truncated
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> serviceTypeAgg = (Map<String, Object>) aggregations.get("serviceType");
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Object> serviceTypeBuckets = (List<Object>) serviceTypeAgg.get("buckets");
|
||||
assertEquals(5, serviceTypeBuckets.size());
|
||||
assertEquals(true, serviceTypeAgg.get("_truncated"));
|
||||
|
||||
// owners has 3 buckets - should NOT be truncated
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> ownersAgg = (Map<String, Object>) aggregations.get("owners");
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Object> ownersBuckets = (List<Object>) ownersAgg.get("buckets");
|
||||
assertEquals(3, ownersBuckets.size());
|
||||
assertFalse(ownersAgg.containsKey("_truncated"));
|
||||
}
|
||||
|
||||
// Helper methods to create test data
|
||||
|
||||
private Map<String, Object> createSearchResponseWithAggregations(int bucketCount) {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
response.put("hits", createEmptyHits());
|
||||
|
||||
Map<String, Object> aggregations = new HashMap<>();
|
||||
Map<String, Object> serviceTypeAgg = new HashMap<>();
|
||||
List<Map<String, Object>> buckets = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < bucketCount; i++) {
|
||||
Map<String, Object> bucket = new HashMap<>();
|
||||
bucket.put("key", "Service" + i);
|
||||
bucket.put("doc_count", 100 - i);
|
||||
buckets.add(bucket);
|
||||
}
|
||||
|
||||
serviceTypeAgg.put("buckets", buckets);
|
||||
aggregations.put("serviceType", serviceTypeAgg);
|
||||
response.put("aggregations", aggregations);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
private Map<String, Object> createSearchResponseWithValueCountAgg() {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
response.put("hits", createEmptyHits());
|
||||
|
||||
Map<String, Object> aggregations = new HashMap<>();
|
||||
Map<String, Object> totalCount = new HashMap<>();
|
||||
totalCount.put("value", 100);
|
||||
aggregations.put("total_count", totalCount);
|
||||
response.put("aggregations", aggregations);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
private Map<String, Object> createSearchResponseWithMultipleAggregations() {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
response.put("hits", createEmptyHits());
|
||||
|
||||
Map<String, Object> aggregations = new HashMap<>();
|
||||
|
||||
// Large aggregation (20 buckets)
|
||||
Map<String, Object> serviceTypeAgg = new HashMap<>();
|
||||
List<Map<String, Object>> serviceTypeBuckets = new ArrayList<>();
|
||||
for (int i = 0; i < 20; i++) {
|
||||
Map<String, Object> bucket = new HashMap<>();
|
||||
bucket.put("key", "Service" + i);
|
||||
bucket.put("doc_count", 100 - i);
|
||||
serviceTypeBuckets.add(bucket);
|
||||
}
|
||||
serviceTypeAgg.put("buckets", serviceTypeBuckets);
|
||||
aggregations.put("serviceType", serviceTypeAgg);
|
||||
|
||||
// Small aggregation (3 buckets)
|
||||
Map<String, Object> ownersAgg = new HashMap<>();
|
||||
List<Map<String, Object>> ownersBuckets = new ArrayList<>();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> bucket = new HashMap<>();
|
||||
bucket.put("key", "Owner" + i);
|
||||
bucket.put("doc_count", 50 - i);
|
||||
ownersBuckets.add(bucket);
|
||||
}
|
||||
ownersAgg.put("buckets", ownersBuckets);
|
||||
aggregations.put("owners", ownersAgg);
|
||||
|
||||
response.put("aggregations", aggregations);
|
||||
return response;
|
||||
}
|
||||
|
||||
private Map<String, Object> createEmptyHits() {
|
||||
Map<String, Object> hits = new HashMap<>();
|
||||
hits.put("hits", Collections.emptyList());
|
||||
Map<String, Object> total = new HashMap<>();
|
||||
total.put("value", 0);
|
||||
hits.put("total", total);
|
||||
return hits;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue