diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index 86917b4..d542d47 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -26,8 +26,8 @@ RawTask = Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]] class PreparedTask(TypedDict): """A preprocessed time series task ready for model training/inference.""" - context: np.ndarray # (n_variates, history_length), float32 - future_covariates: np.ndarray # (n_variates, prediction_length), float32 + context: torch.Tensor # (n_variates, history_length), float32 + future_covariates: torch.Tensor # (n_variates, prediction_length), float32 n_targets: int n_covariates: int n_future_covariates: int @@ -74,8 +74,15 @@ def validate_and_prepare_single_dict_task( Returns ------ - PreparedTask - A dictionary containing preprocessed arrays ready for model consumption. + A tuple containing: + - task_context_tensor: Concatenated tensor of target and past covariates of shape (group_size, history_length), + the first `task_n_targets` items along the first axis contain the target variables and the remaining items contain past-only covariates + and past values of known future covariates. + - task_future_covariates_tensor: Tensor of future covariates of shape (group_size, prediction_length). The last `task_n_future_covariates` + items along the first axis contain future covariates. All the remaining elements corresponding to target and past-only covariates are NaNs. + - task_n_targets: Number of target variables + - task_n_covariates: Total number of covariates (sum of past-only and known future covariates) + - task_n_future_covariates: Number of known future covariates """ allowed_keys = {"target", "past_covariates", "future_covariates"} @@ -89,18 +96,18 @@ def validate_and_prepare_single_dict_task( if "target" not in keys: raise ValueError(f"Element at index {idx} does not contain the required key 'target'") - # validate target - convert to numpy float32 (handles bfloat16 and other dtypes) + # validate target task_target = task["target"] - if isinstance(task_target, torch.Tensor): - task_target = task_target.to(torch.float32).numpy() - task_target = np.asarray(task_target, dtype=np.float32) + if isinstance(task_target, np.ndarray): + task_target = torch.from_numpy(task_target) + assert isinstance(task_target, torch.Tensor) if task_target.ndim > 2: raise ValueError( "When the input is a list of dicts, the `target` should either be 1-d with shape (history_length,) " f" or 2-d with shape (n_variates, history_length). Found element at index {idx} with shape {tuple(task_target.shape)}." ) history_length = task_target.shape[-1] - task_target = task_target.reshape(-1, history_length) + task_target = task_target.view(-1, history_length) # validate past_covariates cat_encoders: dict = {} @@ -128,81 +135,87 @@ def validate_and_prepare_single_dict_task( ) # create ordered keys: past-only first, then known-future (so known-future are the last rows) - task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] + task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] # past_only_keys task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys - task_past_covariates_list: list[np.ndarray] = [] + task_past_covariates_list: list[torch.Tensor] = [] for key in task_ordered_covariate_keys: tensor = task_past_covariates[key] - if isinstance(tensor, torch.Tensor): - tensor = tensor.to(torch.float32).numpy() - tensor = np.asarray(tensor) - # apply encoding to categorical variates - if not np.issubdtype(tensor.dtype, np.number): - # target encoding, if the target is 1-d - if task_target.shape[0] == 1: - cat_encoder = TargetEncoder(target_type="continuous", smooth=1.0) - X = tensor.astype(str).reshape(-1, 1) - y = task_target.reshape(-1) - mask = np.isfinite(y) - cat_encoder.fit(X[mask], y[mask]) - # ordinal encoding, if the target is > 1-d - else: - cat_encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan) - cat_encoder.fit(tensor.astype(str).reshape(-1, 1)) - tensor = cat_encoder.transform(tensor.astype(str).reshape(-1, 1)).reshape(tensor.shape) - cat_encoders[key] = cat_encoder + if isinstance(tensor, np.ndarray): + # apply encoding to categorical variates + if not np.issubdtype(tensor.dtype, np.number): + # target encoding, if the target is 1-d + if task_target.shape[0] == 1: + cat_encoder = TargetEncoder(target_type="continuous", smooth=1.0) + X = tensor.astype(str).reshape(-1, 1) + y = task_target.view(-1).numpy() + mask = np.isfinite(y) + X = X[mask] + y = y[mask] + cat_encoder.fit(X, y) + # ordinal encoding, if the target is > 1-d + else: + cat_encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan) + cat_encoder.fit(tensor.astype(str).reshape(-1, 1)) + tensor = cat_encoder.transform(tensor.astype(str).reshape(-1, 1)).reshape(tensor.shape) + cat_encoders[key] = cat_encoder + tensor = torch.from_numpy(tensor) + assert isinstance(tensor, torch.Tensor) if tensor.ndim != 1 or len(tensor) != history_length: raise ValueError( f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {history_length}), " f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}" ) task_past_covariates_list.append(tensor) - task_past_covariates_array = ( - np.stack(task_past_covariates_list, axis=0) + task_past_covariates_tensor = ( + torch.stack(task_past_covariates_list, dim=0) if task_past_covariates_list - else np.zeros((0, history_length), dtype=np.float32) + else torch.zeros((0, history_length), device=task_target.device) ) # validate future_covariates (build rows in the same task_ordered_covariate_keys order) - task_future_covariates_list: list[np.ndarray] = [] + task_future_covariates_list: list[torch.Tensor] = [] for key in task_ordered_covariate_keys: # future values of past-only covariates are filled with NaNs - tensor = task_future_covariates.get(key, np.full(prediction_length, np.nan)) - if isinstance(tensor, torch.Tensor): - tensor = tensor.to(torch.float32).numpy() - tensor = np.asarray(tensor) - # apply encoding to categorical variates - if not np.issubdtype(tensor.dtype, np.number): - cat_encoder = cat_encoders[key] - tensor = cat_encoder.transform(tensor.astype(str).reshape(-1, 1)).reshape(tensor.shape) + tensor = task_future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan)) + if isinstance(tensor, np.ndarray): + # apply encoding to categorical variates + if not np.issubdtype(tensor.dtype, np.number): + cat_encoder = cat_encoders[key] + tensor = cat_encoder.transform(tensor.astype(str).reshape(-1, 1)).reshape(tensor.shape) + tensor = torch.from_numpy(tensor) + assert isinstance(tensor, torch.Tensor) if tensor.ndim != 1 or len(tensor) != prediction_length: raise ValueError( f"Individual `future_covariates` must be 1-d with length equal to the {prediction_length=}, " f"found: {key} with shape {tuple(tensor.shape)} in element at index {idx}" ) task_future_covariates_list.append(tensor) - task_future_covariates_array = ( - np.stack(task_future_covariates_list, axis=0) + task_future_covariates_tensor = ( + torch.stack(task_future_covariates_list, dim=0) if task_future_covariates_list - else np.zeros((0, prediction_length), dtype=np.float32) + else torch.zeros((0, prediction_length), device=task_target.device) ) # future values of target series are filled with NaNs - task_future_covariates_target_padding = np.full( - (task_target.shape[0], prediction_length), np.nan, dtype=np.float32 + task_future_covariates_target_padding = torch.full( + (task_target.shape[0], prediction_length), fill_value=torch.nan, device=task_target.device ) - context = np.concatenate([task_target, task_past_covariates_array], axis=0).astype(np.float32) - future_covariates = np.concatenate( - [task_future_covariates_target_padding, task_future_covariates_array], axis=0 - ).astype(np.float32) + task_context_tensor = torch.cat([task_target, task_past_covariates_tensor], dim=0).to(dtype=torch.float32) + task_future_covariates_tensor = torch.cat( + [task_future_covariates_target_padding, task_future_covariates_tensor], dim=0 + ).to(dtype=torch.float32) + task_n_targets = task_target.shape[0] + task_n_covariates = task_past_covariates_tensor.shape[0] + # number of known-future covariates + task_n_future_covariates = len(task_future_covariates_keys) return PreparedTask( - context=context, - future_covariates=future_covariates, - n_targets=task_target.shape[0], - n_covariates=task_past_covariates_array.shape[0], - n_future_covariates=len(task_future_covariates_keys), + context=task_context_tensor, + future_covariates=task_future_covariates_tensor, + n_targets=task_n_targets, + n_covariates=task_n_covariates, + n_future_covariates=task_n_future_covariates, ) @@ -217,7 +230,6 @@ def prepare_tasks( This function handles mode-specific preprocessing (e.g., filtering short series) and calls validate_and_prepare_single_dict_task for each task. """ - # Import here to avoid issues with forward reference if isinstance(mode, str): mode = DatasetMode(mode) @@ -237,13 +249,13 @@ def prepare_tasks( raw_task = {**raw_task, "future_covariates": fixed_future_covariates} raw_task = cast(dict[str, TensorOrArray | Mapping[str, TensorOrArray]], raw_task) - prepared = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length) + task = validate_and_prepare_single_dict_task(raw_task, idx, prediction_length) # Filter by minimum length (except in TEST mode) - if mode != DatasetMode.TEST and prepared["context"].shape[-1] < min_past + prediction_length: + if mode != DatasetMode.TEST and task["context"].shape[-1] < min_past + prediction_length: continue - tasks.append(prepared) + tasks.append(task) if len(tasks) == 0: raise ValueError( @@ -522,15 +534,8 @@ class Chronos2Dataset(IterableDataset): def _construct_slice(self, task_idx: int) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, int]: task = self.tasks[task_idx] - # Convert numpy arrays to torch tensors if needed - context = task["context"] - future_cov = task["future_covariates"] - if isinstance(context, np.ndarray): - context = torch.from_numpy(context) - if isinstance(future_cov, np.ndarray): - future_cov = torch.from_numpy(future_cov) - task_past_tensor = context.clone().to(torch.float32) - task_future_tensor = future_cov.clone().to(torch.float32) + task_past_tensor = task["context"].clone() + task_future_tensor = task["future_covariates"].clone() task_n_targets = task["n_targets"] task_n_covariates = task["n_covariates"] task_n_future_covariates = task["n_future_covariates"]