Add pipeline.embed support for Chronos-Bolt (#247)

This commit is contained in:
Abdul Fatir 2024-12-22 13:56:41 +01:00 committed by GitHub
parent 28e7b3281f
commit ad410c9c0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 106 additions and 17 deletions

View file

@ -12,7 +12,7 @@ on:
- labeled # When a label is added to the PR
jobs:
evaluate-and-post:
evaluate-and-print:
if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added
runs-on: ubuntu-latest
env:
@ -33,10 +33,5 @@ jobs:
- name: Run Eval Script
run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32
- name: Upload CSV
uses: actions/upload-artifact@v4
with:
name: eval-metrics
path: ${{ env.RESULTS_CSV }}
retention-days: 1
overwrite: true
- name: Print CSV
run: cat $RESULTS_CSV

View file

@ -25,7 +25,6 @@ from transformers.utils import ModelOutput
from .base import BaseChronosPipeline, ForecastType
logger = logging.getLogger(__file__)
@ -240,13 +239,11 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
):
module.output_layer.bias.data.zero_()
def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
]:
mask = (
mask.to(context.dtype)
if mask is not None
@ -301,8 +298,21 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
attention_mask=attention_mask,
inputs_embeds=input_embeds,
)
hidden_states = encoder_outputs[0]
return encoder_outputs[0], loc_scale, input_embeds, attention_mask
def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
batch_size = context.size(0)
hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
context=context, mask=mask
)
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
quantile_preds_shape = (
@ -426,6 +436,46 @@ class ChronosBoltPipeline(BaseChronosPipeline):
def quantiles(self) -> List[float]:
return self.model.config.chronos_config["quantiles"]
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Returns
-------
embeddings, loc_scale
A tuple of two items: the encoder embeddings and the loc_scale,
i.e., the mean and std of the original time series.
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches in the time series
and the extra 1 is for the [REG] token (if used by the model).
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]
if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
return embeddings.cpu(), (
loc_scale[0].squeeze(-1).cpu(),
loc_scale[1].squeeze(-1).cpu(),
)
def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],

View file

@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles(
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosBoltPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-bolt-model",
device_map="cpu",
torch_dtype=model_dtype,
)
d_model = pipeline.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
# the patch size of dummy model is 16, so only 1 patch is created
expected_embed_length = 1 + (
1 if pipeline.model.config.chronos_config["use_reg_token"] else 0
)
# input: tensor of shape (batch_size, context_length)
embedding, loc_scale = pipeline.embed(context)
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
embedding, loc_scale = pipeline.embed(list(context))
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
# input: tensor of shape (context_length,)
embedding, loc_scale = pipeline.embed(context[0, ...])
validate_tensor(
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32)
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)
# The following tests have been taken from
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
# Author: Caner Turkmen <atturkm@amazon.com>