mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Fix mypy issues
This commit is contained in:
parent
19c1b72a94
commit
c93fadc57c
1 changed files with 7 additions and 13 deletions
|
|
@ -19,9 +19,6 @@ if TYPE_CHECKING:
|
|||
|
||||
TensorOrArray: TypeAlias = torch.Tensor | np.ndarray
|
||||
|
||||
# Type alias for raw input format
|
||||
RawTask = Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]
|
||||
|
||||
|
||||
class PreparedTask(TypedDict):
|
||||
"""A preprocessed time series task ready for model training/inference."""
|
||||
|
|
@ -220,7 +217,7 @@ def validate_and_prepare_single_dict_task(
|
|||
|
||||
|
||||
def prepare_tasks(
|
||||
raw_tasks: Iterable[RawTask],
|
||||
raw_tasks: Iterable[Mapping[str, Any]],
|
||||
prediction_length: int,
|
||||
min_past: int = 1,
|
||||
mode: "DatasetMode | str" = "train",
|
||||
|
|
@ -277,10 +274,7 @@ def validate_prepared_schema(task: Any) -> None:
|
|||
required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"}
|
||||
missing = required_keys - set(task.keys())
|
||||
if missing:
|
||||
raise TypeError(
|
||||
f"Task is missing required keys: {missing}. "
|
||||
"Set convert_inputs=True to preprocess raw inputs."
|
||||
)
|
||||
raise TypeError(f"Task is missing required keys: {missing}. Set convert_inputs=True to preprocess raw inputs.")
|
||||
|
||||
context = task["context"]
|
||||
if not isinstance(context, (np.ndarray, torch.Tensor)) or context.ndim != 2:
|
||||
|
|
@ -502,7 +496,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Sequence[PreparedTask | RawTask],
|
||||
inputs: TensorOrArray | Sequence[TensorOrArray] | Sequence[Mapping[str, Any]] | Sequence[PreparedTask],
|
||||
context_length: int,
|
||||
prediction_length: int,
|
||||
batch_size: int,
|
||||
|
|
@ -514,16 +508,16 @@ class Chronos2Dataset(IterableDataset):
|
|||
super().__init__()
|
||||
assert mode in {DatasetMode.TRAIN, DatasetMode.VALIDATION, DatasetMode.TEST}, f"Invalid mode: {mode}"
|
||||
|
||||
self.tasks: Sequence[PreparedTask]
|
||||
if convert_inputs:
|
||||
# Convert various input formats to list of dicts if needed
|
||||
if isinstance(inputs, (torch.Tensor, np.ndarray)):
|
||||
inputs = convert_tensor_input_to_list_of_dicts_input(inputs)
|
||||
elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)):
|
||||
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs)
|
||||
self.tasks = prepare_tasks(inputs, prediction_length, min_past, mode)
|
||||
inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs))
|
||||
self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode)
|
||||
else:
|
||||
validate_prepared_schema(inputs[0])
|
||||
self.tasks = inputs
|
||||
self.tasks = cast(Sequence[PreparedTask], inputs)
|
||||
|
||||
self.context_length = context_length
|
||||
self.prediction_length = prediction_length
|
||||
|
|
|
|||
Loading…
Reference in a new issue