handle dtype warning

This commit is contained in:
Kashif Rasul 2026-01-09 20:44:20 +01:00
parent 16a91d5b58
commit f669e6c5ca

View file

@ -347,9 +347,11 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
from transformers import AutoConfig
torch_dtype = kwargs.get("torch_dtype", "auto")
if torch_dtype != "auto" and isinstance(torch_dtype, str):
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
# Handle both torch_dtype (deprecated) and dtype arguments
dtype_value = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto")
if dtype_value != "auto" and isinstance(dtype_value, str):
dtype_value = cls.dtypes[dtype_value]
kwargs["dtype"] = dtype_value
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")