mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 17:48:23 +00:00
Add support for causal models (#113)
*Description of changes:* This PR adds support for training causal/decoder-only models. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.de>
This commit is contained in:
parent
79028e3154
commit
2f92a126d3
4 changed files with 96 additions and 4 deletions
|
|
@ -89,6 +89,9 @@
|
|||
The output and checkpoints will be saved in `output/run-{id}/`.
|
||||
> [!TIP]
|
||||
> If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results.
|
||||
- (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `<your_hf_username>/chronos-t5-small-fine-tuned`.
|
||||
```py
|
||||
from chronos import ChronosPipeline
|
||||
|
|
|
|||
35
scripts/training/configs/chronos-gpt2.yaml
Normal file
35
scripts/training/configs/chronos-gpt2.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
training_data_paths:
|
||||
- "/home/ubuntu/tsmixup-data.arrow"
|
||||
- "/home/ubuntu/kernelsynth-data.arrow"
|
||||
probability:
|
||||
- 0.9
|
||||
- 0.1
|
||||
context_length: 512
|
||||
prediction_length: 64
|
||||
min_past: 60
|
||||
max_steps: 200_000
|
||||
save_steps: 100_000
|
||||
log_steps: 500
|
||||
per_device_train_batch_size: 32
|
||||
learning_rate: 0.001
|
||||
optim: adamw_torch_fused
|
||||
num_samples: 20
|
||||
shuffle_buffer_length: 100_000
|
||||
gradient_accumulation_steps: 1
|
||||
model_id: openai-community/gpt2
|
||||
model_type: causal
|
||||
random_init: false
|
||||
tie_embeddings: false
|
||||
output_dir: ./output/
|
||||
tf32: true
|
||||
torch_compile: true
|
||||
tokenizer_class: "MeanScaleUniformBins"
|
||||
tokenizer_kwargs:
|
||||
low_limit: -15.0
|
||||
high_limit: 15.0
|
||||
n_tokens: 4096
|
||||
lr_scheduler_type: linear
|
||||
warmup_ratio: 0.0
|
||||
dataloader_num_workers: 1
|
||||
max_missing_prop: 0.1
|
||||
use_eos_token: true
|
||||
|
|
@ -39,6 +39,9 @@ from gluonts.transform import (
|
|||
ValidationSplitSampler,
|
||||
InstanceSplitter,
|
||||
ExpectedNumInstanceSampler,
|
||||
MissingValueImputation,
|
||||
LeavesMissingValues,
|
||||
LastValueImputation,
|
||||
)
|
||||
|
||||
from chronos import ChronosConfig, ChronosTokenizer
|
||||
|
|
@ -301,6 +304,8 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
|
|||
prediction_length: int = 64,
|
||||
drop_prob: float = 0.2,
|
||||
min_past: Optional[int] = None,
|
||||
model_type: str = "seq2seq",
|
||||
imputation_method: Optional[MissingValueImputation] = None,
|
||||
mode: str = "training",
|
||||
np_dtype=np.float32,
|
||||
) -> None:
|
||||
|
|
@ -308,6 +313,7 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
|
|||
|
||||
assert len(probabilities) == len(datasets)
|
||||
assert mode in ("training", "validation", "test")
|
||||
assert model_type in ("seq2seq", "causal")
|
||||
|
||||
self.datasets = datasets
|
||||
self.probabilities = probabilities
|
||||
|
|
@ -316,6 +322,8 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
|
|||
self.prediction_length = prediction_length
|
||||
self.drop_prob = drop_prob
|
||||
self.min_past = min_past or prediction_length
|
||||
self.model_type = model_type
|
||||
self.imputation_method = imputation_method or LeavesMissingValues()
|
||||
self.mode = mode
|
||||
self.np_dtype = np_dtype
|
||||
|
||||
|
|
@ -324,6 +332,11 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
|
|||
entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype)
|
||||
assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1"
|
||||
|
||||
if self.model_type == "causal":
|
||||
# Causal models do not play nice with missing values, so it is
|
||||
# recommended to use an imputation method, e.g., LastValueImputation
|
||||
entry["target"] = self.imputation_method(entry["target"])
|
||||
|
||||
if mode == "training" and self.drop_prob > 0:
|
||||
target = entry["target"].copy()
|
||||
drop_p = np.random.uniform(low=0.0, high=self.drop_prob)
|
||||
|
|
@ -386,6 +399,48 @@ class ChronosDataset(IterableDataset, ShuffleMixin):
|
|||
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
|
||||
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
|
||||
labels[labels_mask == 0] = -100
|
||||
|
||||
if self.model_type == "causal":
|
||||
# The InstanceSplitter pads time series on the left to be equal to the
|
||||
# context_length. However, certain models (e.g., GPT2) with absolute
|
||||
# position embeddings should not be trained with left padding.
|
||||
# The following piece of code moves padding from left to right.
|
||||
|
||||
assert input_ids.shape[-1] == entry["past_is_pad"].shape[0]
|
||||
|
||||
# Find the index where padding starts
|
||||
pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1)
|
||||
padded_input_ids, obs_input_ids = torch.tensor_split(
|
||||
input_ids, [pad_start_idx], dim=-1
|
||||
)
|
||||
padded_attention_mask, obs_attention_mask = torch.tensor_split(
|
||||
attention_mask, [pad_start_idx], dim=-1
|
||||
)
|
||||
|
||||
# Move padding to the right
|
||||
input_ids = torch.cat(
|
||||
[
|
||||
obs_input_ids,
|
||||
labels,
|
||||
padded_input_ids,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
obs_attention_mask,
|
||||
labels_mask,
|
||||
padded_attention_mask,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# labels for causal models are same as the input_ids.
|
||||
# Internally transformers shifts the labels by one during training.
|
||||
labels = input_ids.clone()
|
||||
input_ids[~attention_mask] = self.tokenizer.config.pad_token_id
|
||||
labels[~attention_mask] = -100
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.squeeze(0),
|
||||
"attention_mask": attention_mask.squeeze(0),
|
||||
|
|
@ -520,9 +575,6 @@ def main(
|
|||
|
||||
assert model_type in ["seq2seq", "causal"]
|
||||
|
||||
if not model_type == "seq2seq":
|
||||
raise NotImplementedError("Only seq2seq models are currently supported")
|
||||
|
||||
output_dir = get_next_path("run", base_dir=output_dir, file_type="")
|
||||
|
||||
log_on_main(f"Logging dir: {output_dir}", logger)
|
||||
|
|
@ -588,6 +640,8 @@ def main(
|
|||
context_length=context_length,
|
||||
prediction_length=prediction_length,
|
||||
min_past=min_past,
|
||||
model_type=model_type,
|
||||
imputation_method=LastValueImputation() if model_type == "causal" else None,
|
||||
mode="training",
|
||||
).shuffle(shuffle_buffer_length=shuffle_buffer_length)
|
||||
|
||||
|
|
|
|||
|
|
@ -551,7 +551,7 @@ class ChronosPipeline:
|
|||
if chronos_config.model_type == "seq2seq":
|
||||
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
|
||||
else:
|
||||
assert config.model_type == "causal"
|
||||
assert chronos_config.model_type == "causal"
|
||||
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
|
||||
|
||||
return cls(
|
||||
|
|
|
|||
Loading…
Reference in a new issue