mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
handle dtype warning
This commit is contained in:
parent
16a91d5b58
commit
f669e6c5ca
1 changed files with 5 additions and 3 deletions
|
|
@ -347,9 +347,11 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
|
||||
from transformers import AutoConfig
|
||||
|
||||
torch_dtype = kwargs.get("torch_dtype", "auto")
|
||||
if torch_dtype != "auto" and isinstance(torch_dtype, str):
|
||||
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
|
||||
# Handle both torch_dtype (deprecated) and dtype arguments
|
||||
dtype_value = kwargs.pop("torch_dtype", None) or kwargs.pop("dtype", "auto")
|
||||
if dtype_value != "auto" and isinstance(dtype_value, str):
|
||||
dtype_value = cls.dtypes[dtype_value]
|
||||
kwargs["dtype"] = dtype_value
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(config, "chronos_config")
|
||||
|
|
|
|||
Loading…
Reference in a new issue