From f669e6c5ca58a85659a22471ab3ab47dce970cd3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 9 Jan 2026 20:44:20 +0100 Subject: [PATCH] handle dtype warning --- src/chronos/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/chronos/base.py b/src/chronos/base.py index 7592c46..3f92972 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -347,9 +347,11 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): from transformers import AutoConfig - torch_dtype = kwargs.get("torch_dtype", "auto") - if torch_dtype != "auto" and isinstance(torch_dtype, str): - kwargs["torch_dtype"] = cls.dtypes[torch_dtype] + # Handle both torch_dtype (deprecated) and dtype arguments + dtype_value = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto") + if dtype_value != "auto" and isinstance(dtype_value, str): + dtype_value = cls.dtypes[dtype_value] + kwargs["dtype"] = dtype_value config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")