mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Merge branch 'main' into transformers-v5
This commit is contained in:
commit
0db7dc1bdb
2 changed files with 22 additions and 8 deletions
|
|
@ -348,9 +348,16 @@ class BaseChronosPipeline(metaclass=PipelineRegistry):
|
|||
if str(pretrained_model_name_or_path).startswith("s3://"):
|
||||
from .boto_utils import cache_model_from_s3
|
||||
|
||||
local_model_path = cache_model_from_s3(
|
||||
str(pretrained_model_name_or_path), force_download=force_s3_download
|
||||
)
|
||||
try:
|
||||
local_model_path = cache_model_from_s3(
|
||||
str(pretrained_model_name_or_path), force_download=force_s3_download
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Loading models from s3:// URIs requires boto3. "
|
||||
"Install the optional dependencies with: "
|
||||
"pip install 'chronos-forecasting[extras]'"
|
||||
) from e
|
||||
return cls.from_pretrained(local_model_path, *model_args, **kwargs)
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
|
|
|||
|
|
@ -3,17 +3,19 @@
|
|||
|
||||
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import boto3
|
||||
import requests # type: ignore
|
||||
from botocore import UNSIGNED
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import boto3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -57,6 +59,11 @@ def download_model_files_from_s3(
|
|||
force_download: bool = False,
|
||||
boto3_session: boto3.Session | None = None,
|
||||
) -> None:
|
||||
import boto3
|
||||
from botocore import UNSIGNED
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
|
||||
boto3_session = boto3_session or boto3.Session()
|
||||
s3_client = boto3_session.client("s3")
|
||||
|
||||
|
|
@ -101,7 +108,7 @@ def cache_model_from_s3(
|
|||
boto3_session: boto3.Session | None = None,
|
||||
):
|
||||
assert re.match("^s3://([^/]+)/(.*?([^/]+)/?)$", s3_uri) is not None, f"Not a valid S3 URI: {s3_uri}"
|
||||
cache_home = Path(os.environ.get("XGD_CACHE_HOME", os.path.expanduser("~/.cache")))
|
||||
cache_home = Path(os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")))
|
||||
cache_dir = cache_home / "chronos"
|
||||
s3_uri = s3_uri.rstrip("/")
|
||||
bucket, prefix = s3_uri.replace("s3://", "").split("/", 1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue