feat(chronos2): add native support for static covariates

This commit is contained in:
Jorge Emiliano 2026-01-28 09:03:24 -03:00
parent 1f099eb265
commit 9e1d8ff961
3 changed files with 62 additions and 5 deletions

View file

@ -110,6 +110,7 @@ class Chronos2ForecastingConfig:
use_arcsinh: bool = False
max_output_patches: int = 1
time_encoding_scale: int | None = None
n_static_covariates: int = 0
@classmethod
def editable_fields(cls) -> list[str]:

View file

@ -39,7 +39,7 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
def validate_and_prepare_single_dict_task(
task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int
) -> tuple[torch.Tensor, torch.Tensor, int, int, int]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""Validates and prepares a single dictionary task for Chronos2Model.
Parameters
@ -72,7 +72,7 @@ def validate_and_prepare_single_dict_task(
- task_n_future_covariates: Number of known future covariates
"""
allowed_keys = {"target", "past_covariates", "future_covariates"}
allowed_keys = {"target", "past_covariates", "future_covariates", "static_covariantes"}
# validate keys
keys = set(task.keys())
@ -96,6 +96,22 @@ def validate_and_prepare_single_dict_task(
history_length = task_target.shape[-1]
task_target = task_target.view(-1, history_length)
# Validate static_covariates
task_static_covariates = task.get("static_covariates", None)
if task_static_covariates is not None:
if isinstance(task_static_covariates, np.ndarray):
task_static_covariates = torch.from_numpy(task_static_covariates)
if not isinstance(task_static_covariates, torch.Tensor):
raise ValueError(f"static_covariates must be a numpy array or torch tensor at index {idx}")
if task_static_covariates.ndim == 0:
task_static_covariates = task_static_covariates.unsqueeze(0)
elif task_static_covariates.ndim > 1:
task_static_covariates = task_static_covariates.view(-1)
else:
task_static_covariantes = torch.zeros((0,), device=task_target.device)
# validate past_covariates
cat_encoders: dict = {}
task_past_covariates = task.get("past_covariates", {})
@ -200,6 +216,7 @@ def validate_and_prepare_single_dict_task(
return (
task_context_tensor,
task_future_covariates_tensor,
task_static_covariates,
task_n_targets,
task_n_covariates,
task_n_future_covariates,
@ -473,6 +490,7 @@ class Chronos2Dataset(IterableDataset):
(
task_past_tensor, # shape: (task_n_targets + task_n_covariates, history_length)
task_future_tensor,
task_static_tensor,
task_n_targets,
task_n_covariates,
task_n_future_covariates,
@ -533,25 +551,27 @@ class Chronos2Dataset(IterableDataset):
# task_future_covariates: (task_n_targets + task_n_past_only_covariates + task_n_future_covariates, prediction_length),
# the entries corresponding to targets and past-only covariates are NaNs
return task_context, task_future_target, task_future_covariates, task_n_targets
return task_context, task_future_target, task_future_covariates, task_static_tensor, task_n_targets
def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]:
"""Build a batch from given task indices."""
batch_context_tensor_list = []
batch_future_target_tensor_list = []
batch_future_covariates_tensor_list = []
batch_static_covariates_list = []
batch_group_ids_list = []
target_idx_ranges: list[tuple[int, int]] = []
target_start_idx = 0
for group_id, task_idx in enumerate(task_indices):
task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx)
task_context, task_future_target, task_future_covariates, static_tensor, task_n_targets = self._construct_slice(task_idx)
group_size = task_context.shape[0]
task_group_ids = torch.full((group_size,), fill_value=group_id)
batch_context_tensor_list.append(task_context)
batch_future_target_tensor_list.append(task_future_target)
batch_future_covariates_tensor_list.append(task_future_covariates)
batch_static_covariates_list.append(static_tensor)
batch_group_ids_list.append(task_group_ids)
target_idx_ranges.append((target_start_idx, target_start_idx + task_n_targets))
target_start_idx += group_size
@ -562,6 +582,7 @@ class Chronos2Dataset(IterableDataset):
if self.mode == DatasetMode.TEST
else torch.cat(cast(list[torch.Tensor], batch_future_target_tensor_list), dim=0),
"future_covariates": torch.cat(batch_future_covariates_tensor_list, dim=0),
"static_covariates": torch.stack(batch_static_covariates_list, dim=0),
"group_ids": torch.cat(batch_group_ids_list, dim=0),
"num_output_patches": self.num_output_patches,
"target_idx_ranges": target_idx_ranges,

View file

@ -235,6 +235,16 @@ class Chronos2Model(PreTrainedModel):
dropout_p=config.dropout_rate,
)
# Embedding for static covariables
if self.chronos_config.n_static_covariates > 0:
self.static_covariates_embedding = ResidualBlock(
in_dim=self.chronos_config.n_static_covariates,
h_dim=config.d_ff,
out_dim=config.d_model,
act_fn_name=config.dense_act_fn,
dropout_p=config.dropout_rate,
)
# patching layer
self.patch = Patch(
patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride
@ -322,12 +332,18 @@ class Chronos2Model(PreTrainedModel):
num_output_patches: int,
future_target: torch.Tensor | None,
future_target_mask: torch.Tensor | None,
static_covariates: torch.Tensor | None,
):
output_patch_size = self.chronos_config.output_patch_size
if context.ndim != 2:
raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}")
if context_mask is not None and context_mask.shape != context.shape:
raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}")
if self.chronos_config.n_static_covariates > 0:
if static_covariates is None:
raise ValueError(f"Model expects {self.chronos_config.n_static_covariates} static covariates, but None provided.")
if static_covariates.shape[0] != context.shape[0] or static_covariates.shape[-1] != self.chronos_config.n_static_covariates:
raise ValueError(f"static_covariates must have shape (batch_size, {self.chronos_config.n_static_covariates}), found {tuple(static_covariates.shape)}")
if future_covariates is not None:
if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2:
raise ValueError(
@ -557,6 +573,7 @@ class Chronos2Model(PreTrainedModel):
num_output_patches: int = 1,
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
static_covariates: torch.Tensor | None = None,
output_attentions: bool = False,
):
self._validate_input(
@ -568,6 +585,7 @@ class Chronos2Model(PreTrainedModel):
num_output_patches=num_output_patches,
future_target=future_target,
future_target_mask=future_target_mask,
static_covariates=static_covariates,
)
batch_size = context.shape[0]
@ -578,6 +596,14 @@ class Chronos2Model(PreTrainedModel):
# get input embeddings of shape (batch, num_context_patches, d_model)
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
# Injection Static Covariates
if self.chronos_config.n_static_covariates > 0:
static_embeds = self.static_covariates_embedding(static_covariates.to(self.dtype))
static_embeds = static_embeds.unsqueeze(1)
input_embeds = torch.cat([static_embeds, input_embeds], dim=-2)
static_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([static_mask, attention_mask], dim=-1)
# append [REG] special token embedding, if needed
if self.chronos_config.use_reg_token:
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
@ -625,6 +651,7 @@ class Chronos2Model(PreTrainedModel):
num_output_patches: int = 1,
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
static_covariates: torch.Tensor | None = None,
output_attentions: bool = False,
) -> Chronos2Output:
"""Forward pass of the Chronos2 model.
@ -703,10 +730,18 @@ class Chronos2Model(PreTrainedModel):
num_output_patches=num_output_patches,
future_target=future_target,
future_target_mask=future_target_mask,
static_covariates=static_covariates,
output_attentions=output_attentions,
)
hidden_states: torch.Tensor = encoder_outputs[0]
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
expect_seq_len = num_context_patches + 1 + num_output_patches
if self.chronos_config.use_reg_token:
expect_seq_len += 1
if self.chronos_config.n_static_covariates > 0:
expect_seq_len += 1
# assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
# slice the last num_output_patches hidden states to be input into the output_patch_embedding
forecast_embeds = hidden_states[:, -num_output_patches:]