From c93fadc57c4b289926572c3bc8c75fe59548952b Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Wed, 18 Feb 2026 15:51:44 +0000 Subject: [PATCH] Fix mypy issues --- src/chronos/chronos2/dataset.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index afbc987..dfabeaa 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -19,9 +19,6 @@ if TYPE_CHECKING: TensorOrArray: TypeAlias = torch.Tensor | np.ndarray -# Type alias for raw input format -RawTask = Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]] - class PreparedTask(TypedDict): """A preprocessed time series task ready for model training/inference.""" @@ -220,7 +217,7 @@ def validate_and_prepare_single_dict_task( def prepare_tasks( - raw_tasks: Iterable[RawTask], + raw_tasks: Iterable[Mapping[str, Any]], prediction_length: int, min_past: int = 1, mode: "DatasetMode | str" = "train", @@ -277,10 +274,7 @@ def validate_prepared_schema(task: Any) -> None: required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"} missing = required_keys - set(task.keys()) if missing: - raise TypeError( - f"Task is missing required keys: {missing}. " - "Set convert_inputs=True to preprocess raw inputs." - ) + raise TypeError(f"Task is missing required keys: {missing}. Set convert_inputs=True to preprocess raw inputs.") context = task["context"] if not isinstance(context, (np.ndarray, torch.Tensor)) or context.ndim != 2: @@ -502,7 +496,7 @@ class Chronos2Dataset(IterableDataset): def __init__( self, - inputs: Sequence[PreparedTask | RawTask], + inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedTask], context_length: int, prediction_length: int, batch_size: int, @@ -514,16 +508,16 @@ class Chronos2Dataset(IterableDataset): super().__init__() assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}" + self.tasks: Sequence[PreparedTask] if convert_inputs: - # Convert various input formats to list of dicts if needed if isinstance(inputs, (torch.Tensor, np.ndarray)): inputs = convert_tensor_input_to_list_of_dicts_input(inputs) elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)): - inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs) - self.tasks = prepare_tasks(inputs, prediction_length, min_past, mode) + inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs)) + self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode) else: validate_prepared_schema(inputs[0]) - self.tasks = inputs + self.tasks = cast(Sequence[PreparedTask], inputs) self.context_length = context_length self.prediction_length = prediction_length