diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index dfabeaa..1b8a77a 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -227,9 +227,6 @@ def prepare_tasks( This function handles mode-specific preprocessing (e.g., filtering short series) and calls validate_and_prepare_single_dict_task for each task. """ - if isinstance(mode, str): - mode = DatasetMode(mode) - tasks: list[PreparedTask] = [] for idx, raw_task in enumerate(raw_tasks): @@ -268,20 +265,37 @@ def validate_prepared_schema(task: Any) -> None: if not isinstance(task, Mapping): raise TypeError( f"Expected task to be a dict-like, got {type(task).__name__}. " - "Set convert_inputs=True to preprocess raw inputs." + "Set convert_inputs=True when calling fit() to preprocess raw inputs." ) required_keys = {"context", "future_covariates", "n_targets", "n_covariates", "n_future_covariates"} missing = required_keys - set(task.keys()) if missing: - raise TypeError(f"Task is missing required keys: {missing}. Set convert_inputs=True to preprocess raw inputs.") + raise TypeError( + f"Task is missing required keys: {missing}. Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) context = task["context"] - if not isinstance(context, (np.ndarray, torch.Tensor)) or context.ndim != 2: + if not isinstance(context, torch.Tensor) or context.ndim != 2: raise TypeError( - f"Expected 'context' to be 2-d array, got {type(context).__name__} " + f"Expected 'context' to be 2-d torch.Tensor, got {type(context).__name__} " f"with shape {getattr(context, 'shape', 'N/A')}. " - "Set convert_inputs=True to preprocess raw inputs." + "Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + future_covariates = task["future_covariates"] + if not isinstance(future_covariates, torch.Tensor) or future_covariates.ndim != 2: + raise TypeError( + f"Expected 'future_covariates' to be 2-d torch.Tensor, got {type(future_covariates).__name__} " + f"with shape {getattr(future_covariates, 'shape', 'N/A')}. " + "Set convert_inputs=True when calling fit() to preprocess raw inputs." + ) + + if context.shape[0] != future_covariates.shape[0]: + raise ValueError( + f"Expected 'context' and 'future_covariates' to have the same first dimension, " + f"got {context.shape[0]} and {future_covariates.shape[0]}. " + "Set convert_inputs=True when calling fit() to preprocess raw inputs." ) @@ -512,7 +526,9 @@ class Chronos2Dataset(IterableDataset): if convert_inputs: if isinstance(inputs, (torch.Tensor, np.ndarray)): inputs = convert_tensor_input_to_list_of_dicts_input(inputs) - elif isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)): + elif ( + isinstance(inputs, Sequence) and len(inputs) > 0 and isinstance(inputs[0], (torch.Tensor, np.ndarray)) + ): inputs = convert_list_of_tensors_input_to_list_of_dicts_input(cast(Sequence[TensorOrArray], inputs)) self.tasks = prepare_tasks(cast(Iterable[Mapping[str, Any]], inputs), prediction_length, min_past, mode) else: @@ -528,7 +544,7 @@ class Chronos2Dataset(IterableDataset): def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: task = self.tasks[task_idx] - task_past_tensor = task["context"].clone() + task_past_tensor = task["context"].clone() # shape: (task_n_targets + task_n_covariates, history_length) task_future_tensor = task["future_covariates"].clone() task_n_targets = int(task["n_targets"]) task_n_covariates = int(task["n_covariates"]) @@ -672,39 +688,3 @@ class Chronos2Dataset(IterableDataset): yield batch else: yield from self._generate_sequential_batches() - - @classmethod - def convert_inputs( - cls, - inputs: TensorOrArray - | Sequence[TensorOrArray] - | Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]], - context_length: int, - prediction_length: int, - batch_size: int, - output_patch_size: int, - min_past: int = 1, - mode: str | DatasetMode = DatasetMode.TRAIN, - ) -> "Chronos2Dataset": - """Convert from different input formats to a Chronos2Dataset.""" - if isinstance(inputs, (torch.Tensor, np.ndarray)): - inputs = convert_tensor_input_to_list_of_dicts_input(inputs) - elif isinstance(inputs, list) and all([isinstance(x, (torch.Tensor, np.ndarray)) for x in inputs]): - inputs = cast(list[TensorOrArray], inputs) - inputs = convert_list_of_tensors_input_to_list_of_dicts_input(inputs) - elif isinstance(inputs, list) and all([isinstance(x, dict) for x in inputs]): - pass - else: - raise ValueError("Unexpected inputs format") - - inputs = cast(list[dict[str, TensorOrArray | dict[str, TensorOrArray]]], inputs) - - return cls( - inputs, - context_length=context_length, - prediction_length=prediction_length, - batch_size=batch_size, - output_patch_size=output_patch_size, - min_past=min_past, - mode=mode, - ) diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 8f4642f..bff11d1 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -616,8 +616,8 @@ class Chronos2Pipeline(BaseChronosPipeline): ) context_length = self.model_context_length - test_dataset = Chronos2Dataset.convert_inputs( - inputs=inputs, + test_dataset = Chronos2Dataset( + inputs, context_length=context_length, prediction_length=prediction_length, batch_size=batch_size, @@ -1142,8 +1142,8 @@ class Chronos2Pipeline(BaseChronosPipeline): ) context_length = self.model_context_length - test_dataset = Chronos2Dataset.convert_inputs( - inputs=inputs, + test_dataset = Chronos2Dataset( + inputs, context_length=context_length, prediction_length=0, batch_size=batch_size,