mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Split input_transform into context_input_transform and label_input_transform (#82)
*Description of changes:* This splits `input_transform` into `context_input_transform` and `label_input_transform`. Previously, `input_transform` was being used for both context and label during training which would lead to incorrect results where `prediction_length` > `context_length`. TODO: - [x] Update docstrings - [x] Test the training script By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
This commit is contained in:
parent
ea26e3d7a7
commit
223e576e2e
3 changed files with 94 additions and 41 deletions
|
|
@ -387,9 +387,11 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
|
|||
|
||||
def to_hf_format(self, entry: dict) -> dict:
|
||||
past_target = torch.tensor(entry["past_target"]).unsqueeze(0)
|
||||
input_ids, attention_mask, scale = self.tokenizer.input_transform(past_target)
|
||||
input_ids, attention_mask, scale = self.tokenizer.context_input_transform(
|
||||
past_target
|
||||
)
|
||||
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
|
||||
labels, labels_mask, _ = self.tokenizer.input_transform(future_target, scale)
|
||||
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
|
||||
labels[labels_mask == 0] = -100
|
||||
return {
|
||||
"input_ids": input_ids.squeeze(0),
|
||||
|
|
|
|||
|
|
@ -26,14 +26,14 @@ class ChronosConfig:
|
|||
|
||||
tokenizer_class: str
|
||||
tokenizer_kwargs: Dict[str, Any]
|
||||
context_length: int
|
||||
prediction_length: int
|
||||
n_tokens: int
|
||||
n_special_tokens: int
|
||||
pad_token_id: int
|
||||
eos_token_id: int
|
||||
use_eos_token: bool
|
||||
model_type: Literal["causal", "seq2seq"]
|
||||
context_length: int
|
||||
prediction_length: int
|
||||
num_samples: int
|
||||
temperature: float
|
||||
top_k: int
|
||||
|
|
@ -59,13 +59,12 @@ class ChronosTokenizer:
|
|||
which concrete classes must implement.
|
||||
"""
|
||||
|
||||
def input_transform(
|
||||
def context_input_transform(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
tokenizer_state: Any = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
|
||||
) -> Tuple:
|
||||
"""
|
||||
Turn a batch of time series into token IDs, attention map, and scale.
|
||||
Turn a batch of time series into token IDs, attention map, and tokenizer_state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -73,13 +72,6 @@ class ChronosTokenizer:
|
|||
A tensor shaped (batch_size, time_length), containing the
|
||||
timeseries to forecast. Use left-padding with ``torch.nan``
|
||||
to align time series of different lengths.
|
||||
tokenizer_state
|
||||
An object returned by ``input_transform`` containing
|
||||
relevant information to preprocess data, such as location and
|
||||
scale. The nature of this depends on the specific tokenizer.
|
||||
This is useful when tokenizing the label (for training), in
|
||||
order to use the same scaling used to tokenize the context;
|
||||
when tokenizing the context, this argument should be ignored.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -92,9 +84,41 @@ class ChronosTokenizer:
|
|||
which input observations are not ``torch.nan`` (i.e. not
|
||||
missing nor padding).
|
||||
tokenizer_state
|
||||
An object that will be passed to ``output_transform``.
|
||||
Contains the relevant information to decode output samples into
|
||||
real values, such as location and scale parameters.
|
||||
An object that can be passed to ``label_input_transform``
|
||||
and ``output_transform``. Contains the relevant information
|
||||
to decode output samples into real values,
|
||||
such as location and scale parameters.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
|
||||
"""
|
||||
Turn a batch of label slices of time series into token IDs and attention map
|
||||
using the ``tokenizer_state`` provided by ``context_input_transform``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context
|
||||
A tensor shaped (batch_size, time_length), containing the
|
||||
timeseries to forecast. Use left-padding with ``torch.nan``
|
||||
to align time series of different lengths.
|
||||
tokenizer_state
|
||||
An object returned by ``context_input_transform`` containing
|
||||
relevant information to preprocess data, such as location and
|
||||
scale. The nature of this depends on the specific tokenizer.
|
||||
This is used for tokenizing the label, in order to use the same
|
||||
scaling used to tokenize the context.
|
||||
|
||||
Returns
|
||||
-------
|
||||
token_ids
|
||||
A tensor of integers, shaped (batch_size, time_length + 1)
|
||||
if ``config.use_eos_token`` and (batch_size, time_length)
|
||||
otherwise, containing token IDs for the input series.
|
||||
attention_mask
|
||||
A boolean tensor, same shape as ``token_ids``, indicating
|
||||
which input observations are not ``torch.nan`` (i.e. not
|
||||
missing nor padding).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -141,14 +165,9 @@ class MeanScaleUniformBins(ChronosTokenizer):
|
|||
)
|
||||
)
|
||||
|
||||
def input_transform(
|
||||
def _input_transform(
|
||||
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size, length = context.shape
|
||||
|
||||
if length > self.config.context_length:
|
||||
context = context[..., -self.config.context_length :]
|
||||
|
||||
attention_mask = ~torch.isnan(context)
|
||||
|
||||
if scale is None:
|
||||
|
|
@ -170,16 +189,51 @@ class MeanScaleUniformBins(ChronosTokenizer):
|
|||
)
|
||||
token_ids[~attention_mask] = self.config.pad_token_id
|
||||
|
||||
if self.config.use_eos_token:
|
||||
eos_tokens = torch.full(
|
||||
(batch_size, 1), fill_value=self.config.eos_token_id
|
||||
return token_ids, attention_mask, scale
|
||||
|
||||
def _append_eos_token(
|
||||
self, token_ids: torch.Tensor, attention_mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = token_ids.shape[0]
|
||||
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
|
||||
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
|
||||
eos_mask = torch.full((batch_size, 1), fill_value=True)
|
||||
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
|
||||
|
||||
return token_ids, attention_mask
|
||||
|
||||
def context_input_transform(
|
||||
self, context: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
length = context.shape[-1]
|
||||
|
||||
if length > self.config.context_length:
|
||||
context = context[..., -self.config.context_length :]
|
||||
|
||||
token_ids, attention_mask, scale = self._input_transform(context=context)
|
||||
|
||||
if self.config.use_eos_token and self.config.model_type == "seq2seq":
|
||||
token_ids, attention_mask = self._append_eos_token(
|
||||
token_ids=token_ids, attention_mask=attention_mask
|
||||
)
|
||||
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
|
||||
eos_mask = torch.full((batch_size, 1), fill_value=True)
|
||||
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
|
||||
|
||||
return token_ids, attention_mask, scale
|
||||
|
||||
def label_input_transform(
|
||||
self, label: torch.Tensor, scale: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
length = label.shape[-1]
|
||||
|
||||
assert length == self.config.prediction_length
|
||||
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
|
||||
|
||||
if self.config.use_eos_token:
|
||||
token_ids, attention_mask = self._append_eos_token(
|
||||
token_ids=token_ids, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
return token_ids, attention_mask
|
||||
|
||||
def output_transform(
|
||||
self, samples: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -318,6 +372,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
|||
return torch.stack(padded)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosPipeline:
|
||||
"""
|
||||
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
|
||||
|
|
@ -337,10 +392,6 @@ class ChronosPipeline:
|
|||
tokenizer: ChronosTokenizer
|
||||
model: ChronosModel
|
||||
|
||||
def __init__(self, tokenizer, model):
|
||||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
|
||||
def _prepare_and_validate_context(
|
||||
self, context: Union[torch.Tensor, List[torch.Tensor]]
|
||||
):
|
||||
|
|
@ -380,8 +431,8 @@ class ChronosPipeline:
|
|||
provided, and the extra 1 is for EOS.
|
||||
"""
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
|
||||
context_tensor
|
||||
token_ids, attention_mask, tokenizer_state = (
|
||||
self.tokenizer.context_input_transform(context_tensor)
|
||||
)
|
||||
embeddings = self.model.encode(
|
||||
input_ids=token_ids.to(self.model.device),
|
||||
|
|
@ -455,7 +506,7 @@ class ChronosPipeline:
|
|||
remaining = prediction_length
|
||||
|
||||
while remaining > 0:
|
||||
token_ids, attention_mask, scale = self.tokenizer.input_transform(
|
||||
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
|
||||
context_tensor
|
||||
)
|
||||
samples = self.model(
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
|
|||
context = tokenizer.centers.unsqueeze(0) # add batch dimension
|
||||
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
|
||||
|
||||
token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
|
||||
token_ids, _, _ = tokenizer._input_transform(context, scale=scale)
|
||||
|
||||
samples = tokenizer.output_transform(
|
||||
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
|
||||
token_ids.unsqueeze(1), # add sample dimension
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ def test_tokenizer_fixed_data(
|
|||
)
|
||||
batch_size, _ = context.shape
|
||||
|
||||
token_ids, attention_mask, scale = tokenizer.input_transform(context)
|
||||
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
|
||||
|
||||
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
|
||||
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
|
||||
|
|
@ -136,7 +136,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
|
|||
]
|
||||
)
|
||||
|
||||
token_ids, attention_mask, scale = tokenizer.input_transform(context)
|
||||
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
|
||||
|
||||
assert token_ids.shape == (
|
||||
*context.shape[:-1],
|
||||
|
|
|
|||
Loading…
Reference in a new issue