mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Fix output transform, add test to enforce tokenizer consistency (#73)
*Description of changes:* The bin indexes were shifted by one between input transform and output transform. Subtracting 1 to the sampled tokens in output transform lead to the correct reconstruction of the signal. Add a test to ensure the consistency of the Chronos Tokenizer. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Lorenzo Stella <stellalo@amazon.com> and Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
parent
02d1a1d73e
commit
3fe24ff8cd
2 changed files with 40 additions and 2 deletions
|
|
@ -185,7 +185,7 @@ class MeanScaleUniformBins(ChronosTokenizer):
|
|||
) -> torch.Tensor:
|
||||
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
|
||||
indices = torch.clamp(
|
||||
samples - self.config.n_special_tokens,
|
||||
samples - self.config.n_special_tokens - 1,
|
||||
min=0,
|
||||
max=len(self.centers) - 1,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,45 @@ from typing import Tuple
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
from chronos import ChronosConfig, ChronosPipeline
|
||||
from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
|
||||
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
|
||||
def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
|
||||
n_tokens = n_numerical_tokens + n_special_tokens
|
||||
|
||||
config = ChronosConfig(
|
||||
tokenizer_class="MeanScaleUniformBins",
|
||||
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
|
||||
n_tokens=n_tokens,
|
||||
n_special_tokens=n_special_tokens,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
use_eos_token=True,
|
||||
model_type="seq2seq",
|
||||
context_length=512,
|
||||
prediction_length=64,
|
||||
num_samples=20,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
)
|
||||
|
||||
tokenizer = config.create_tokenizer()
|
||||
assert isinstance(tokenizer, MeanScaleUniformBins)
|
||||
|
||||
context = tokenizer.centers.unsqueeze(0) # add batch dimension
|
||||
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
|
||||
|
||||
token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
|
||||
|
||||
samples = tokenizer.output_transform(
|
||||
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
assert (samples[0, 0, :] == context).all()
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
|
|
|
|||
Loading…
Reference in a new issue