mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Add pipeline.embed support for Chronos-Bolt (#247)
This commit is contained in:
parent
28e7b3281f
commit
ad410c9c0a
3 changed files with 106 additions and 17 deletions
11
.github/workflows/eval-model.yml
vendored
11
.github/workflows/eval-model.yml
vendored
|
|
@ -12,7 +12,7 @@ on:
|
|||
- labeled # When a label is added to the PR
|
||||
|
||||
jobs:
|
||||
evaluate-and-post:
|
||||
evaluate-and-print:
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
|
|
@ -33,10 +33,5 @@ jobs:
|
|||
- name: Run Eval Script
|
||||
run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32
|
||||
|
||||
- name: Upload CSV
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: eval-metrics
|
||||
path: ${{ env.RESULTS_CSV }}
|
||||
retention-days: 1
|
||||
overwrite: true
|
||||
- name: Print CSV
|
||||
run: cat $RESULTS_CSV
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from transformers.utils import ModelOutput
|
|||
|
||||
from .base import BaseChronosPipeline, ForecastType
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
|
|
@ -240,13 +239,11 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
):
|
||||
module.output_layer.bias.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
target: Optional[torch.Tensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
) -> ChronosBoltOutput:
|
||||
def encode(
|
||||
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
) -> Tuple[
|
||||
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
|
||||
]:
|
||||
mask = (
|
||||
mask.to(context.dtype)
|
||||
if mask is not None
|
||||
|
|
@ -301,8 +298,21 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
inputs_embeds=input_embeds,
|
||||
)
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
return encoder_outputs[0], loc_scale, input_embeds, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
target: Optional[torch.Tensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
) -> ChronosBoltOutput:
|
||||
batch_size = context.size(0)
|
||||
|
||||
hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
|
||||
context=context, mask=mask
|
||||
)
|
||||
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
|
||||
|
||||
quantile_preds_shape = (
|
||||
|
|
@ -426,6 +436,46 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
def quantiles(self) -> List[float]:
|
||||
return self.model.config.chronos_config["quantiles"]
|
||||
|
||||
@torch.no_grad()
|
||||
def embed(
|
||||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Get encoder embeddings for the given time series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
Input series. This is either a 1D tensor, or a list
|
||||
of 1D tensors, or a 2D tensor whose first dimension
|
||||
is batch. In the latter case, use left-padding with
|
||||
``torch.nan`` to align series of different lengths.
|
||||
|
||||
Returns
|
||||
-------
|
||||
embeddings, loc_scale
|
||||
A tuple of two items: the encoder embeddings and the loc_scale,
|
||||
i.e., the mean and std of the original time series.
|
||||
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
|
||||
where num_patches is the number of patches in the time series
|
||||
and the extra 1 is for the [REG] token (if used by the model).
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
model_context_length = self.model.config.chronos_config["context_length"]
|
||||
|
||||
if context_tensor.shape[-1] > model_context_length:
|
||||
context_tensor = context_tensor[..., -model_context_length:]
|
||||
|
||||
context_tensor = context_tensor.to(
|
||||
device=self.model.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
|
||||
return embeddings.cpu(), (
|
||||
loc_scale[0].squeeze(-1).cpu(),
|
||||
loc_scale[1].squeeze(-1).cpu(),
|
||||
)
|
||||
|
||||
def predict( # type: ignore[override]
|
||||
self,
|
||||
context: Union[torch.Tensor, List[torch.Tensor]],
|
||||
|
|
|
|||
|
|
@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles(
|
|||
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
|
||||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosBoltPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-bolt-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
d_model = pipeline.model.config.d_model
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = context.to(dtype=input_dtype)
|
||||
|
||||
# the patch size of dummy model is 16, so only 1 patch is created
|
||||
expected_embed_length = 1 + (
|
||||
1 if pipeline.model.config.chronos_config["use_reg_token"] else 0
|
||||
)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
embedding, loc_scale = pipeline.embed(context)
|
||||
validate_tensor(
|
||||
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
|
||||
)
|
||||
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
|
||||
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
embedding, loc_scale = pipeline.embed(list(context))
|
||||
validate_tensor(
|
||||
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
|
||||
)
|
||||
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
|
||||
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
embedding, loc_scale = pipeline.embed(context[0, ...])
|
||||
validate_tensor(
|
||||
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
|
||||
)
|
||||
validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32)
|
||||
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)
|
||||
|
||||
|
||||
# The following tests have been taken from
|
||||
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
|
||||
# Author: Caner Turkmen <atturkm@amazon.com>
|
||||
|
|
|
|||
Loading…
Reference in a new issue