diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 9115abc..225282a 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -185,7 +185,7 @@ class MeanScaleUniformBins(ChronosTokenizer): ) -> torch.Tensor: scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1) indices = torch.clamp( - samples - self.config.n_special_tokens, + samples - self.config.n_special_tokens - 1, min=0, max=len(self.centers) - 1, ) diff --git a/test/test_chronos.py b/test/test_chronos.py index a7d63bc..480908c 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -7,7 +7,45 @@ from typing import Tuple import torch import pytest -from chronos import ChronosConfig, ChronosPipeline +from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins + + +@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27]) +@pytest.mark.parametrize("n_special_tokens", [2, 5, 13]) +def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int): + n_tokens = n_numerical_tokens + n_special_tokens + + config = ChronosConfig( + tokenizer_class="MeanScaleUniformBins", + tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), + n_tokens=n_tokens, + n_special_tokens=n_special_tokens, + pad_token_id=0, + eos_token_id=1, + use_eos_token=True, + model_type="seq2seq", + context_length=512, + prediction_length=64, + num_samples=20, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + + tokenizer = config.create_tokenizer() + assert isinstance(tokenizer, MeanScaleUniformBins) + + context = tokenizer.centers.unsqueeze(0) # add batch dimension + scale = torch.ones((1,)) # fix the scale to one to turn off scaling + + token_ids, _, _ = tokenizer.input_transform(context, scale=scale) + + samples = tokenizer.output_transform( + token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension + scale=scale, + ) + + assert (samples[0, 0, :] == context).all() @pytest.mark.xfail