From 2f92a126d3b6ec81111bb120953a5310e5ee8502 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Thu, 13 Jun 2024 17:37:04 +0200 Subject: [PATCH] Add support for causal models (#113) *Description of changes:* This PR adds support for training causal/decoder-only models. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari --- scripts/README.md | 3 ++ scripts/training/configs/chronos-gpt2.yaml | 35 +++++++++++++ scripts/training/train.py | 60 ++++++++++++++++++++-- src/chronos/chronos.py | 2 +- 4 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 scripts/training/configs/chronos-gpt2.yaml diff --git a/scripts/README.md b/scripts/README.md index 925aefc..3dadd90 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -89,6 +89,9 @@ The output and checkpoints will be saved in `output/run-{id}/`. > [!TIP] > If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`. + +> [!IMPORTANT] +> When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results. - (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `/chronos-t5-small-fine-tuned`. ```py from chronos import ChronosPipeline diff --git a/scripts/training/configs/chronos-gpt2.yaml b/scripts/training/configs/chronos-gpt2.yaml new file mode 100644 index 0000000..4d917ad --- /dev/null +++ b/scripts/training/configs/chronos-gpt2.yaml @@ -0,0 +1,35 @@ +training_data_paths: +- "/home/ubuntu/tsmixup-data.arrow" +- "/home/ubuntu/kernelsynth-data.arrow" +probability: +- 0.9 +- 0.1 +context_length: 512 +prediction_length: 64 +min_past: 60 +max_steps: 200_000 +save_steps: 100_000 +log_steps: 500 +per_device_train_batch_size: 32 +learning_rate: 0.001 +optim: adamw_torch_fused +num_samples: 20 +shuffle_buffer_length: 100_000 +gradient_accumulation_steps: 1 +model_id: openai-community/gpt2 +model_type: causal +random_init: false +tie_embeddings: false +output_dir: ./output/ +tf32: true +torch_compile: true +tokenizer_class: "MeanScaleUniformBins" +tokenizer_kwargs: + low_limit: -15.0 + high_limit: 15.0 +n_tokens: 4096 +lr_scheduler_type: linear +warmup_ratio: 0.0 +dataloader_num_workers: 1 +max_missing_prop: 0.1 +use_eos_token: true diff --git a/scripts/training/train.py b/scripts/training/train.py index d14aa75..ee6f99d 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -39,6 +39,9 @@ from gluonts.transform import ( ValidationSplitSampler, InstanceSplitter, ExpectedNumInstanceSampler, + MissingValueImputation, + LeavesMissingValues, + LastValueImputation, ) from chronos import ChronosConfig, ChronosTokenizer @@ -301,6 +304,8 @@ class ChronosDataset(IterableDataset, ShuffleMixin): prediction_length: int = 64, drop_prob: float = 0.2, min_past: Optional[int] = None, + model_type: str = "seq2seq", + imputation_method: Optional[MissingValueImputation] = None, mode: str = "training", np_dtype=np.float32, ) -> None: @@ -308,6 +313,7 @@ class ChronosDataset(IterableDataset, ShuffleMixin): assert len(probabilities) == len(datasets) assert mode in ("training", "validation", "test") + assert model_type in ("seq2seq", "causal") self.datasets = datasets self.probabilities = probabilities @@ -316,6 +322,8 @@ class ChronosDataset(IterableDataset, ShuffleMixin): self.prediction_length = prediction_length self.drop_prob = drop_prob self.min_past = min_past or prediction_length + self.model_type = model_type + self.imputation_method = imputation_method or LeavesMissingValues() self.mode = mode self.np_dtype = np_dtype @@ -324,6 +332,11 @@ class ChronosDataset(IterableDataset, ShuffleMixin): entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype) assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1" + if self.model_type == "causal": + # Causal models do not play nice with missing values, so it is + # recommended to use an imputation method, e.g., LastValueImputation + entry["target"] = self.imputation_method(entry["target"]) + if mode == "training" and self.drop_prob > 0: target = entry["target"].copy() drop_p = np.random.uniform(low=0.0, high=self.drop_prob) @@ -386,6 +399,48 @@ class ChronosDataset(IterableDataset, ShuffleMixin): future_target = torch.tensor(entry["future_target"]).unsqueeze(0) labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale) labels[labels_mask == 0] = -100 + + if self.model_type == "causal": + # The InstanceSplitter pads time series on the left to be equal to the + # context_length. However, certain models (e.g., GPT2) with absolute + # position embeddings should not be trained with left padding. + # The following piece of code moves padding from left to right. + + assert input_ids.shape[-1] == entry["past_is_pad"].shape[0] + + # Find the index where padding starts + pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1) + padded_input_ids, obs_input_ids = torch.tensor_split( + input_ids, [pad_start_idx], dim=-1 + ) + padded_attention_mask, obs_attention_mask = torch.tensor_split( + attention_mask, [pad_start_idx], dim=-1 + ) + + # Move padding to the right + input_ids = torch.cat( + [ + obs_input_ids, + labels, + padded_input_ids, + ], + axis=-1, + ) + attention_mask = torch.cat( + [ + obs_attention_mask, + labels_mask, + padded_attention_mask, + ], + axis=-1, + ) + + # labels for causal models are same as the input_ids. + # Internally transformers shifts the labels by one during training. + labels = input_ids.clone() + input_ids[~attention_mask] = self.tokenizer.config.pad_token_id + labels[~attention_mask] = -100 + return { "input_ids": input_ids.squeeze(0), "attention_mask": attention_mask.squeeze(0), @@ -520,9 +575,6 @@ def main( assert model_type in ["seq2seq", "causal"] - if not model_type == "seq2seq": - raise NotImplementedError("Only seq2seq models are currently supported") - output_dir = get_next_path("run", base_dir=output_dir, file_type="") log_on_main(f"Logging dir: {output_dir}", logger) @@ -588,6 +640,8 @@ def main( context_length=context_length, prediction_length=prediction_length, min_past=min_past, + model_type=model_type, + imputation_method=LastValueImputation() if model_type == "causal" else None, mode="training", ).shuffle(shuffle_buffer_length=shuffle_buffer_length) diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index f4b3377..3b17502 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -551,7 +551,7 @@ class ChronosPipeline: if chronos_config.model_type == "seq2seq": inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs) else: - assert config.model_type == "causal" + assert chronos_config.model_type == "causal" inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) return cls(