diff --git a/pyproject.toml b/pyproject.toml index d9e7117..d21994f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "numpy>=1.21,<3", "einops>=0.7.0,<1", "scikit-learn>=1.6.0,<2", + "typing_extensions~=4.0", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 20d2528..dbde6fe 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -5,7 +5,7 @@ import math from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, NotRequired, Sequence, TypeAlias, TypedDict, cast +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, TypeAlias, TypedDict, cast import numpy as np import torch @@ -16,6 +16,11 @@ if TYPE_CHECKING: import datasets import fev +try: + from typing import NotRequired # Python 3.11+ +except ImportError: + from typing_extensions import NotRequired + TensorOrArray: TypeAlias = torch.Tensor | np.ndarray