Fix output transform, add test to enforce tokenizer consistency (#73)

*Description of changes:* 

The bin indexes were shifted by one between input transform and output
transform. Subtracting 1 to the sampled tokens in output transform lead
to the correct reconstruction of the signal.

Add a test to ensure the consistency of the Chronos Tokenizer.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Lorenzo Stella <stellalo@amazon.com> and Abdul Fatir
Ansari <ansarnd@amazon.com>
This commit is contained in:
HugoSenetaire 2024-05-17 15:29:18 +02:00 committed by GitHub
parent 02d1a1d73e
commit 3fe24ff8cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 2 deletions

View file

@ -185,7 +185,7 @@ class MeanScaleUniformBins(ChronosTokenizer):
) -> torch.Tensor:
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
indices = torch.clamp(
samples - self.config.n_special_tokens,
samples - self.config.n_special_tokens - 1,
min=0,
max=len(self.centers) - 1,
)

View file

@ -7,7 +7,45 @@ from typing import Tuple
import torch
import pytest
from chronos import ChronosConfig, ChronosPipeline
from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
n_tokens = n_numerical_tokens + n_special_tokens
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=True,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
assert isinstance(tokenizer, MeanScaleUniformBins)
context = tokenizer.centers.unsqueeze(0) # add batch dimension
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
samples = tokenizer.output_transform(
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
scale=scale,
)
assert (samples[0, 0, :] == context).all()
@pytest.mark.xfail