mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
443 lines
14 KiB
Python
443 lines
14 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""Benchmark MCPIOService overhead with mocked MCP servers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import contextlib
|
|
import random
|
|
import time
|
|
import tracemalloc
|
|
from collections.abc import Iterator
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from data_designer.config.mcp import MCPProvider
|
|
from data_designer.engine.mcp.io import MCPIOService
|
|
|
|
try:
|
|
import resource
|
|
except ImportError: # pragma: no cover - resource not available on all platforms
|
|
resource = None
|
|
|
|
|
|
BYTES_PER_MB = 1024 * 1024
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LatencyProfile:
|
|
"""Latency profile for mocked MCP responses."""
|
|
|
|
name: str
|
|
fast_range_sec: tuple[float, float]
|
|
tail_probability: float
|
|
tail_range_sec: tuple[float, float]
|
|
list_tools_delay_sec: float
|
|
|
|
def sample_delay(self, rng: random.Random) -> float:
|
|
if self.tail_probability > 0 and rng.random() < self.tail_probability:
|
|
return rng.uniform(*self.tail_range_sec)
|
|
return rng.uniform(*self.fast_range_sec)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ProviderSpec:
|
|
provider: MCPProvider
|
|
tools: tuple[dict[str, Any], ...]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FakeTransport:
|
|
provider_name: str
|
|
tools: tuple[dict[str, Any], ...]
|
|
latency_profile: LatencyProfile
|
|
rng: random.Random
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FakeToolResult:
|
|
content: dict[str, Any]
|
|
isError: bool = False
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CallSpec:
|
|
provider: MCPProvider
|
|
tool_name: str
|
|
arguments: dict[str, Any]
|
|
delay_sec: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BenchmarkResult:
|
|
scenario_name: str
|
|
total_calls: int
|
|
elapsed_sec: float
|
|
rpm: float
|
|
server_walltime_sec: float
|
|
overhead_sec: float
|
|
overhead_pct: float
|
|
rss_start_mb: float | None
|
|
rss_end_mb: float | None
|
|
rss_growth_mb: float | None
|
|
py_current_mb: float
|
|
py_peak_mb: float
|
|
py_growth_mb: float
|
|
|
|
|
|
class FakeServerContext:
|
|
def __init__(self, transport: FakeTransport) -> None:
|
|
self._transport = transport
|
|
|
|
async def __aenter__(self) -> tuple[FakeTransport, None]:
|
|
return self._transport, None
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc: BaseException | None,
|
|
tb: Any | None,
|
|
) -> None:
|
|
_ = exc_type, exc, tb
|
|
return None
|
|
|
|
|
|
class FakeClientSession:
|
|
def __init__(self, read: FakeTransport, write: Any) -> None:
|
|
self._transport = read
|
|
self._write = write
|
|
self._closed = False
|
|
|
|
async def __aenter__(self) -> FakeClientSession:
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc: BaseException | None,
|
|
tb: Any | None,
|
|
) -> None:
|
|
_ = exc_type, exc, tb
|
|
self._closed = True
|
|
|
|
async def initialize(self) -> None:
|
|
return None
|
|
|
|
async def list_tools(self) -> list[dict[str, Any]]:
|
|
delay = self._transport.latency_profile.list_tools_delay_sec
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
return list(self._transport.tools)
|
|
|
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> FakeToolResult:
|
|
delay = arguments.get("_benchmark_delay_sec")
|
|
if delay is None:
|
|
delay = self._transport.latency_profile.sample_delay(self._transport.rng)
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
payload = {"tool": tool_name, "provider": self._transport.provider_name}
|
|
return FakeToolResult(content=payload, isError=False)
|
|
|
|
|
|
class FakeServerRegistry:
|
|
def __init__(self, providers: list[ProviderSpec], latency_profile: LatencyProfile, rng: random.Random) -> None:
|
|
self._providers_by_endpoint = {spec.provider.endpoint: spec for spec in providers}
|
|
self._latency_profile = latency_profile
|
|
self._rng = rng
|
|
|
|
def context_for_endpoint(self, endpoint: str) -> FakeServerContext:
|
|
spec = self._providers_by_endpoint.get(endpoint)
|
|
if spec is None:
|
|
raise ValueError(f"Unknown endpoint {endpoint!r} in fake registry.")
|
|
transport = FakeTransport(
|
|
provider_name=spec.provider.name,
|
|
tools=spec.tools,
|
|
latency_profile=self._latency_profile,
|
|
rng=self._rng,
|
|
)
|
|
return FakeServerContext(transport)
|
|
|
|
|
|
def _format_mb(value: float | None) -> str:
|
|
if value is None:
|
|
return "n/a"
|
|
return f"{value:.2f} MB"
|
|
|
|
|
|
def _get_rss_bytes() -> int | None:
|
|
if resource is None:
|
|
return None
|
|
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
|
if rss < 10_000_000:
|
|
rss *= 1024
|
|
return int(rss)
|
|
|
|
|
|
def _build_tools(provider_name: str, tools_per_provider: int) -> tuple[dict[str, Any], ...]:
|
|
tools: list[dict[str, Any]] = []
|
|
for idx in range(tools_per_provider):
|
|
tool_name = f"{provider_name}-tool-{idx}"
|
|
tools.append(
|
|
{
|
|
"name": tool_name,
|
|
"description": f"Tool {idx} for {provider_name}.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {"query": {"type": "string"}, "limit": {"type": "integer"}},
|
|
"required": ["query"],
|
|
},
|
|
}
|
|
)
|
|
return tuple(tools)
|
|
|
|
|
|
def _build_provider_specs(
|
|
provider_count: int,
|
|
tools_per_provider: int,
|
|
) -> tuple[list[ProviderSpec], dict[str, tuple[str, ...]]]:
|
|
providers: list[ProviderSpec] = []
|
|
tool_names: dict[str, tuple[str, ...]] = {}
|
|
for idx in range(provider_count):
|
|
name = f"provider-{idx}"
|
|
provider = MCPProvider(name=name, endpoint=f"fake://{name}", api_key="benchmark-key")
|
|
tools = _build_tools(name, tools_per_provider)
|
|
tool_names[name] = tuple(tool["name"] for tool in tools)
|
|
providers.append(ProviderSpec(provider=provider, tools=tools))
|
|
return providers, tool_names
|
|
|
|
|
|
def _build_call_specs(
|
|
providers: list[ProviderSpec],
|
|
tool_names: dict[str, tuple[str, ...]],
|
|
latency_profile: LatencyProfile,
|
|
rng: random.Random,
|
|
total_calls: int,
|
|
payload_size: int,
|
|
) -> list[CallSpec]:
|
|
provider_tool_pairs: list[tuple[ProviderSpec, str]] = []
|
|
for spec in providers:
|
|
for tool_name in tool_names[spec.provider.name]:
|
|
provider_tool_pairs.append((spec, tool_name))
|
|
|
|
calls: list[CallSpec] = []
|
|
call_id = 0
|
|
while len(calls) < total_calls:
|
|
rng.shuffle(provider_tool_pairs)
|
|
for spec, tool_name in provider_tool_pairs:
|
|
if len(calls) >= total_calls:
|
|
break
|
|
delay = latency_profile.sample_delay(rng)
|
|
arguments = {
|
|
"query": f"call-{call_id}",
|
|
"limit": call_id % 10,
|
|
"payload": "x" * payload_size,
|
|
"_benchmark_delay_sec": delay,
|
|
"_benchmark_call_id": call_id,
|
|
}
|
|
calls.append(CallSpec(provider=spec.provider, tool_name=tool_name, arguments=arguments, delay_sec=delay))
|
|
call_id += 1
|
|
|
|
rng.shuffle(calls)
|
|
return calls
|
|
|
|
|
|
def _chunk_calls(calls: list[CallSpec], batch_size: int) -> list[list[CallSpec]]:
|
|
return [calls[idx : idx + batch_size] for idx in range(0, len(calls), batch_size)]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _patch_mcp_io(registry: FakeServerRegistry) -> Iterator[None]:
|
|
import data_designer.engine.mcp.io as mcp_io
|
|
|
|
original_client_session = mcp_io.ClientSession
|
|
original_sse_client = mcp_io.sse_client
|
|
|
|
def fake_sse_client(endpoint: str, headers: dict[str, Any] | None = None) -> FakeServerContext:
|
|
_ = headers
|
|
return registry.context_for_endpoint(endpoint)
|
|
|
|
mcp_io.ClientSession = FakeClientSession
|
|
mcp_io.sse_client = fake_sse_client
|
|
try:
|
|
yield
|
|
finally:
|
|
mcp_io.ClientSession = original_client_session
|
|
mcp_io.sse_client = original_sse_client
|
|
|
|
|
|
def _run_scenario(
|
|
*,
|
|
scenario_name: str,
|
|
latency_profile: LatencyProfile,
|
|
provider_count: int,
|
|
tools_per_provider: int,
|
|
total_calls: int,
|
|
batch_size: int,
|
|
payload_size: int,
|
|
warmup_batches: int,
|
|
seed: int,
|
|
) -> BenchmarkResult:
|
|
rng = random.Random(seed)
|
|
providers, tool_names = _build_provider_specs(provider_count, tools_per_provider)
|
|
registry = FakeServerRegistry(providers, latency_profile, rng)
|
|
|
|
warmup_calls_total = warmup_batches * batch_size
|
|
all_calls = _build_call_specs(
|
|
providers=providers,
|
|
tool_names=tool_names,
|
|
latency_profile=latency_profile,
|
|
rng=rng,
|
|
total_calls=total_calls + warmup_calls_total,
|
|
payload_size=payload_size,
|
|
)
|
|
|
|
warmup_calls = all_calls[:warmup_calls_total]
|
|
measured_calls = all_calls[warmup_calls_total:]
|
|
warmup_batches_list = _chunk_calls(warmup_calls, batch_size) if warmup_calls else []
|
|
measured_batches = _chunk_calls(measured_calls, batch_size)
|
|
|
|
server_walltime_sec = sum(max(call.delay_sec for call in batch) for batch in measured_batches if batch)
|
|
|
|
with _patch_mcp_io(registry):
|
|
service = MCPIOService()
|
|
try:
|
|
for spec in providers:
|
|
service.list_tools(spec.provider)
|
|
service.list_tools(spec.provider)
|
|
|
|
for batch in warmup_batches_list:
|
|
service.call_tools([(call.provider, call.tool_name, call.arguments) for call in batch])
|
|
|
|
tracemalloc.start()
|
|
start_current = tracemalloc.get_traced_memory()[0]
|
|
rss_start = _get_rss_bytes()
|
|
start_time = time.perf_counter()
|
|
|
|
for batch in measured_batches:
|
|
service.call_tools([(call.provider, call.tool_name, call.arguments) for call in batch])
|
|
|
|
elapsed_sec = time.perf_counter() - start_time
|
|
current_bytes, peak_bytes = tracemalloc.get_traced_memory()
|
|
tracemalloc.stop()
|
|
rss_end = _get_rss_bytes()
|
|
finally:
|
|
service.shutdown()
|
|
|
|
rpm = (total_calls / elapsed_sec * 60.0) if elapsed_sec > 0 else 0.0
|
|
overhead_sec = max(0.0, elapsed_sec - server_walltime_sec)
|
|
overhead_pct = (overhead_sec / elapsed_sec * 100.0) if elapsed_sec > 0 else 0.0
|
|
|
|
rss_start_mb = (rss_start / BYTES_PER_MB) if rss_start is not None else None
|
|
rss_end_mb = (rss_end / BYTES_PER_MB) if rss_end is not None else None
|
|
rss_growth_mb = ((rss_end - rss_start) / BYTES_PER_MB) if rss_start is not None and rss_end is not None else None
|
|
|
|
py_current_mb = current_bytes / BYTES_PER_MB
|
|
py_peak_mb = peak_bytes / BYTES_PER_MB
|
|
py_growth_mb = (current_bytes - start_current) / BYTES_PER_MB
|
|
|
|
return BenchmarkResult(
|
|
scenario_name=scenario_name,
|
|
total_calls=total_calls,
|
|
elapsed_sec=elapsed_sec,
|
|
rpm=rpm,
|
|
server_walltime_sec=server_walltime_sec,
|
|
overhead_sec=overhead_sec,
|
|
overhead_pct=overhead_pct,
|
|
rss_start_mb=rss_start_mb,
|
|
rss_end_mb=rss_end_mb,
|
|
rss_growth_mb=rss_growth_mb,
|
|
py_current_mb=py_current_mb,
|
|
py_peak_mb=py_peak_mb,
|
|
py_growth_mb=py_growth_mb,
|
|
)
|
|
|
|
|
|
def _print_result(result: BenchmarkResult, provider_count: int, tools_per_provider: int, batch_size: int) -> None:
|
|
print(f"\nScenario: {result.scenario_name}")
|
|
print(
|
|
f"- providers: {provider_count} | tools/provider: {tools_per_provider} | batch size: {batch_size} | "
|
|
f"calls: {result.total_calls}"
|
|
)
|
|
print(
|
|
f"- walltime: {result.elapsed_sec:.3f}s | server walltime: {result.server_walltime_sec:.3f}s | "
|
|
f"overhead: {result.overhead_sec:.3f}s ({result.overhead_pct:.1f}%)"
|
|
)
|
|
print(f"- throughput: {result.rpm:,.1f} RPM")
|
|
print(
|
|
"- memory: "
|
|
f"RSS start {_format_mb(result.rss_start_mb)}, "
|
|
f"RSS end {_format_mb(result.rss_end_mb)}, "
|
|
f"RSS growth {_format_mb(result.rss_growth_mb)}"
|
|
)
|
|
print(
|
|
"- python alloc: "
|
|
f"current {result.py_current_mb:.2f} MB, "
|
|
f"peak {result.py_peak_mb:.2f} MB, "
|
|
f"growth {result.py_growth_mb:.2f} MB"
|
|
)
|
|
|
|
|
|
def _parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Benchmark MCPIOService overhead with mocked MCP servers.")
|
|
parser.add_argument("--providers", type=int, default=100, help="Number of MCP providers to simulate.")
|
|
parser.add_argument("--tools-per-provider", type=int, default=6, help="Number of tools per provider.")
|
|
parser.add_argument("--total-calls", type=int, default=2000, help="Total tool calls per scenario.")
|
|
parser.add_argument("--batch-size", type=int, default=200, help="Number of tool calls per batch.")
|
|
parser.add_argument("--payload-size", type=int, default=128, help="Size of payload argument for each call.")
|
|
parser.add_argument("--warmup-batches", type=int, default=1, help="Warmup batches excluded from timing.")
|
|
parser.add_argument("--seed", type=int, default=7, help="Random seed for reproducibility.")
|
|
parser.add_argument(
|
|
"--profile",
|
|
type=str,
|
|
default="all",
|
|
choices=("all", "instant", "long-tail"),
|
|
help="Which latency profile to run.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = _parse_args()
|
|
profiles = {
|
|
"instant": LatencyProfile(
|
|
name="instant",
|
|
fast_range_sec=(0.0, 0.001),
|
|
tail_probability=0.0,
|
|
tail_range_sec=(0.0, 0.0),
|
|
list_tools_delay_sec=0.0005,
|
|
),
|
|
"long-tail": LatencyProfile(
|
|
name="long-tail",
|
|
fast_range_sec=(0.001, 0.004),
|
|
tail_probability=0.05,
|
|
tail_range_sec=(0.05, 0.2),
|
|
list_tools_delay_sec=0.001,
|
|
),
|
|
}
|
|
|
|
selected_profiles: list[LatencyProfile]
|
|
if args.profile == "all":
|
|
selected_profiles = [profiles["instant"], profiles["long-tail"]]
|
|
else:
|
|
selected_profiles = [profiles[args.profile]]
|
|
|
|
for offset, profile in enumerate(selected_profiles):
|
|
result = _run_scenario(
|
|
scenario_name=profile.name,
|
|
latency_profile=profile,
|
|
provider_count=args.providers,
|
|
tools_per_provider=args.tools_per_provider,
|
|
total_calls=args.total_calls,
|
|
batch_size=args.batch_size,
|
|
payload_size=args.payload_size,
|
|
warmup_batches=args.warmup_batches,
|
|
seed=args.seed + offset,
|
|
)
|
|
_print_result(result, args.providers, args.tools_per_provider, args.batch_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|