Fix mypy issues

This commit is contained in:
Oleksandr Shchur 2026-02-18 15:51:44 +00:00
parent 19c1b72a94
commit c93fadc57c

View file

@ -19,9 +19,6 @@ if TYPE_CHECKING:
TensorOrArray: TypeAlias = torch.Tensor | np.ndarray
# Type alias for raw input format
RawTask = Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]
class PreparedTask(TypedDict):
"""A preprocessed time series task ready for model training/inference."""
@ -220,7 +217,7 @@ def validate_and_prepare_single_dict_task(
def prepare_tasks(
raw_tasks: Iterable[RawTask],
raw_tasks: Iterable[Mapping[str, Any]],
prediction_length: int,
min_past: int = 1,
mode: "DatasetMode | str" = "train",
@ -277,10 +274,7 @@ def validate_prepared_schema(task: Any) -> None:
required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"}
missing = required_keys - set(task.keys())
if missing:
raise TypeError(
f"Task is missing required keys: {missing}. "
"Set convert_inputs=True to preprocess raw inputs."
)
raise TypeError(f"Task is missing required keys: {missing}. Set convert_inputs=True to preprocess raw inputs.")
context = task["context"]
if not isinstance(context, (np.ndarray, torch.Tensor)) or context.ndim != 2:
@ -502,7 +496,7 @@ class Chronos2Dataset(IterableDataset):
def __init__(
self,
inputs: Sequence[PreparedTask | RawTask],
inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedTask],
context_length: int,
prediction_length: int,
batch_size: int,
@ -514,16 +508,16 @@ class Chronos2Dataset(IterableDataset):
super().__init__()
assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}"
self.tasks: Sequence[PreparedTask]
if convert_inputs:
# Convert various input formats to list of dicts if needed
if isinstance(inputs, (torch.Tensor, np.ndarray)):
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)):
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs)
self.tasks = prepare_tasks(inputs, prediction_length, min_past, mode)
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs))
self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
else:
validate_prepared_schema(inputs[0])
self.tasks = inputs
self.tasks = cast(Sequence[PreparedTask], inputs)
self.context_length = context_length
self.prediction_length = prediction_length