Split input_transform into context_input_transform and label_input_transform (#82)

*Description of changes:* This splits `input_transform` into
`context_input_transform` and `label_input_transform`. Previously,
`input_transform` was being used for both context and label during
training which would lead to incorrect results where `prediction_length`
> `context_length`.

TODO:

- [x] Update docstrings
- [x] Test the training script

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
Abdul Fatir 2024-05-28 09:58:22 +02:00 committed by GitHub
parent ea26e3d7a7
commit 223e576e2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 94 additions and 41 deletions

View file

@ -387,9 +387,11 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
def to_hf_format(self, entry: dict) -> dict:
past_target = torch.tensor(entry["past_target"]).unsqueeze(0)
input_ids, attention_mask, scale = self.tokenizer.input_transform(past_target)
input_ids, attention_mask, scale = self.tokenizer.context_input_transform(
past_target
)
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
labels, labels_mask, _ = self.tokenizer.input_transform(future_target, scale)
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
labels[labels_mask == 0] = -100
return {
"input_ids": input_ids.squeeze(0),

View file

@ -26,14 +26,14 @@ class ChronosConfig:
tokenizer_class: str
tokenizer_kwargs: Dict[str, Any]
context_length: int
prediction_length: int
n_tokens: int
n_special_tokens: int
pad_token_id: int
eos_token_id: int
use_eos_token: bool
model_type: Literal["causal", "seq2seq"]
context_length: int
prediction_length: int
num_samples: int
temperature: float
top_k: int
@ -59,13 +59,12 @@ class ChronosTokenizer:
which concrete classes must implement.
"""
def input_transform(
def context_input_transform(
self,
context: torch.Tensor,
tokenizer_state: Any = None,
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
) -> Tuple:
"""
Turn a batch of time series into token IDs, attention map, and scale.
Turn a batch of time series into token IDs, attention map, and tokenizer_state.
Parameters
----------
@ -73,13 +72,6 @@ class ChronosTokenizer:
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
tokenizer_state
An object returned by ``input_transform`` containing
relevant information to preprocess data, such as location and
scale. The nature of this depends on the specific tokenizer.
This is useful when tokenizing the label (for training), in
order to use the same scaling used to tokenize the context;
when tokenizing the context, this argument should be ignored.
Returns
-------
@ -92,9 +84,41 @@ class ChronosTokenizer:
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
tokenizer_state
An object that will be passed to ``output_transform``.
Contains the relevant information to decode output samples into
real values, such as location and scale parameters.
An object that can be passed to ``label_input_transform``
and ``output_transform``. Contains the relevant information
to decode output samples into real values,
such as location and scale parameters.
"""
raise NotImplementedError()
def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
"""
Turn a batch of label slices of time series into token IDs and attention map
using the ``tokenizer_state`` provided by ``context_input_transform``.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
tokenizer_state
An object returned by ``context_input_transform`` containing
relevant information to preprocess data, such as location and
scale. The nature of this depends on the specific tokenizer.
This is used for tokenizing the label, in order to use the same
scaling used to tokenize the context.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
"""
raise NotImplementedError()
@ -141,14 +165,9 @@ class MeanScaleUniformBins(ChronosTokenizer):
)
)
def input_transform(
def _input_transform(
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, length = context.shape
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
attention_mask = ~torch.isnan(context)
if scale is None:
@ -170,16 +189,51 @@ class MeanScaleUniformBins(ChronosTokenizer):
)
token_ids[~attention_mask] = self.config.pad_token_id
if self.config.use_eos_token:
eos_tokens = torch.full(
(batch_size, 1), fill_value=self.config.eos_token_id
return token_ids, attention_mask, scale
def _append_eos_token(
self, token_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = token_ids.shape[0]
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
return token_ids, attention_mask
def context_input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
length = context.shape[-1]
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
token_ids, attention_mask, scale = self._input_transform(context=context)
if self.config.use_eos_token and self.config.model_type == "seq2seq":
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
return token_ids, attention_mask, scale
def label_input_transform(
self, label: torch.Tensor, scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
length = label.shape[-1]
assert length == self.config.prediction_length
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
if self.config.use_eos_token:
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
return token_ids, attention_mask
def output_transform(
self, samples: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
return torch.stack(padded)
@dataclass
class ChronosPipeline:
"""
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
@ -337,10 +392,6 @@ class ChronosPipeline:
tokenizer: ChronosTokenizer
model: ChronosModel
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
):
@ -380,8 +431,8 @@ class ChronosPipeline:
provided, and the extra 1 is for EOS.
"""
context_tensor = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
context_tensor
token_ids, attention_mask, tokenizer_state = (
self.tokenizer.context_input_transform(context_tensor)
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
@ -455,7 +506,7 @@ class ChronosPipeline:
remaining = prediction_length
while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.input_transform(
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
context_tensor
)
samples = self.model(

View file

@ -38,10 +38,10 @@ def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
context = tokenizer.centers.unsqueeze(0) # add batch dimension
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
token_ids, _, _ = tokenizer._input_transform(context, scale=scale)
samples = tokenizer.output_transform(
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
token_ids.unsqueeze(1), # add sample dimension
scale=scale,
)
@ -85,7 +85,7 @@ def test_tokenizer_fixed_data(
)
batch_size, _ = context.shape
token_ids, attention_mask, scale = tokenizer.input_transform(context)
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
@ -136,7 +136,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
]
)
token_ids, attention_mask, scale = tokenizer.input_transform(context)
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
assert token_ids.shape == (
*context.shape[:-1],