DataDesigner/tests/engine/models/conftest.py
Nabin Mulepati 8370e4a00b
feat: support native embedding generation (#106)
* Add generation type to ModelConfig

* pass tests

* added generate_text_embeddings

* tests

* remove sensitive=True old artifact no longer needed

* Slight refactor

* slight refactor

* Added embedding generator

* chunk_separator -> chunk_pattern

* update tests

* rename for consistency

* Restructure InferenceParameters -> CompletionInferenceParameters, BaseInferenceParameters, EmbeddingInferenceParameters

* Remove purpose from consolidated kwargs

* WithModelConfiguration.inference_parameters should should be typed with BaseInferenceParameters

* Type as WithModelGeneration

* Add image generation modality

* update return type for generate_kwargs

* make generation_type a field of ModelConfig as opposed to a prop resolved based on the type of InferenceParameters

* remove regex based chunking from embedding generator

* Remove image generation for now

* more tests and updates

* column_type_is_llm_generated -> column_type_is_model_generated

* change set to list: fix flaky tests

* CompletionInferenceParameters -> ChatCompletionInferenceParameters for consistency with generation_type

* Update docs

* fix deprecation warning originating from cli model settings

* update display of inference parameters in cli list

* save prog on inference parameter

* updates for the ocnfig builder

* update cli readme

* update cli for inference parmeters

* update inference parameter names

* flip order of vars

* WithCompletion -> WithChatCompletion

* specify InferenceParamsT

* Update columns.md with EmbeddingColumnConfig info

* make generation_type a descriminator field in inference params. add configuration support for max_parallel_requests and timeout

* DRY out some stuff in field.py

* Update nomenclature. prompt tokens -> input tokens, completion tokens -> output tokens in column statistics for consistency

* Add nvidia-embedding and openai-embedding to default model configs

* Fix typo in docs

* Make generate collab notebooks

* fine-tune -> adjust
2025-12-15 11:03:33 -07:00

74 lines
2.4 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import pytest
from data_designer.config.models import (
ChatCompletionInferenceParams,
EmbeddingInferenceParams,
ModelConfig,
)
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
from data_designer.engine.models.registry import ModelRegistry, create_model_registry
from data_designer.engine.secret_resolver import SecretsFileResolver
@pytest.fixture
def stub_secrets_resolver() -> SecretsFileResolver:
module_path = Path(__file__).parent
return SecretsFileResolver(module_path / "stub_secrets.json")
@pytest.fixture
def stub_model_provider_registry() -> ModelProviderRegistry:
return ModelProviderRegistry(
providers=[
ModelProvider(
name="stub-model-provider",
endpoint="https://api.example.com/v1",
provider_type="openai",
api_key="STUB_API_KEY",
)
]
)
@pytest.fixture
def stub_model_configs() -> list[ModelConfig]:
return [
ModelConfig(
alias="stub-text",
model="stub-model-text",
provider="stub-model-provider",
inference_parameters=ChatCompletionInferenceParams(
temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100
),
),
ModelConfig(
alias="stub-reasoning",
model="stub-model-reasoning",
provider="stub-model-provider",
inference_parameters=ChatCompletionInferenceParams(
temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100
),
),
ModelConfig(
alias="stub-embedding",
model="stub-model-embedding",
provider="stub-model-provider",
inference_parameters=EmbeddingInferenceParams(
dimensions=100,
),
),
]
@pytest.fixture
def stub_model_registry(stub_model_configs, stub_secrets_resolver, stub_model_provider_registry) -> ModelRegistry:
return create_model_registry(
model_configs=stub_model_configs,
secret_resolver=stub_secrets_resolver,
model_provider_registry=stub_model_provider_registry,
)