Add Chronos2Pipeline.embed (#361)

*Issue #, if available:* #354 

*Description of changes:* This PR adds `Chronos2Pipeline.embed` to
enable users to extract embeddings from the last encoder layer in an
easy way. The API and behavior is similar to what Chronos and
Chronos-Bolt provides.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Abdul Fatir 2025-11-17 17:16:50 +01:00 committed by GitHub
parent e48f48071f
commit 111972a6cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 175 additions and 47 deletions

View file

@ -547,6 +547,74 @@ class Chronos2Model(PreTrainedModel):
return loss
def encode(
self,
context: torch.Tensor,
context_mask: torch.Tensor | None = None,
group_ids: torch.Tensor | None = None,
future_covariates: torch.Tensor | None = None,
future_covariates_mask: torch.Tensor | None = None,
num_output_patches: int = 1,
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
output_attentions: bool = False,
):
self._validate_input(
context=context,
context_mask=context_mask,
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
group_ids=group_ids,
num_output_patches=num_output_patches,
future_target=future_target,
future_target_mask=future_target_mask,
)
batch_size = context.shape[0]
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
context=context, context_mask=context_mask
)
num_context_patches = attention_mask.shape[-1]
# get input embeddings of shape (batch, num_context_patches, d_model)
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
# append [REG] special token embedding, if needed
if self.chronos_config.use_reg_token:
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
)
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
loc_scale=loc_scale,
num_output_patches=num_output_patches,
batch_size=batch_size,
)
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
# get future embeddings of shape (batch, num_output_patches, d_model)
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
# concatenate context and future embeddings and masks
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
if group_ids is None:
# by default, each time series is treated independently, i.e., no mixing across the batch
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
encoder_outputs: Chronos2EncoderOutput = self.encoder(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
group_ids=group_ids,
output_attentions=output_attentions,
)
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
def forward(
self,
context: torch.Tensor,
@ -625,63 +693,19 @@ class Chronos2Model(PreTrainedModel):
- enc_time_self_attn_weights: Time self attention weights, if output_attentions=True
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
"""
self._validate_input(
batch_size = context.shape[0]
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
context=context,
context_mask=context_mask,
group_ids=group_ids,
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
group_ids=group_ids,
num_output_patches=num_output_patches,
future_target=future_target,
future_target_mask=future_target_mask,
)
batch_size = context.shape[0]
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
context=context, context_mask=context_mask
)
num_context_patches = attention_mask.shape[-1]
# get input embeddings of shape (batch, num_context_patches, d_model)
input_embeds: torch.Tensor = self.input_patch_embedding(patched_context)
# append [REG] special token embedding, if needed
if self.chronos_config.use_reg_token:
reg_input_ids = torch.full((batch_size, 1), self.config.reg_token_id, device=input_embeds.device)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[attention_mask.to(self.dtype), torch.ones_like(reg_input_ids).to(self.dtype)], dim=-1
)
patched_future, patched_future_covariates_mask = self._prepare_patched_future(
future_covariates=future_covariates,
future_covariates_mask=future_covariates_mask,
loc_scale=loc_scale,
num_output_patches=num_output_patches,
batch_size=batch_size,
)
future_attention_mask = torch.ones(batch_size, num_output_patches, dtype=self.dtype, device=self.device)
# get future embeddings of shape (batch, num_output_patches, d_model)
future_embeds: torch.Tensor = self.input_patch_embedding(patched_future)
# concatenate context and future embeddings and masks
input_embeds = torch.cat([input_embeds, future_embeds], dim=-2)
attention_mask = torch.cat([attention_mask, future_attention_mask], dim=-1)
if group_ids is None:
# by default, each time series is treated independently, i.e., no mixing across the batch
group_ids = torch.arange(batch_size, dtype=torch.long, device=self.device)
encoder_outputs: Chronos2EncoderOutput = self.encoder(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
group_ids=group_ids,
output_attentions=output_attentions,
)
hidden_states: torch.Tensor = encoder_outputs[0]
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
# slice the last num_output_patches hidden states to be input into the output_patch_embedding

View file

@ -988,6 +988,81 @@ class Chronos2Pipeline(BaseChronosPipeline):
return predictions_per_window, inference_time_s
@torch.no_grad()
def embed(
self, inputs: TensorOrArray | Sequence[TensorOrArray], batch_size: int = 256, context_length: int | None = None
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
inputs
The time series to get embeddings for, can be one of:
- A 3-dimensional `torch.Tensor` or `np.ndarray` of shape (batch, n_variates, history_length). When `n_variates > 1`, information
will be shared among the different variates of each time series in the batch.
- A list of `torch.Tensor` or `np.ndarray` where each element can either be 1-dimensional of shape (history_length,)
or 2-dimensional of shape (n_variates, history_length). The history_lengths may be different across elements; left-padding
will be applied, if needed.
batch_size
The batch size used for generating embeddings. Note that the batch size here means the total number of time series which are input into the model.
If your data has multiple variates, the effective number of time series tasks in a batch will be lower than this value, by default 256
context_length
The maximum context length used during for inference, by default set to the model's default context length
Returns
-------
embeddings
a list of `torch.Tensor` where each element has shape (n_variates, num_patches + 2, d_model) and the number of elements are equal to the number
of target time series (univariate or multivariate) in the `inputs`. The extra +2 is due to embeddings of the [REG] token and a masked output patch token.
loc_scale
a list of tuples with the mean and standard deviation of each time series.
"""
if context_length is None:
context_length = self.model_context_length
if context_length > self.model_context_length:
warnings.warn(
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
f"Resetting context_length to {self.model_context_length}."
)
context_length = self.model_context_length
test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
context_length=context_length,
prediction_length=0,
batch_size=batch_size,
output_patch_size=self.model_output_patch_size,
mode=DatasetMode.TEST,
)
test_loader = DataLoader(
test_dataset, batch_size=None, num_workers=1, pin_memory=True, shuffle=False, drop_last=False
)
all_embeds: list[torch.Tensor] = []
all_loc_scales: list[tuple[torch.Tensor, torch.Tensor]] = []
for batch in test_loader:
assert batch["future_target"] is None
batch_context = batch["context"]
batch_group_ids = batch["group_ids"]
batch_target_idx_ranges = batch["target_idx_ranges"]
encoder_outputs, (locs, scales), *_ = self.model.encode(
context=batch_context.to(device=self.model.device, dtype=torch.float32),
group_ids=batch_group_ids.to(self.model.device),
)
batch_embeds = [encoder_outputs[0][start:end].cpu() for (start, end) in batch_target_idx_ranges]
batch_loc_scales = list(
zip(
[locs[start:end].cpu() for (start, end) in batch_target_idx_ranges],
[scales[start:end].cpu() for (start, end) in batch_target_idx_ranges],
)
)
all_embeds.extend(batch_embeds)
all_loc_scales.extend(batch_loc_scales)
return all_embeds, all_loc_scales
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""

View file

@ -340,6 +340,35 @@ def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: tor
validate_tensor(quantiles_item, (3, expected_num_quantiles, 7), dtype=torch.float32)
@pytest.mark.parametrize(
"inputs, expected_output_shapes",
[
# NOTE: d_model for the dummy model is 6
# Homogenous univariate task
(torch.rand(4, 1, 16), [(1, 3, 6)] * 4),
# Homogenous multivariate task
(torch.rand(4, 3, 37), [(3, 5, 6)] * 4),
# Heterogenous tasks with different history lengths
(
[torch.rand(100), torch.rand(2, 150), torch.rand(120)],
[(1, 12, 6), (2, 12, 6), (1, 12, 6)],
),
],
)
def test_when_input_is_valid_then_pipeline_can_embed(pipeline, inputs, expected_output_shapes):
embeds, loc_scales = pipeline.embed(inputs)
assert (
isinstance(embeds, list)
and len(embeds) == len(expected_output_shapes)
and len(loc_scales) == len(expected_output_shapes)
)
for embed, loc_scale, expected_shape in zip(embeds, loc_scales, expected_output_shapes):
validate_tensor(embed, expected_shape, dtype=torch.float32)
validate_tensor(loc_scale[0], (expected_shape[0], 1), dtype=torch.float32)
validate_tensor(loc_scale[1], (expected_shape[0], 1), dtype=torch.float32)
@pytest.mark.parametrize(
"task_kwargs",
[