Merge branch 'main' into transformers-v5

This commit is contained in:
Kashif Rasul 2026-04-26 14:25:21 +02:00 committed by GitHub
commit 0db7dc1bdb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 22 additions and 8 deletions

View file

@ -348,9 +348,16 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
if str(pretrained_model_name_or_path).startswith("s3://"):
from .boto_utils import cache_model_from_s3
local_model_path = cache_model_from_s3(
str(pretrained_model_name_or_path), force_download=force_s3_download
)
try:
local_model_path = cache_model_from_s3(
str(pretrained_model_name_or_path), force_download=force_s3_download
)
except ImportError as e:
raise ImportError(
"Loading models from s3:// URIs requires boto3. "
"Install the optional dependencies with: "
"pip install 'chronos-forecasting[extras]'"
) from e
return cls.from_pretrained(local_model_path, *model_args, **kwargs)
from transformers import AutoConfig

View file

@ -3,17 +3,19 @@
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
from __future__ import annotations
import logging
import os
import re
import warnings
from pathlib import Path
from typing import TYPE_CHECKING
import boto3
import requests # type: ignore
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ClientError, NoCredentialsError
if TYPE_CHECKING:
import boto3
logger = logging.getLogger(__name__)
@ -57,6 +59,11 @@ def download_model_files_from_s3(
force_download: bool = False,
boto3_session: boto3.Session | None = None,
) -> None:
import boto3
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ClientError, NoCredentialsError
boto3_session = boto3_session or boto3.Session()
s3_client = boto3_session.client("s3")
@ -101,7 +108,7 @@ def cache_model_from_s3(
boto3_session: boto3.Session | None = None,
):
assert re.match("^s3://([^/]+)/(.*?([^/]+)/?)$", s3_uri) is not None, f"Not a valid S3 URI: {s3_uri}"
cache_home = Path(os.environ.get("XGD_CACHE_HOME", os.path.expanduser("~/.cache")))
cache_home = Path(os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")))
cache_dir = cache_home / "chronos"
s3_uri = s3_uri.rstrip("/")
bucket, prefix = s3_uri.replace("s3://", "").split("/", 1)