diff --git a/pyproject.toml b/pyproject.toml index d9e7117..dd4e4d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ "torch>=2.2,<3", - "transformers>=4.41,<5", - "accelerate>=0.34,<2", + "transformers>=4.41", + "accelerate>=1.1.0", "numpy>=1.21,<3", "einops>=0.7.0,<1", "scikit-learn>=1.6.0,<2", @@ -41,14 +41,14 @@ path = "src/chronos/__about__.py" [project.optional-dependencies] extras = [ "boto3>=1.10,<2", - "peft>=0.13.0,<0.18", + "peft>=0.18.1", "fev>=0.6.1", "pandas[pyarrow]>=2.0,<2.4", ] test = [ "pytest~=8.0", "boto3>=1.10,<2", - "peft>=0.13.0,<1", + "peft>=0.18.1", "fev>=0.6.1", "pandas[pyarrow]>=2.0,<2.4", ] diff --git a/scripts/training/train.py b/scripts/training/train.py index 09d5d8e..c586460 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -21,6 +21,7 @@ import torch import torch.distributed as dist from torch.utils.data import IterableDataset, get_worker_info import transformers +from packaging import version from transformers import ( AutoModelForSeq2SeqLM, AutoModelForCausalLM, @@ -46,6 +47,7 @@ from gluonts.transform import ( from chronos import ChronosConfig, ChronosTokenizer +_TRANSFORMERS_V5 = version.parse(transformers.__version__) >= version.parse("5.0.0") app = typer.Typer(pretty_exceptions_enable=False) @@ -661,7 +663,7 @@ def main( per_device_train_batch_size=per_device_train_batch_size, learning_rate=learning_rate, lr_scheduler_type=lr_scheduler_type, - warmup_ratio=warmup_ratio, + **({"warmup_steps": round(warmup_ratio * max_steps)} if _TRANSFORMERS_V5 else {"warmup_ratio": warmup_ratio}), optim=optim, logging_strategy="steps", logging_steps=log_steps, diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index c96d0bc..99eefb0 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -16,7 +16,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput -_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0") +_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0") if _TRANSFORMERS_V5: from transformers import initialization as init diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 223689d..2cefa5f 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -22,6 +22,7 @@ from transformers.utils.peft_utils import find_adapter_config_file import chronos.chronos2 from chronos.base import BaseChronosPipeline, ForecastType from chronos.chronos2 import Chronos2Model +from chronos.chronos2.model import _TRANSFORMERS_V5 from chronos.chronos2.dataset import Chronos2Dataset, DatasetMode, TensorOrArray from chronos.df_utils import convert_df_input_to_list_of_dicts_input from chronos.utils import interpolate_quantiles, weighted_quantile @@ -270,7 +271,7 @@ class Chronos2Pipeline(BaseChronosPipeline): per_device_eval_batch_size=batch_size, learning_rate=learning_rate, lr_scheduler_type="linear", - warmup_ratio=0.0, + **({"warmup_steps": 0} if _TRANSFORMERS_V5 else {"warmup_ratio": 0.0}), optim="adamw_torch_fused", logging_strategy="steps", logging_steps=100, diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 5b68029..aa9c1b7 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -29,7 +29,7 @@ from .base import BaseChronosPipeline, ForecastType logger = logging.getLogger(__file__) -_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0.dev0") +_TRANSFORMERS_V5 = version.parse(transformers_version) >= version.parse("5.0.0") # In transformers v5, use guarded init functions that check _is_hf_initialized # to avoid re-initializing weights loaded from checkpoint