mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
feat(chronos2): add native support for static covariates
This commit is contained in:
parent
1f099eb265
commit
9e1d8ff961
3 changed files with 62 additions and 5 deletions
|
|
@ -110,6 +110,7 @@ class Chronos2ForecastingConfig:
|
|||
use_arcsinh: bool = False
|
||||
max_output_patches: int = 1
|
||||
time_encoding_scale: int | None = None
|
||||
n_static_covariates: int = 0
|
||||
|
||||
@classmethod
|
||||
def editable_fields(cls) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def left_pad_and_cat_2D(tensors: list[torch.Tensor]) -> torch.Tensor:
|
|||
|
||||
def validate_and_prepare_single_dict_task(
|
||||
task: Mapping[str, TensorOrArray | Mapping[str, TensorOrArray]], idx: int, prediction_length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int, int, int]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""Validates and prepares a single dictionary task for Chronos2Model.
|
||||
|
||||
Parameters
|
||||
|
|
@ -72,7 +72,7 @@ def validate_and_prepare_single_dict_task(
|
|||
- task_n_future_covariates: Number of known future covariates
|
||||
"""
|
||||
|
||||
allowed_keys = {"target", "past_covariates", "future_covariates"}
|
||||
allowed_keys = {"target", "past_covariates", "future_covariates", "static_covariantes"}
|
||||
|
||||
# validate keys
|
||||
keys = set(task.keys())
|
||||
|
|
@ -96,6 +96,22 @@ def validate_and_prepare_single_dict_task(
|
|||
history_length = task_target.shape[-1]
|
||||
task_target = task_target.view(-1, history_length)
|
||||
|
||||
# Validate static_covariates
|
||||
task_static_covariates = task.get("static_covariates", None)
|
||||
if task_static_covariates is not None:
|
||||
if isinstance(task_static_covariates, np.ndarray):
|
||||
task_static_covariates = torch.from_numpy(task_static_covariates)
|
||||
if not isinstance(task_static_covariates, torch.Tensor):
|
||||
raise ValueError(f"static_covariates must be a numpy array or torch tensor at index {idx}")
|
||||
|
||||
if task_static_covariates.ndim == 0:
|
||||
task_static_covariates = task_static_covariates.unsqueeze(0)
|
||||
elif task_static_covariates.ndim > 1:
|
||||
task_static_covariates = task_static_covariates.view(-1)
|
||||
|
||||
else:
|
||||
task_static_covariantes = torch.zeros((0,), device=task_target.device)
|
||||
|
||||
# validate past_covariates
|
||||
cat_encoders: dict = {}
|
||||
task_past_covariates = task.get("past_covariates", {})
|
||||
|
|
@ -200,6 +216,7 @@ def validate_and_prepare_single_dict_task(
|
|||
return (
|
||||
task_context_tensor,
|
||||
task_future_covariates_tensor,
|
||||
task_static_covariates,
|
||||
task_n_targets,
|
||||
task_n_covariates,
|
||||
task_n_future_covariates,
|
||||
|
|
@ -473,6 +490,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
(
|
||||
task_past_tensor, # shape: (task_n_targets + task_n_covariates, history_length)
|
||||
task_future_tensor,
|
||||
task_static_tensor,
|
||||
task_n_targets,
|
||||
task_n_covariates,
|
||||
task_n_future_covariates,
|
||||
|
|
@ -533,25 +551,27 @@ class Chronos2Dataset(IterableDataset):
|
|||
# task_future_covariates: (task_n_targets + task_n_past_only_covariates + task_n_future_covariates, prediction_length),
|
||||
# the entries corresponding to targets and past-only covariates are NaNs
|
||||
|
||||
return task_context, task_future_target, task_future_covariates, task_n_targets
|
||||
return task_context, task_future_target, task_future_covariates, task_static_tensor, task_n_targets
|
||||
|
||||
def _build_batch(self, task_indices: list[int]) -> dict[str, torch.Tensor | int | list[tuple[int, int]] | None]:
|
||||
"""Build a batch from given task indices."""
|
||||
batch_context_tensor_list = []
|
||||
batch_future_target_tensor_list = []
|
||||
batch_future_covariates_tensor_list = []
|
||||
batch_static_covariates_list = []
|
||||
batch_group_ids_list = []
|
||||
target_idx_ranges: list[tuple[int, int]] = []
|
||||
|
||||
target_start_idx = 0
|
||||
for group_id, task_idx in enumerate(task_indices):
|
||||
task_context, task_future_target, task_future_covariates, task_n_targets = self._construct_slice(task_idx)
|
||||
task_context, task_future_target, task_future_covariates, static_tensor, task_n_targets = self._construct_slice(task_idx)
|
||||
|
||||
group_size = task_context.shape[0]
|
||||
task_group_ids = torch.full((group_size,), fill_value=group_id)
|
||||
batch_context_tensor_list.append(task_context)
|
||||
batch_future_target_tensor_list.append(task_future_target)
|
||||
batch_future_covariates_tensor_list.append(task_future_covariates)
|
||||
batch_static_covariates_list.append(static_tensor)
|
||||
batch_group_ids_list.append(task_group_ids)
|
||||
target_idx_ranges.append((target_start_idx, target_start_idx + task_n_targets))
|
||||
target_start_idx += group_size
|
||||
|
|
@ -562,6 +582,7 @@ class Chronos2Dataset(IterableDataset):
|
|||
if self.mode == DatasetMode.TEST
|
||||
else torch.cat(cast(list[torch.Tensor], batch_future_target_tensor_list), dim=0),
|
||||
"future_covariates": torch.cat(batch_future_covariates_tensor_list, dim=0),
|
||||
"static_covariates": torch.stack(batch_static_covariates_list, dim=0),
|
||||
"group_ids": torch.cat(batch_group_ids_list, dim=0),
|
||||
"num_output_patches": self.num_output_patches,
|
||||
"target_idx_ranges": target_idx_ranges,
|
||||
|
|
|
|||
|
|
@ -235,6 +235,16 @@ class Chronos2Model(PreTrainedModel):
|
|||
dropout_p=config.dropout_rate,
|
||||
)
|
||||
|
||||
# Embedding for static covariables
|
||||
if self.chronos_config.n_static_covariates > 0:
|
||||
self.static_covariates_embedding = ResidualBlock(
|
||||
in_dim=self.chronos_config.n_static_covariates,
|
||||
h_dim=config.d_ff,
|
||||
out_dim=config.d_model,
|
||||
act_fn_name=config.dense_act_fn,
|
||||
dropout_p=config.dropout_rate,
|
||||
)
|
||||
|
||||
# patching layer
|
||||
self.patch = Patch(
|
||||
patch_size=self.chronos_config.input_patch_size, patch_stride=self.chronos_config.input_patch_stride
|
||||
|
|
@ -322,12 +332,18 @@ class Chronos2Model(PreTrainedModel):
|
|||
num_output_patches: int,
|
||||
future_target: torch.Tensor | None,
|
||||
future_target_mask: torch.Tensor | None,
|
||||
static_covariates: torch.Tensor | None,
|
||||
):
|
||||
output_patch_size = self.chronos_config.output_patch_size
|
||||
if context.ndim != 2:
|
||||
raise ValueError(f"context must have shape (batch_size, context_length), found: {tuple(context.shape)}")
|
||||
if context_mask is not None and context_mask.shape != context.shape:
|
||||
raise ValueError(f"mask must have shape {tuple(context.shape)}, found: {tuple(context_mask.shape)}")
|
||||
if self.chronos_config.n_static_covariates > 0:
|
||||
if static_covariates is None:
|
||||
raise ValueError(f"Model expects {self.chronos_config.n_static_covariates} static covariates, but None provided.")
|
||||
if static_covariates.shape[0] != context.shape[0] or static_covariates.shape[-1] != self.chronos_config.n_static_covariates:
|
||||
raise ValueError(f"static_covariates must have shape (batch_size, {self.chronos_config.n_static_covariates}), found {tuple(static_covariates.shape)}")
|
||||
if future_covariates is not None:
|
||||
if future_covariates.shape[0] != context.shape[0] or future_covariates.ndim != 2:
|
||||
raise ValueError(
|
||||
|
|
@ -557,6 +573,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
num_output_patches: int = 1,
|
||||
future_target: torch.Tensor | None = None,
|
||||
future_target_mask: torch.Tensor | None = None,
|
||||
static_covariates: torch.Tensor | None = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
self._validate_input(
|
||||
|
|
@ -568,6 +585,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
num_output_patches=num_output_patches,
|
||||
future_target=future_target,
|
||||
future_target_mask=future_target_mask,
|
||||
static_covariates=static_covariates,
|
||||
)
|
||||
|
||||
batch_size = context.shape[0]
|
||||
|
|
@ -578,6 +596,14 @@ class Chronos2Model(PreTrainedModel):
|
|||
|
||||
# get input embeddings of shape (batch, num_context_patches, d_model)
|
||||
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
|
||||
|
||||
# Injection Static Covariates
|
||||
if self.chronos_config.n_static_covariates > 0:
|
||||
static_embeds = self.static_covariates_embedding(static_covariates.to(self.dtype))
|
||||
static_embeds = static_embeds.unsqueeze(1)
|
||||
input_embeds = torch.cat([static_embeds, input_embeds], dim=-2)
|
||||
static_mask = torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
attention_mask = torch.cat([static_mask, attention_mask], dim=-1)
|
||||
# append [REG] special token embedding, if needed
|
||||
if self.chronos_config.use_reg_token:
|
||||
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
|
||||
|
|
@ -625,6 +651,7 @@ class Chronos2Model(PreTrainedModel):
|
|||
num_output_patches: int = 1,
|
||||
future_target: torch.Tensor | None = None,
|
||||
future_target_mask: torch.Tensor | None = None,
|
||||
static_covariates: torch.Tensor | None = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Chronos2Output:
|
||||
"""Forward pass of the Chronos2 model.
|
||||
|
|
@ -703,10 +730,18 @@ class Chronos2Model(PreTrainedModel):
|
|||
num_output_patches=num_output_patches,
|
||||
future_target=future_target,
|
||||
future_target_mask=future_target_mask,
|
||||
static_covariates=static_covariates,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states: torch.Tensor = encoder_outputs[0]
|
||||
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
|
||||
|
||||
expect_seq_len = num_context_patches + 1 + num_output_patches
|
||||
if self.chronos_config.use_reg_token:
|
||||
expect_seq_len += 1
|
||||
if self.chronos_config.n_static_covariates > 0:
|
||||
expect_seq_len += 1
|
||||
|
||||
# assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
|
||||
|
||||
# slice the last num_output_patches hidden states to be input into the output_patch_embedding
|
||||
forecast_embeds = hidden_states[:, -num_output_patches:]
|
||||
|
|
|
|||
Loading…
Reference in a new issue