# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import pytest from data_designer.engine.models.usage import ( ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenCountSource, TokenUsageStats, ToolUsageStats, ) def test_token_usage_stats() -> None: token_usage_stats = TokenUsageStats() assert token_usage_stats.input_tokens == 0 assert token_usage_stats.output_tokens == 0 assert token_usage_stats.reasoning_tokens is None assert token_usage_stats.reasoning_token_count_source is None assert token_usage_stats.total_tokens == 0 assert token_usage_stats.has_usage is False token_usage_stats.extend( input_tokens=10, output_tokens=20, reasoning_tokens=5, reasoning_token_count_source=TokenCountSource.PROVIDER, ) assert token_usage_stats.input_tokens == 10 assert token_usage_stats.output_tokens == 20 assert token_usage_stats.reasoning_tokens == 5 assert token_usage_stats.reasoning_token_count_source == TokenCountSource.PROVIDER assert token_usage_stats.total_tokens == 30 assert token_usage_stats.has_usage is True def test_token_usage_stats_reasoning_source_is_required() -> None: with pytest.raises(ValueError, match="reasoning_tokens requires reasoning_token_count_source"): TokenUsageStats(reasoning_tokens=1) with pytest.raises(ValueError, match="reasoning_token_count_source requires reasoning_tokens"): TokenUsageStats(reasoning_token_count_source=TokenCountSource.ESTIMATED) def test_token_usage_stats_uses_estimated_source_when_any_count_is_estimated() -> None: token_usage_stats = TokenUsageStats() token_usage_stats.extend( input_tokens=10, output_tokens=20, reasoning_tokens=5, reasoning_token_count_source=TokenCountSource.PROVIDER, ) token_usage_stats.extend( input_tokens=3, output_tokens=4, reasoning_tokens=2, reasoning_token_count_source=TokenCountSource.PROVIDER, ) assert token_usage_stats.reasoning_tokens == 7 assert token_usage_stats.reasoning_token_count_source == TokenCountSource.PROVIDER token_usage_stats.extend( input_tokens=1, output_tokens=1, reasoning_tokens=8, reasoning_token_count_source=TokenCountSource.ESTIMATED, ) assert token_usage_stats.reasoning_tokens == 15 assert token_usage_stats.reasoning_token_count_source == TokenCountSource.ESTIMATED def test_request_usage_stats() -> None: request_usage_stats = RequestUsageStats() assert request_usage_stats.successful_requests == 0 assert request_usage_stats.failed_requests == 0 assert request_usage_stats.total_requests == 0 assert request_usage_stats.has_usage is False request_usage_stats.extend(successful_requests=10, failed_requests=20) assert request_usage_stats.successful_requests == 10 assert request_usage_stats.failed_requests == 20 assert request_usage_stats.total_requests == 30 assert request_usage_stats.has_usage is True def test_image_usage_stats() -> None: image_usage_stats = ImageUsageStats() assert image_usage_stats.total_images == 0 assert image_usage_stats.has_usage is False image_usage_stats.extend(images=5) assert image_usage_stats.total_images == 5 assert image_usage_stats.has_usage is True image_usage_stats.extend(images=3) assert image_usage_stats.total_images == 8 assert image_usage_stats.has_usage is True def test_tool_usage_stats_empty_state() -> None: """Test ToolUsageStats initialization with empty state.""" tool_usage = ToolUsageStats() assert tool_usage.total_tool_calls == 0 assert tool_usage.total_tool_call_turns == 0 assert tool_usage.total_generations == 0 assert tool_usage.generations_with_tools == 0 assert tool_usage.has_usage is False def test_tool_usage_stats_single_generation_with_tools() -> None: """Test ToolUsageStats with a single generation that uses tools.""" tool_usage = ToolUsageStats() tool_usage.extend(tool_calls=5, tool_call_turns=2) assert tool_usage.total_tool_calls == 5 assert tool_usage.total_tool_call_turns == 2 assert tool_usage.total_generations == 1 assert tool_usage.generations_with_tools == 1 assert tool_usage.has_usage is True def test_tool_usage_stats_multiple_generations() -> None: """Test ToolUsageStats with multiple generations.""" tool_usage = ToolUsageStats() for _ in range(3): tool_usage.extend(tool_calls=4, tool_call_turns=3) assert tool_usage.total_tool_calls == 12 assert tool_usage.total_tool_call_turns == 9 assert tool_usage.total_generations == 3 assert tool_usage.generations_with_tools == 3 assert tool_usage.has_usage is True def test_tool_usage_stats_generation_without_tool_calls() -> None: """Test that extend with zero tool_calls still increments total_generations but not generations_with_tools.""" tool_usage = ToolUsageStats() tool_usage.extend(tool_calls=0, tool_call_turns=0) assert tool_usage.total_tool_calls == 0 assert tool_usage.total_tool_call_turns == 0 assert tool_usage.total_generations == 1 assert tool_usage.generations_with_tools == 0 assert tool_usage.has_usage is True def test_tool_usage_stats_mixed_generations() -> None: """Test ratio tracking with mix of generations with and without tools.""" tool_usage = ToolUsageStats() tool_usage.extend(tool_calls=0, tool_call_turns=0) # No tools used tool_usage.extend(tool_calls=4, tool_call_turns=2) # Tools used tool_usage.extend(tool_calls=0, tool_call_turns=0) # No tools used tool_usage.extend(tool_calls=6, tool_call_turns=4) # Tools used assert tool_usage.total_tool_calls == 10 assert tool_usage.total_tool_call_turns == 6 assert tool_usage.total_generations == 4 assert tool_usage.generations_with_tools == 2 assert tool_usage.has_usage is True def test_tool_usage_stats_merge() -> None: """Test that merging two ToolUsageStats objects works correctly.""" stats1 = ToolUsageStats() stats1.extend(tool_calls=2, tool_call_turns=1) stats1.extend(tool_calls=4, tool_call_turns=3) stats2 = ToolUsageStats() stats2.extend(tool_calls=6, tool_call_turns=2) stats2.extend(tool_calls=0, tool_call_turns=0) # No tools stats1.merge(stats2) assert stats1.total_tool_calls == 12 assert stats1.total_tool_call_turns == 6 assert stats1.total_generations == 4 assert stats1.generations_with_tools == 3 def test_tool_usage_stats_merge_empty() -> None: """Test merging an empty ToolUsageStats doesn't change values.""" stats1 = ToolUsageStats() stats1.extend(tool_calls=4, tool_call_turns=2) stats2 = ToolUsageStats() stats1.merge(stats2) assert stats1.total_tool_calls == 4 assert stats1.total_tool_call_turns == 2 assert stats1.total_generations == 1 assert stats1.generations_with_tools == 1 def test_model_usage_stats() -> None: model_usage_stats = ModelUsageStats() assert model_usage_stats.token_usage.input_tokens == 0 assert model_usage_stats.token_usage.output_tokens == 0 assert model_usage_stats.request_usage.successful_requests == 0 assert model_usage_stats.request_usage.failed_requests == 0 assert model_usage_stats.image_usage.total_images == 0 assert model_usage_stats.has_usage is False # tool_usage and image_usage are excluded when has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=10) == { "token_usage": { "input_tokens": 0, "output_tokens": 0, "reasoning_tokens": None, "reasoning_token_count_source": None, "total_tokens": 0, }, "request_usage": {"successful_requests": 0, "failed_requests": 0, "total_requests": 0}, "tokens_per_second": 0, "requests_per_minute": 0, } model_usage_stats.extend( token_usage=TokenUsageStats( input_tokens=10, output_tokens=20, reasoning_tokens=7, reasoning_token_count_source=TokenCountSource.PROVIDER, ), request_usage=RequestUsageStats(successful_requests=2, failed_requests=1), ) assert model_usage_stats.token_usage.input_tokens == 10 assert model_usage_stats.token_usage.output_tokens == 20 assert model_usage_stats.token_usage.reasoning_tokens == 7 assert model_usage_stats.request_usage.successful_requests == 2 assert model_usage_stats.request_usage.failed_requests == 1 assert model_usage_stats.has_usage is True # tool_usage and image_usage are excluded when has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=2) == { "token_usage": { "input_tokens": 10, "output_tokens": 20, "reasoning_tokens": 7, "reasoning_token_count_source": "provider", "total_tokens": 30, }, "request_usage": {"successful_requests": 2, "failed_requests": 1, "total_requests": 3}, "tokens_per_second": 15, "requests_per_minute": 90, } def test_model_usage_stats_extend_with_tool_usage() -> None: """Test that ModelUsageStats.extend properly merges tool usage.""" stats1 = ModelUsageStats() stats1.tool_usage.extend(tool_calls=2, tool_call_turns=1) stats1.tool_usage.extend(tool_calls=4, tool_call_turns=3) stats2 = ModelUsageStats() stats2.tool_usage.extend(tool_calls=6, tool_call_turns=2) stats2.tool_usage.extend(tool_calls=0, tool_call_turns=0) # No tools used stats1.extend(tool_usage=stats2.tool_usage) assert stats1.tool_usage.total_tool_calls == 12 assert stats1.tool_usage.total_tool_call_turns == 6 assert stats1.tool_usage.total_generations == 4 assert stats1.tool_usage.generations_with_tools == 3 def test_model_usage_stats_with_image_usage() -> None: """Test that ModelUsageStats includes image_usage when it has usage.""" model_usage_stats = ModelUsageStats() model_usage_stats.extend( token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), image_usage=ImageUsageStats(total_images=5), ) assert model_usage_stats.image_usage.total_images == 5 assert model_usage_stats.image_usage.has_usage is True # image_usage should be included in output usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2) assert "image_usage" in usage_stats assert usage_stats["image_usage"] == {"total_images": 5} def test_model_usage_stats_has_usage_any_of() -> None: """Test that has_usage is True when any of token, request, or image usage is present.""" # Only token usage stats = ModelUsageStats() stats.extend(token_usage=TokenUsageStats(input_tokens=1, output_tokens=0)) assert stats.has_usage is True # Only request usage (e.g. diffusion API without token counts) stats = ModelUsageStats() stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) assert stats.has_usage is True # Only image usage stats = ModelUsageStats() stats.extend(image_usage=ImageUsageStats(total_images=2)) assert stats.has_usage is True # None of the three stats = ModelUsageStats() assert stats.has_usage is False def test_model_usage_stats_exclude_unused_stats() -> None: """Test that ModelUsageStats excludes tool_usage and image_usage when they have no usage.""" model_usage_stats = ModelUsageStats() model_usage_stats.extend( token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) usage_stats = model_usage_stats.get_usage_stats(total_time_elapsed=2) assert "tool_usage" not in usage_stats assert "image_usage" not in usage_stats assert "token_usage" in usage_stats assert "request_usage" in usage_stats