diff --git a/src/chronos/base.py b/src/chronos/base.py index 2bba087..807f91d 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -348,9 +348,16 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): if str(pretrained_model_name_or_path).startswith("s3://"): from .boto_utils import cache_model_from_s3 - local_model_path = cache_model_from_s3( - str(pretrained_model_name_or_path), force_download=force_s3_download - ) + try: + local_model_path = cache_model_from_s3( + str(pretrained_model_name_or_path), force_download=force_s3_download + ) + except ImportError as e: + raise ImportError( + "Loading models from s3:// URIs requires boto3. " + "Install the optional dependencies with: " + "pip install 'chronos-forecasting[extras]'" + ) from e return cls.from_pretrained(local_model_path, *model_args, **kwargs) from transformers import AutoConfig diff --git a/src/chronos/boto_utils.py b/src/chronos/boto_utils.py index d339333..48393d6 100644 --- a/src/chronos/boto_utils.py +++ b/src/chronos/boto_utils.py @@ -3,17 +3,19 @@ # Authors: Abdul Fatir Ansari +from __future__ import annotations + import logging import os import re import warnings from pathlib import Path +from typing import TYPE_CHECKING -import boto3 import requests # type: ignore -from botocore import UNSIGNED -from botocore.client import Config -from botocore.exceptions import ClientError, NoCredentialsError + +if TYPE_CHECKING: + import boto3 logger = logging.getLogger(__name__) @@ -57,6 +59,11 @@ def download_model_files_from_s3( force_download: bool = False, boto3_session: boto3.Session | None = None, ) -> None: + import boto3 + from botocore import UNSIGNED + from botocore.client import Config + from botocore.exceptions import ClientError, NoCredentialsError + boto3_session = boto3_session or boto3.Session() s3_client = boto3_session.client("s3")