From 3fe24ff8cd2e52d1aef45bbcbd30ae1c16fdce49 Mon Sep 17 00:00:00 2001 From: HugoSenetaire <32298113+HugoSenetaire@users.noreply.github.com> Date: Fri, 17 May 2024 15:29:18 +0200 Subject: [PATCH] Fix output transform, add test to enforce tokenizer consistency (#73) *Description of changes:* The bin indexes were shifted by one between input transform and output transform. Subtracting 1 to the sampled tokens in output transform lead to the correct reconstruction of the signal. Add a test to ensure the consistency of the Chronos Tokenizer. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Lorenzo Stella and Abdul Fatir Ansari --- src/chronos/chronos.py | 2 +- test/test_chronos.py | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 9115abc..225282a 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -185,7 +185,7 @@ class MeanScaleUniformBins(ChronosTokenizer): ) -> torch.Tensor: scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1) indices = torch.clamp( - samples - self.config.n_special_tokens, + samples - self.config.n_special_tokens - 1, min=0, max=len(self.centers) - 1, ) diff --git a/test/test_chronos.py b/test/test_chronos.py index a7d63bc..480908c 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -7,7 +7,45 @@ from typing import Tuple import torch import pytest -from chronos import ChronosConfig, ChronosPipeline +from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins + + +@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27]) +@pytest.mark.parametrize("n_special_tokens", [2, 5, 13]) +def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int): + n_tokens = n_numerical_tokens + n_special_tokens + + config = ChronosConfig( + tokenizer_class="MeanScaleUniformBins", + tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), + n_tokens=n_tokens, + n_special_tokens=n_special_tokens, + pad_token_id=0, + eos_token_id=1, + use_eos_token=True, + model_type="seq2seq", + context_length=512, + prediction_length=64, + num_samples=20, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + + tokenizer = config.create_tokenizer() + assert isinstance(tokenizer, MeanScaleUniformBins) + + context = tokenizer.centers.unsqueeze(0) # add batch dimension + scale = torch.ones((1,)) # fix the scale to one to turn off scaling + + token_ids, _, _ = tokenizer.input_transform(context, scale=scale) + + samples = tokenizer.output_transform( + token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension + scale=scale, + ) + + assert (samples[0, 0, :] == context).all() @pytest.mark.xfail