diff --git a/scripts/training/train.py b/scripts/training/train.py index ee6f99d..01e1678 100644 --- a/scripts/training/train.py +++ b/scripts/training/train.py @@ -320,7 +320,7 @@ class ChronosDataset(IterableDataset, ShuffleMixin): self.tokenizer = tokenizer self.context_length = context_length self.prediction_length = prediction_length - self.drop_prob = drop_prob + self.drop_prob = drop_prob if model_type == "seq2seq" else 0.0 self.min_past = min_past or prediction_length self.model_type = model_type self.imputation_method = imputation_method or LeavesMissingValues()