diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 74b4771..8c10e16 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -331,7 +331,8 @@ class Chronos2Pipeline(BaseChronosPipeline): trainer.train() - # update max_output_patches, if the model was fine-tuned with longer prediction_length + # update context_length and max_output_patches, if the model was fine-tuned with larger values + model.chronos_config.context_length = max(model.chronos_config.context_length, context_length) model.chronos_config.max_output_patches = max( model.chronos_config.max_output_patches, math.ceil(prediction_length / self.model_output_patch_size) )