mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Add Chronos2Pipeline.embed (#361)
*Issue #, if available:* #354 *Description of changes:* This PR adds `Chronos2Pipeline.embed` to enable users to extract embeddings from the last encoder layer in an easy way. The API and behavior is similar to what Chronos and Chronos-Bolt provides. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
e48f48071f
commit
111972a6cc
3 changed files with 175 additions and 47 deletions
|
|
@ -547,6 +547,74 @@ class Chronos2Model(PreTrainedModel):
|
|||
|
||||
return loss
|
||||
|
||||
def encode(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
group_ids: torch.Tensor | None = None,
|
||||
future_covariates: torch.Tensor | None = None,
|
||||
future_covariates_mask: torch.Tensor | None = None,
|
||||
num_output_patches: int = 1,
|
||||
future_target: torch.Tensor | None = None,
|
||||
future_target_mask: torch.Tensor | None = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
self._validate_input(
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
future_covariates=future_covariates,
|
||||
future_covariates_mask=future_covariates_mask,
|
||||
group_ids=group_ids,
|
||||
num_output_patches=num_output_patches,
|
||||
future_target=future_target,
|
||||
future_target_mask=future_target_mask,
|
||||
)
|
||||
|
||||
batch_size = context.shape[0]
|
||||
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
|
||||
context=context, context_mask=context_mask
|
||||
)
|
||||
num_context_patches = attention_mask.shape[-1]
|
||||
|
||||
# get input embeddings of shape (batch, num_context_patches, d_model)
|
||||
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
|
||||
# append [REG] special token embedding, if needed
|
||||
if self.chronos_config.use_reg_token:
|
||||
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
|
||||
reg_embeds = self.shared(reg_input_ids)
|
||||
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
|
||||
)
|
||||
|
||||
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
|
||||
future_covariates=future_covariates,
|
||||
future_covariates_mask=future_covariates_mask,
|
||||
loc_scale=loc_scale,
|
||||
num_output_patches=num_output_patches,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
|
||||
|
||||
# get future embeddings of shape (batch, num_output_patches, d_model)
|
||||
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
|
||||
|
||||
# concatenate context and future embeddings and masks
|
||||
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
|
||||
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
|
||||
|
||||
if group_ids is None:
|
||||
# by default, each time series is treated independently, i.e., no mixing across the batch
|
||||
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
encoder_outputs: Chronos2EncoderOutput = self.encoder(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=input_embeds,
|
||||
group_ids=group_ids,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
|
|
@ -625,63 +693,19 @@ class Chronos2Model(PreTrainedModel):
|
|||
- enc_time_self_attn_weights: Time self attention weights, if output_attentions=True
|
||||
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
|
||||
"""
|
||||
|
||||
self._validate_input(
|
||||
batch_size = context.shape[0]
|
||||
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
group_ids=group_ids,
|
||||
future_covariates=future_covariates,
|
||||
future_covariates_mask=future_covariates_mask,
|
||||
group_ids=group_ids,
|
||||
num_output_patches=num_output_patches,
|
||||
future_target=future_target,
|
||||
future_target_mask=future_target_mask,
|
||||
)
|
||||
|
||||
batch_size = context.shape[0]
|
||||
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
|
||||
context=context, context_mask=context_mask
|
||||
)
|
||||
num_context_patches = attention_mask.shape[-1]
|
||||
|
||||
# get input embeddings of shape (batch, num_context_patches, d_model)
|
||||
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
|
||||
# append [REG] special token embedding, if needed
|
||||
if self.chronos_config.use_reg_token:
|
||||
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
|
||||
reg_embeds = self.shared(reg_input_ids)
|
||||
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
|
||||
)
|
||||
|
||||
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
|
||||
future_covariates=future_covariates,
|
||||
future_covariates_mask=future_covariates_mask,
|
||||
loc_scale=loc_scale,
|
||||
num_output_patches=num_output_patches,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
|
||||
|
||||
# get future embeddings of shape (batch, num_output_patches, d_model)
|
||||
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
|
||||
|
||||
# concatenate context and future embeddings and masks
|
||||
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
|
||||
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
|
||||
|
||||
if group_ids is None:
|
||||
# by default, each time series is treated independently, i.e., no mixing across the batch
|
||||
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
encoder_outputs: Chronos2EncoderOutput = self.encoder(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=input_embeds,
|
||||
group_ids=group_ids,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states: torch.Tensor = encoder_outputs[0]
|
||||
|
||||
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
|
||||
|
||||
# slice the last num_output_patches hidden states to be input into the output_patch_embedding
|
||||
|
|
|
|||
|
|
@ -988,6 +988,81 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
return predictions_per_window, inference_time_s
|
||||
|
||||
@torch.no_grad()
|
||||
def embed(
|
||||
self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None
|
||||
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Get encoder embeddings for the given time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs
|
||||
The time series to get embeddings for, can be one of:
|
||||
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
|
||||
will be shared among the different variates of each time series in the batch.
|
||||
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
|
||||
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
|
||||
will be applied, if needed.
|
||||
batch_size
|
||||
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
|
||||
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
|
||||
context_length
|
||||
The maximum context length used during for inference, by default set to the model's default context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
embeddings
|
||||
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
|
||||
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
|
||||
loc_scale
|
||||
a list of tuples with the mean and standard deviation of each time series.
|
||||
"""
|
||||
if context_length is None:
|
||||
context_length = self.model_context_length
|
||||
|
||||
if context_length > self.model_context_length:
|
||||
warnings.warn(
|
||||
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
|
||||
f"Resetting context_length to {self.model_context_length}."
|
||||
)
|
||||
context_length = self.model_context_length
|
||||
|
||||
test_dataset = Chronos2Dataset.convert_inputs(
|
||||
inputs=inputs,
|
||||
context_length=context_length,
|
||||
prediction_length=0,
|
||||
batch_size=batch_size,
|
||||
output_patch_size=self.model_output_patch_size,
|
||||
mode=DatasetMode.TEST,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
|
||||
)
|
||||
all_embeds: list[torch.Tensor] = []
|
||||
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for batch in test_loader:
|
||||
assert batch["future_target"] is None
|
||||
batch_context = batch["context"]
|
||||
batch_group_ids = batch["group_ids"]
|
||||
batch_target_idx_ranges = batch["target_idx_ranges"]
|
||||
|
||||
encoder_outputs, (locs, scales), *_ = self.model.encode(
|
||||
context=batch_context.to(device=self.model.device, dtype=torch.float32),
|
||||
group_ids=batch_group_ids.to(self.model.device),
|
||||
)
|
||||
batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges]
|
||||
batch_loc_scales = list(
|
||||
zip(
|
||||
[locs[start:end].cpu() for (start, end) in batch_target_idx_ranges],
|
||||
[scales[start:end].cpu() for (start, end) in batch_target_idx_ranges],
|
||||
)
|
||||
)
|
||||
all_embeds.extend(batch_embeds)
|
||||
all_loc_scales.extend(batch_loc_scales)
|
||||
|
||||
return all_embeds, all_loc_scales
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor
|
|||
validate_tensor(quantiles_item, (3, expected_num_quantiles, 7), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs, expected_output_shapes",
|
||||
[
|
||||
# NOTE: d_model for the dummy model is 6
|
||||
# Homogenous univariate task
|
||||
(torch.rand(4, 1, 16), [(1, 3, 6)] * 4),
|
||||
# Homogenous multivariate task
|
||||
(torch.rand(4, 3, 37), [(3, 5, 6)] * 4),
|
||||
# Heterogenous tasks with different history lengths
|
||||
(
|
||||
[torch.rand(100), torch.rand(2, 150), torch.rand(120)],
|
||||
[(1, 12, 6), (2, 12, 6), (1, 12, 6)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_when_input_is_valid_then_pipeline_can_embed(pipeline, inputs, expected_output_shapes):
|
||||
embeds, loc_scales = pipeline.embed(inputs)
|
||||
|
||||
assert (
|
||||
isinstance(embeds, list)
|
||||
and len(embeds) == len(expected_output_shapes)
|
||||
and len(loc_scales) == len(expected_output_shapes)
|
||||
)
|
||||
for embed, loc_scale, expected_shape in zip(embeds, loc_scales, expected_output_shapes):
|
||||
validate_tensor(embed, expected_shape, dtype=torch.float32)
|
||||
validate_tensor(loc_scale[0], (expected_shape[0], 1), dtype=torch.float32)
|
||||
validate_tensor(loc_scale[1], (expected_shape[0], 1), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_kwargs",
|
||||
[
|
||||
|
|
|
|||
Loading…
Reference in a new issue