From 9e1d8ff96121c022c8777c0fcd552cafc6afb269 Mon Sep 17 00:00:00 2001 From: Jorge Emiliano Date: Wed, 28 Jan 2026 09:03:24 -0300 Subject: [PATCH] feat(chronos2): add native support for static covariates --- src/chronos/chronos2/config.py | 1 + src/chronos/chronos2/dataset.py | 29 ++++++++++++++++++++++---- src/chronos/chronos2/model.py | 37 ++++++++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/chronos/chronos2/config.py b/src/chronos/chronos2/config.py index c6e011c..43fa3e9 100644 --- a/src/chronos/chronos2/config.py +++ b/src/chronos/chronos2/config.py @@ -110,6 +110,7 @@ class Chronos2ForecastingConfig: use_arcsinh: bool = False max_output_patches: int = 1 time_encoding_scale: int | None = None + n_static_covariates: int = 0 @classmethod def editable_fields(cls) -> list[str]: diff --git a/src/chronos/chronos2/dataset.py b/src/chronos/chronos2/dataset.py index cb75571..253cb34 100644 --- a/src/chronos/chronos2/dataset.py +++ b/src/chronos/chronos2/dataset.py @@ -39,7 +39,7 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor: def validate_and_prepare_single_dict_task( task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int -) -> tuple[torch.Tensor, torch.Tensor, int, int, int]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """Validates and prepares a single dictionary task for Chronos2Model. Parameters @@ -72,7 +72,7 @@ def validate_and_prepare_single_dict_task( - task_n_future_covariates: Number of known future covariates """ - allowed_keys = {"target", "past_covariates", "future_covariates"} + allowed_keys = {"target", "past_covariates", "future_covariates", "static_covariantes"} # validate keys keys = set(task.keys()) @@ -96,6 +96,22 @@ def validate_and_prepare_single_dict_task( history_length = task_target.shape[-1] task_target = task_target.view(-1, history_length) + # Validate static_covariates + task_static_covariates = task.get("static_covariates", None) + if task_static_covariates is not None: + if isinstance(task_static_covariates, np.ndarray): + task_static_covariates = torch.from_numpy(task_static_covariates) + if not isinstance(task_static_covariates, torch.Tensor): + raise ValueError(f"static_covariates must be a numpy array or torch tensor at index {idx}") + + if task_static_covariates.ndim == 0: + task_static_covariates = task_static_covariates.unsqueeze(0) + elif task_static_covariates.ndim > 1: + task_static_covariates = task_static_covariates.view(-1) + + else: + task_static_covariantes = torch.zeros((0,), device=task_target.device) + # validate past_covariates cat_encoders: dict = {} task_past_covariates = task.get("past_covariates", {}) @@ -200,6 +216,7 @@ def validate_and_prepare_single_dict_task( return ( task_context_tensor, task_future_covariates_tensor, + task_static_covariates, task_n_targets, task_n_covariates, task_n_future_covariates, @@ -473,6 +490,7 @@ class Chronos2Dataset(IterableDataset): ( task_past_tensor, # shape: (task_n_targets + task_n_covariates, history_length) task_future_tensor, + task_static_tensor, task_n_targets, task_n_covariates, task_n_future_covariates, @@ -533,25 +551,27 @@ class Chronos2Dataset(IterableDataset): # task_future_covariates: (task_n_targets + task_n_past_only_covariates + task_n_future_covariates, prediction_length), # the entries corresponding to targets and past-only covariates are NaNs - return task_context, task_future_target, task_future_covariates, task_n_targets + return task_context, task_future_target, task_future_covariates, task_static_tensor, task_n_targets def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]: """Build a batch from given task indices.""" batch_context_tensor_list = [] batch_future_target_tensor_list = [] batch_future_covariates_tensor_list = [] + batch_static_covariates_list = [] batch_group_ids_list = [] target_idx_ranges: list[tuple[int, int]] = [] target_start_idx = 0 for group_id, task_idx in enumerate(task_indices): - task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx) + task_context, task_future_target, task_future_covariates, static_tensor, task_n_targets = self._construct_slice(task_idx) group_size = task_context.shape[0] task_group_ids = torch.full((group_size,), fill_value=group_id) batch_context_tensor_list.append(task_context) batch_future_target_tensor_list.append(task_future_target) batch_future_covariates_tensor_list.append(task_future_covariates) + batch_static_covariates_list.append(static_tensor) batch_group_ids_list.append(task_group_ids) target_idx_ranges.append((target_start_idx, target_start_idx + task_n_targets)) target_start_idx += group_size @@ -562,6 +582,7 @@ class Chronos2Dataset(IterableDataset): if self.mode == DatasetMode.TEST else torch.cat(cast(list[torch.Tensor], batch_future_target_tensor_list), dim=0), "future_covariates": torch.cat(batch_future_covariates_tensor_list, dim=0), + "static_covariates": torch.stack(batch_static_covariates_list, dim=0), "group_ids": torch.cat(batch_group_ids_list, dim=0), "num_output_patches": self.num_output_patches, "target_idx_ranges": target_idx_ranges, diff --git a/src/chronos/chronos2/model.py b/src/chronos/chronos2/model.py index 0397be2..9416827 100644 --- a/src/chronos/chronos2/model.py +++ b/src/chronos/chronos2/model.py @@ -235,6 +235,16 @@ class Chronos2Model(PreTrainedModel): dropout_p=config.dropout_rate, ) + # Embedding for static covariables + if self.chronos_config.n_static_covariates > 0: + self.static_covariates_embedding = ResidualBlock( + in_dim=self.chronos_config.n_static_covariates, + h_dim=config.d_ff, + out_dim=config.d_model, + act_fn_name=config.dense_act_fn, + dropout_p=config.dropout_rate, + ) + # patching layer self.patch = Patch( patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride @@ -322,12 +332,18 @@ class Chronos2Model(PreTrainedModel): num_output_patches: int, future_target: torch.Tensor | None, future_target_mask: torch.Tensor | None, + static_covariates: torch.Tensor | None, ): output_patch_size = self.chronos_config.output_patch_size if context.ndim != 2: raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}") if context_mask is not None and context_mask.shape != context.shape: raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}") + if self.chronos_config.n_static_covariates > 0: + if static_covariates is None: + raise ValueError(f"Model expects {self.chronos_config.n_static_covariates} static covariates, but None provided.") + if static_covariates.shape[0] != context.shape[0] or static_covariates.shape[-1] != self.chronos_config.n_static_covariates: + raise ValueError(f"static_covariates must have shape (batch_size, {self.chronos_config.n_static_covariates}), found {tuple(static_covariates.shape)}") if future_covariates is not None: if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2: raise ValueError( @@ -557,6 +573,7 @@ class Chronos2Model(PreTrainedModel): num_output_patches: int = 1, future_target: torch.Tensor | None = None, future_target_mask: torch.Tensor | None = None, + static_covariates: torch.Tensor | None = None, output_attentions: bool = False, ): self._validate_input( @@ -568,6 +585,7 @@ class Chronos2Model(PreTrainedModel): num_output_patches=num_output_patches, future_target=future_target, future_target_mask=future_target_mask, + static_covariates=static_covariates, ) batch_size = context.shape[0] @@ -578,6 +596,14 @@ class Chronos2Model(PreTrainedModel): # get input embeddings of shape (batch, num_context_patches, d_model) input_embeds: torch.Tensor = self.input_patch_embedding(patched_context) + + # Injection Static Covariates + if self.chronos_config.n_static_covariates > 0: + static_embeds = self.static_covariates_embedding(static_covariates.to(self.dtype)) + static_embeds = static_embeds.unsqueeze(1) + input_embeds = torch.cat([static_embeds, input_embeds], dim=-2) + static_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([static_mask, attention_mask], dim=-1) # append [REG] special token embedding, if needed if self.chronos_config.use_reg_token: reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device) @@ -625,6 +651,7 @@ class Chronos2Model(PreTrainedModel): num_output_patches: int = 1, future_target: torch.Tensor | None = None, future_target_mask: torch.Tensor | None = None, + static_covariates: torch.Tensor | None = None, output_attentions: bool = False, ) -> Chronos2Output: """Forward pass of the Chronos2 model. @@ -703,10 +730,18 @@ class Chronos2Model(PreTrainedModel): num_output_patches=num_output_patches, future_target=future_target, future_target_mask=future_target_mask, + static_covariates=static_covariates, output_attentions=output_attentions, ) hidden_states: torch.Tensor = encoder_outputs[0] - assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) + + expect_seq_len = num_context_patches + 1 + num_output_patches + if self.chronos_config.use_reg_token: + expect_seq_len += 1 + if self.chronos_config.n_static_covariates > 0: + expect_seq_len += 1 + + # assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim) # slice the last num_output_patches hidden states to be input into the output_patch_embedding forecast_embeds = hidden_states[:, -num_output_patches:]