From 223e576e2ecb27aaf4f55c6962136a41da755014 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Tue, 28 May 2024 09:58:22 +0200 Subject: [PATCH] Split `input_transform` into `context_input_transform` and `label_input_transform` (#82) *Description of changes:* This splits `input_transform` into `context_input_transform` and `label_input_transform`. Previously, `input_transform` was being used for both context and label during training which would lead to incorrect results where `prediction_length` > `context_length`. TODO: - [x] Update docstrings - [x] Test the training script By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari --- scripts/training/train.py | 6 +- src/chronos/chronos.py | 121 +++++++++++++++++++++++++++----------- test/test_chronos.py | 8 +-- 3 files changed, 94 insertions(+), 41 deletions(-) diff --git a/scripts/training/train.py b/scripts/training/train.py index 1973121..bbdf6cd 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -387,9 +387,11 @@ class ChronosDataset(IterableDataset, ShuffleMixin): def to_hf_format(self, entry: dict) -> dict: past_target = torch.tensor(entry["past_target"]).unsqueeze(0) - input_ids, attention_mask, scale = self.tokenizer.input_transform(past_target) + input_ids, attention_mask, scale = self.tokenizer.context_input_transform( + past_target + ) future_target = torch.tensor(entry["future_target"]).unsqueeze(0) - labels, labels_mask, _ = self.tokenizer.input_transform(future_target, scale) + labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale) labels[labels_mask == 0] = -100 return { "input_ids": input_ids.squeeze(0), diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 225282a..f4b3377 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -26,14 +26,14 @@ class ChronosConfig: tokenizer_class: str tokenizer_kwargs: Dict[str, Any] + context_length: int + prediction_length: int n_tokens: int n_special_tokens: int pad_token_id: int eos_token_id: int use_eos_token: bool model_type: Literal["causal", "seq2seq"] - context_length: int - prediction_length: int num_samples: int temperature: float top_k: int @@ -59,13 +59,12 @@ class ChronosTokenizer: which concrete classes must implement. """ - def input_transform( + def context_input_transform( self, context: torch.Tensor, - tokenizer_state: Any = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + ) -> Tuple: """ - Turn a batch of time series into token IDs, attention map, and scale. + Turn a batch of time series into token IDs, attention map, and tokenizer_state. Parameters ---------- @@ -73,13 +72,6 @@ class ChronosTokenizer: A tensor shaped (batch_size, time_length), containing the timeseries to forecast. Use left-padding with ``torch.nan`` to align time series of different lengths. - tokenizer_state - An object returned by ``input_transform`` containing - relevant information to preprocess data, such as location and - scale. The nature of this depends on the specific tokenizer. - This is useful when tokenizing the label (for training), in - order to use the same scaling used to tokenize the context; - when tokenizing the context, this argument should be ignored. Returns ------- @@ -92,9 +84,41 @@ class ChronosTokenizer: which input observations are not ``torch.nan`` (i.e. not missing nor padding). tokenizer_state - An object that will be passed to ``output_transform``. - Contains the relevant information to decode output samples into - real values, such as location and scale parameters. + An object that can be passed to ``label_input_transform`` + and ``output_transform``. Contains the relevant information + to decode output samples into real values, + such as location and scale parameters. + """ + raise NotImplementedError() + + def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple: + """ + Turn a batch of label slices of time series into token IDs and attention map + using the ``tokenizer_state`` provided by ``context_input_transform``. + + Parameters + ---------- + context + A tensor shaped (batch_size, time_length), containing the + timeseries to forecast. Use left-padding with ``torch.nan`` + to align time series of different lengths. + tokenizer_state + An object returned by ``context_input_transform`` containing + relevant information to preprocess data, such as location and + scale. The nature of this depends on the specific tokenizer. + This is used for tokenizing the label, in order to use the same + scaling used to tokenize the context. + + Returns + ------- + token_ids + A tensor of integers, shaped (batch_size, time_length + 1) + if ``config.use_eos_token`` and (batch_size, time_length) + otherwise, containing token IDs for the input series. + attention_mask + A boolean tensor, same shape as ``token_ids``, indicating + which input observations are not ``torch.nan`` (i.e. not + missing nor padding). """ raise NotImplementedError() @@ -141,14 +165,9 @@ class MeanScaleUniformBins(ChronosTokenizer): ) ) - def input_transform( + def _input_transform( self, context: torch.Tensor, scale: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size, length = context.shape - - if length > self.config.context_length: - context = context[..., -self.config.context_length :] - attention_mask = ~torch.isnan(context) if scale is None: @@ -170,16 +189,51 @@ class MeanScaleUniformBins(ChronosTokenizer): ) token_ids[~attention_mask] = self.config.pad_token_id - if self.config.use_eos_token: - eos_tokens = torch.full( - (batch_size, 1), fill_value=self.config.eos_token_id + return token_ids, attention_mask, scale + + def _append_eos_token( + self, token_ids: torch.Tensor, attention_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = token_ids.shape[0] + eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id) + token_ids = torch.concat((token_ids, eos_tokens), dim=1) + eos_mask = torch.full((batch_size, 1), fill_value=True) + attention_mask = torch.concat((attention_mask, eos_mask), dim=1) + + return token_ids, attention_mask + + def context_input_transform( + self, context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + length = context.shape[-1] + + if length > self.config.context_length: + context = context[..., -self.config.context_length :] + + token_ids, attention_mask, scale = self._input_transform(context=context) + + if self.config.use_eos_token and self.config.model_type == "seq2seq": + token_ids, attention_mask = self._append_eos_token( + token_ids=token_ids, attention_mask=attention_mask ) - token_ids = torch.concat((token_ids, eos_tokens), dim=1) - eos_mask = torch.full((batch_size, 1), fill_value=True) - attention_mask = torch.concat((attention_mask, eos_mask), dim=1) return token_ids, attention_mask, scale + def label_input_transform( + self, label: torch.Tensor, scale: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + length = label.shape[-1] + + assert length == self.config.prediction_length + token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale) + + if self.config.use_eos_token: + token_ids, attention_mask = self._append_eos_token( + token_ids=token_ids, attention_mask=attention_mask + ) + + return token_ids, attention_mask + def output_transform( self, samples: torch.Tensor, scale: torch.Tensor ) -> torch.Tensor: @@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.stack(padded) +@dataclass class ChronosPipeline: """ A ``ChronosPipeline`` uses the given tokenizer and model to forecast @@ -337,10 +392,6 @@ class ChronosPipeline: tokenizer: ChronosTokenizer model: ChronosModel - def __init__(self, tokenizer, model): - self.tokenizer = tokenizer - self.model = model - def _prepare_and_validate_context( self, context: Union[torch.Tensor, List[torch.Tensor]] ): @@ -380,8 +431,8 @@ class ChronosPipeline: provided, and the extra 1 is for EOS. """ context_tensor = self._prepare_and_validate_context(context=context) - token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform( - context_tensor + token_ids, attention_mask, tokenizer_state = ( + self.tokenizer.context_input_transform(context_tensor) ) embeddings = self.model.encode( input_ids=token_ids.to(self.model.device), @@ -455,7 +506,7 @@ class ChronosPipeline: remaining = prediction_length while remaining > 0: - token_ids, attention_mask, scale = self.tokenizer.input_transform( + token_ids, attention_mask, scale = self.tokenizer.context_input_transform( context_tensor ) samples = self.model( diff --git a/test/test_chronos.py b/test/test_chronos.py index 480908c..9cd039c 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -38,10 +38,10 @@ def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int): context = tokenizer.centers.unsqueeze(0) # add batch dimension scale = torch.ones((1,)) # fix the scale to one to turn off scaling - token_ids, _, _ = tokenizer.input_transform(context, scale=scale) + token_ids, _, _ = tokenizer._input_transform(context, scale=scale) samples = tokenizer.output_transform( - token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension + token_ids.unsqueeze(1), # add sample dimension scale=scale, ) @@ -85,7 +85,7 @@ def test_tokenizer_fixed_data( ) batch_size, _ = context.shape - token_ids, attention_mask, scale = tokenizer.input_transform(context) + token_ids, attention_mask, scale = tokenizer.context_input_transform(context) assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token) assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size)) @@ -136,7 +136,7 @@ def test_tokenizer_random_data(use_eos_token: bool): ] ) - token_ids, attention_mask, scale = tokenizer.input_transform(context) + token_ids, attention_mask, scale = tokenizer.context_input_transform(context) assert token_ids.shape == ( *context.shape[:-1],