DataDesigner/scripts/health_checks.py
Andre Manoel 16db8d61fa
fix(config): update OpenRouter vision model id (#630)
* fix(config): update OpenRouter vision model id

* fix(ci): harden provider health checks

* fix(config): use nano omni for OpenRouter vision

* docs: warn about hosted provider data handling

* fix(config): align OpenRouter vision params
2026-05-11 13:49:17 -03:00

135 lines
4.5 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Health checks for all predefined model providers.
Verifies that each model in each provider can respond to a basic request.
Providers without an API key set in the environment are skipped.
Usage:
uv run python scripts/health_checks.py
"""
import os
import sys
import traceback
from data_designer.config.models import (
ChatCompletionInferenceParams,
EmbeddingInferenceParams,
ModelConfig,
ModelProvider,
)
from data_designer.config.utils.constants import (
NVIDIA_API_KEY_ENV_VAR_NAME,
NVIDIA_PROVIDER_NAME,
OPENAI_API_KEY_ENV_VAR_NAME,
OPENAI_PROVIDER_NAME,
OPENROUTER_API_KEY_ENV_VAR_NAME,
OPENROUTER_PROVIDER_NAME,
PREDEFINED_PROVIDERS,
PREDEFINED_PROVIDERS_MODEL_MAP,
)
from data_designer.engine.model_provider import ModelProviderRegistry
from data_designer.engine.models.clients.factory import create_model_client
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.secret_resolver import EnvironmentResolver
PROVIDER_API_KEY_ENV_VARS = {
NVIDIA_PROVIDER_NAME: NVIDIA_API_KEY_ENV_VAR_NAME,
OPENAI_PROVIDER_NAME: OPENAI_API_KEY_ENV_VAR_NAME,
OPENROUTER_PROVIDER_NAME: OPENROUTER_API_KEY_ENV_VAR_NAME,
}
CHAT_COMPLETION_ATTEMPTS = 2
def _get_provider_registry(provider_name: str) -> ModelProviderRegistry:
provider_data = next(p for p in PREDEFINED_PROVIDERS if p["name"] == provider_name)
provider = ModelProvider(**provider_data)
return ModelProviderRegistry(providers=[provider])
def _check_model(provider_name: str, model_type: str) -> None:
provider_registry = _get_provider_registry(provider_name)
secret_resolver = EnvironmentResolver()
model_info = PREDEFINED_PROVIDERS_MODEL_MAP[provider_name][model_type]
model_name = model_info["model"]
inference_params = model_info["inference_parameters"]
if model_type == "embedding":
params = EmbeddingInferenceParams(**inference_params)
else:
params = ChatCompletionInferenceParams(**inference_params)
model_config = ModelConfig(
alias=f"{provider_name}-{model_type}",
model=model_name,
inference_parameters=params,
provider=provider_name,
)
client = create_model_client(model_config, secret_resolver, provider_registry)
facade = ModelFacade(model_config, provider_registry, client=client)
if model_type == "embedding":
result = facade.generate_text_embeddings(
input_texts=["Hello!"],
skip_usage_tracking=True,
)
if len(result) != 1 or len(result[0]) == 0:
raise AssertionError(
f"Expected one non-empty embedding from {provider_name}/{model_type} "
f"({model_name}); got {len(result)} embeddings"
)
return
result = None
for attempt in range(1, CHAT_COMPLETION_ATTEMPTS + 1):
result, _ = facade.generate(
prompt="Say 'OK' and nothing else.",
parser=lambda x: x,
system_prompt="You are a helpful assistant.",
max_correction_steps=0,
max_conversation_restarts=0,
skip_usage_tracking=True,
)
if isinstance(result, str) and len(result) > 0:
return
if attempt < CHAT_COMPLETION_ATTEMPTS:
print(f"RETRY {provider_name}/{model_type} ({model_name}) returned {result!r}")
raise AssertionError(
f"Expected non-empty chat response from {provider_name}/{model_type} "
f"({model_name}) after {CHAT_COMPLETION_ATTEMPTS} attempts; got {result!r}"
)
def main() -> int:
passed, failed, skipped = 0, 0, 0
for provider_name, env_var in PROVIDER_API_KEY_ENV_VARS.items():
if not os.environ.get(env_var):
models = list(PREDEFINED_PROVIDERS_MODEL_MAP[provider_name])
skipped += len(models)
print(f"SKIP {provider_name} ({env_var} not set)")
continue
for model_type in PREDEFINED_PROVIDERS_MODEL_MAP[provider_name]:
label = f"{provider_name}/{model_type}"
try:
_check_model(provider_name, model_type)
passed += 1
print(f"PASS {label}")
except Exception:
failed += 1
tb = traceback.format_exc()
print(f"FAIL {label}\n{tb}")
print(f"\n{passed} passed, {failed} failed, {skipped} skipped")
return 1 if failed else 0
if __name__ == "__main__":
sys.exit(main())