mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
update transformers version check
This commit is contained in:
parent
9ae8a656bf
commit
9ee10fc9b6
5 changed files with 11 additions and 8 deletions
|
|
@ -15,8 +15,8 @@ license = { file = "LICENSE" }
|
|||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.2,<3",
|
||||
"transformers>=4.41,<5",
|
||||
"accelerate>=0.34,<2",
|
||||
"transformers>=4.41",
|
||||
"accelerate>=1.1.0",
|
||||
"numpy>=1.21,<3",
|
||||
"einops>=0.7.0,<1",
|
||||
"scikit-learn>=1.6.0,<2",
|
||||
|
|
@ -41,14 +41,14 @@ path = "src/chronos/__about__.py"
|
|||
[project.optional-dependencies]
|
||||
extras = [
|
||||
"boto3>=1.10,<2",
|
||||
"peft>=0.13.0,<0.18",
|
||||
"peft>=0.18.1",
|
||||
"fev>=0.6.1",
|
||||
"pandas[pyarrow]>=2.0,<2.4",
|
||||
]
|
||||
test = [
|
||||
"pytest~=8.0",
|
||||
"boto3>=1.10,<2",
|
||||
"peft>=0.13.0,<1",
|
||||
"peft>=0.18.1",
|
||||
"fev>=0.6.1",
|
||||
"pandas[pyarrow]>=2.0,<2.4",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset, get_worker_info
|
||||
import transformers
|
||||
from packaging import version
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForCausalLM,
|
||||
|
|
@ -46,6 +47,7 @@ from gluonts.transform import (
|
|||
|
||||
from chronos import ChronosConfig, ChronosTokenizer
|
||||
|
||||
_TRANSFORMERS_V5 = version.parse(transformers.__version__) >= version.parse("5.0.0")
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
|
||||
|
|
@ -661,7 +663,7 @@ def main(
|
|||
per_device_train_batch_size=per_device_train_batch_size,
|
||||
learning_rate=learning_rate,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
warmup_ratio=warmup_ratio,
|
||||
**({"warmup_steps": round(warmup_ratio * max_steps)} if _TRANSFORMERS_V5 else {"warmup_ratio": warmup_ratio}),
|
||||
optim=optim,
|
||||
logging_strategy="steps",
|
||||
logging_steps=log_steps,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||
from transformers.utils import ModelOutput
|
||||
|
||||
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0")
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0")
|
||||
|
||||
if _TRANSFORMERS_V5:
|
||||
from transformers import initialization as init
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from transformers.utils.peft_utils import find_adapter_config_file
|
|||
import chronos.chronos2
|
||||
from chronos.base import BaseChronosPipeline, ForecastType
|
||||
from chronos.chronos2 import Chronos2Model
|
||||
from chronos.chronos2.model import _TRANSFORMERS_V5
|
||||
from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode, TensorOrArray
|
||||
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
|
||||
from chronos.utils import interpolate_quantiles, weighted_quantile
|
||||
|
|
@ -270,7 +271,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
per_device_eval_batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
lr_scheduler_type="linear",
|
||||
warmup_ratio=0.0,
|
||||
**({"warmup_steps": 0} if _TRANSFORMERS_V5 else {"warmup_ratio": 0.0}),
|
||||
optim="adamw_torch_fused",
|
||||
logging_strategy="steps",
|
||||
logging_steps=100,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from .base import BaseChronosPipeline, ForecastType
|
|||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0")
|
||||
_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0")
|
||||
|
||||
# In transformers v5, use guarded init functions that check _is_hf_initialized
|
||||
# to avoid re-initializing weights loaded from checkpoint
|
||||
|
|
|
|||
Loading…
Reference in a new issue