mirror of
https://github.com/NVIDIA-NeMo/DataDesigner
synced 2026-05-24 09:48:29 +00:00
320 lines
9.6 KiB
Python
320 lines
9.6 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
|
|
import pytest
|
|
|
|
from data_designer.engine.progress.tracker import ProgressTracker
|
|
|
|
|
|
@pytest.fixture
|
|
def tracker() -> ProgressTracker:
|
|
return ProgressTracker(total_records=100, label="test column 'name'")
|
|
|
|
|
|
def test_initializes_with_correct_values() -> None:
|
|
tracker = ProgressTracker(total_records=100, label="test label")
|
|
|
|
assert tracker.total_records == 100
|
|
assert tracker.label == "test label"
|
|
assert tracker.completed == 0
|
|
assert tracker.success == 0
|
|
assert tracker.failed == 0
|
|
assert tracker.skipped == 0
|
|
|
|
|
|
def test_calculates_log_interval_from_percentage() -> None:
|
|
tracker = ProgressTracker(total_records=100, label="test", log_interval_percent=10)
|
|
assert tracker.log_interval == 10
|
|
|
|
tracker = ProgressTracker(total_records=100, label="test", log_interval_percent=25)
|
|
assert tracker.log_interval == 25
|
|
|
|
tracker = ProgressTracker(total_records=1000, label="test", log_interval_percent=5)
|
|
assert tracker.log_interval == 50
|
|
|
|
|
|
def test_log_interval_minimum_is_one() -> None:
|
|
tracker = ProgressTracker(total_records=5, label="test", log_interval_percent=10)
|
|
assert tracker.log_interval >= 1
|
|
|
|
|
|
def test_handles_zero_total_records() -> None:
|
|
tracker = ProgressTracker(total_records=0, label="test")
|
|
assert tracker.log_interval == 1
|
|
assert tracker.total_records == 0
|
|
|
|
|
|
def test_record_success_increments_completed_and_success(tracker: ProgressTracker) -> None:
|
|
tracker.record_success()
|
|
|
|
assert tracker.completed == 1
|
|
assert tracker.success == 1
|
|
assert tracker.failed == 0
|
|
|
|
|
|
def test_record_success_multiple_times(tracker: ProgressTracker) -> None:
|
|
for _ in range(5):
|
|
tracker.record_success()
|
|
|
|
assert tracker.completed == 5
|
|
assert tracker.success == 5
|
|
assert tracker.failed == 0
|
|
|
|
|
|
def test_record_failure_increments_completed_and_failed(tracker: ProgressTracker) -> None:
|
|
tracker.record_failure()
|
|
|
|
assert tracker.completed == 1
|
|
assert tracker.success == 0
|
|
assert tracker.failed == 1
|
|
|
|
|
|
def test_record_failure_multiple_times(tracker: ProgressTracker) -> None:
|
|
for _ in range(5):
|
|
tracker.record_failure()
|
|
|
|
assert tracker.completed == 5
|
|
assert tracker.success == 0
|
|
assert tracker.failed == 5
|
|
|
|
|
|
def test_record_skipped_increments_completed_and_skipped(tracker: ProgressTracker) -> None:
|
|
tracker.record_skipped()
|
|
|
|
assert tracker.completed == 1
|
|
assert tracker.success == 0
|
|
assert tracker.failed == 0
|
|
assert tracker.skipped == 1
|
|
|
|
|
|
def test_tracks_mixed_successes_and_failures(tracker: ProgressTracker) -> None:
|
|
tracker.record_success()
|
|
tracker.record_success()
|
|
tracker.record_failure()
|
|
tracker.record_success()
|
|
tracker.record_failure()
|
|
|
|
assert tracker.completed == 5
|
|
assert tracker.success == 3
|
|
assert tracker.failed == 2
|
|
|
|
|
|
def test_get_snapshot_includes_skipped_counts() -> None:
|
|
tracker = ProgressTracker(total_records=10, label="test")
|
|
tracker.record_success()
|
|
tracker.record_failure()
|
|
tracker.record_skipped()
|
|
|
|
completed, total_records, success, failed, skipped, percent, rate, emoji = tracker.get_snapshot(elapsed=2.0)
|
|
|
|
assert completed == 3
|
|
assert total_records == 10
|
|
assert success == 1
|
|
assert failed == 1
|
|
assert skipped == 1
|
|
assert percent == 30.0
|
|
assert rate == 1.5
|
|
assert emoji
|
|
|
|
|
|
def test_log_start_logs_worker_info(tracker: ProgressTracker, caplog: pytest.LogCaptureFixture) -> None:
|
|
with caplog.at_level(logging.INFO):
|
|
tracker.log_start(max_workers=8)
|
|
|
|
assert "8 concurrent workers" in caplog.text
|
|
assert "test column 'name'" in caplog.text
|
|
|
|
|
|
def test_logs_progress_at_interval(caplog: pytest.LogCaptureFixture) -> None:
|
|
tracker = ProgressTracker(total_records=10, label="test", log_interval_percent=50)
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
for _ in range(5):
|
|
tracker.record_success()
|
|
|
|
assert "5/10" in caplog.text
|
|
assert "50%" in caplog.text
|
|
|
|
|
|
def test_log_final_logs_remaining_progress(caplog: pytest.LogCaptureFixture) -> None:
|
|
tracker = ProgressTracker(total_records=10, label="test", log_interval_percent=50)
|
|
|
|
for _ in range(3):
|
|
tracker.record_success()
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
tracker.log_final()
|
|
|
|
assert "3/10" in caplog.text
|
|
|
|
|
|
def test_progress_log_includes_rate_and_eta(caplog: pytest.LogCaptureFixture) -> None:
|
|
tracker = ProgressTracker(total_records=10, label="test", log_interval_percent=50)
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
for _ in range(5):
|
|
tracker.record_success()
|
|
|
|
assert "rec/s" in caplog.text
|
|
assert "eta" in caplog.text
|
|
|
|
|
|
def test_progress_log_shows_ok_and_failed_counts(caplog: pytest.LogCaptureFixture) -> None:
|
|
tracker = ProgressTracker(total_records=10, label="test", log_interval_percent=50)
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
for _ in range(3):
|
|
tracker.record_success()
|
|
for _ in range(2):
|
|
tracker.record_failure()
|
|
|
|
assert "3 ok" in caplog.text
|
|
assert "2 failed" in caplog.text
|
|
|
|
|
|
def test_concurrent_record_calls_are_thread_safe() -> None:
|
|
tracker = ProgressTracker(total_records=1000, label="test", log_interval_percent=100)
|
|
num_threads = 10
|
|
records_per_thread = 100
|
|
|
|
def record_many_successes() -> None:
|
|
for _ in range(records_per_thread):
|
|
tracker.record_success()
|
|
|
|
def record_many_failures() -> None:
|
|
for _ in range(records_per_thread):
|
|
tracker.record_failure()
|
|
|
|
threads = []
|
|
for i in range(num_threads):
|
|
if i % 2 == 0:
|
|
thread = threading.Thread(target=record_many_successes)
|
|
else:
|
|
thread = threading.Thread(target=record_many_failures)
|
|
threads.append(thread)
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
expected_success = (num_threads // 2) * records_per_thread
|
|
expected_failed = (num_threads - num_threads // 2) * records_per_thread
|
|
|
|
assert tracker.completed == num_threads * records_per_thread
|
|
assert tracker.success == expected_success
|
|
assert tracker.failed == expected_failed
|
|
|
|
|
|
def test_handles_very_small_total_records() -> None:
|
|
tracker = ProgressTracker(total_records=1, label="test")
|
|
tracker.record_success()
|
|
|
|
assert tracker.completed == 1
|
|
assert tracker.success == 1
|
|
|
|
|
|
def test_handles_log_interval_larger_than_total() -> None:
|
|
tracker = ProgressTracker(total_records=5, label="test", log_interval_percent=50)
|
|
|
|
for _ in range(5):
|
|
tracker.record_success()
|
|
|
|
assert tracker.completed == 5
|
|
|
|
|
|
def test_log_final_handles_zero_records_processed(caplog: pytest.LogCaptureFixture) -> None:
|
|
tracker = ProgressTracker(total_records=10, label="test")
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
tracker.log_final()
|
|
|
|
# Should not raise, may or may not log depending on implementation
|
|
|
|
|
|
def test_progress_percentage_with_zero_total() -> None:
|
|
tracker = ProgressTracker(total_records=0, label="test")
|
|
|
|
# Should not raise division by zero
|
|
tracker.record_success()
|
|
|
|
assert tracker.completed == 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"total_records,log_interval_percent",
|
|
[
|
|
(10, 10), # Exact divisibility
|
|
(10, 30), # Non-exact divisibility: logs at 3, 6, 9
|
|
(1, 10), # Single record edge case
|
|
(10, 100), # Interval equals total
|
|
],
|
|
)
|
|
def test_100_percent_logged_exactly_once(
|
|
total_records: int, log_interval_percent: int, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""100% should be logged exactly once after completing all records + log_final."""
|
|
tracker = ProgressTracker(total_records=total_records, label="test", log_interval_percent=log_interval_percent)
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
for _ in range(total_records):
|
|
tracker.record_success()
|
|
tracker.log_final()
|
|
|
|
assert caplog.text.count("100%") == 1
|
|
|
|
|
|
def test_record_completion_never_logs_100_percent(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""_record_completion should never log 100%; that's log_final's job."""
|
|
tracker = ProgressTracker(total_records=10, label="test", log_interval_percent=10)
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
for _ in range(10):
|
|
tracker.record_success()
|
|
|
|
assert caplog.text.count("100%") == 0
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
tracker.log_final()
|
|
|
|
assert caplog.text.count("100%") == 1
|
|
|
|
|
|
def test_partial_completion_logs_correct_percentage(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""Partial progress should show actual percentage, not 100%."""
|
|
tracker = ProgressTracker(total_records=10, label="test", log_interval_percent=10)
|
|
|
|
for _ in range(7):
|
|
tracker.record_success()
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
tracker.log_final()
|
|
|
|
assert "70%" in caplog.text
|
|
assert caplog.text.count("100%") == 0
|
|
|
|
|
|
def test_concurrent_completion_logs_100_percent_once(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""Thread safety: 100% logged exactly once even with concurrent completions."""
|
|
tracker = ProgressTracker(total_records=100, label="test", log_interval_percent=100)
|
|
|
|
def record_successes() -> None:
|
|
for _ in range(10):
|
|
tracker.record_success()
|
|
|
|
threads = [threading.Thread(target=record_successes) for _ in range(10)]
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
tracker.log_final()
|
|
|
|
assert tracker.completed == 100
|
|
assert caplog.text.count("100%") == 1
|