diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 1f882c3..7ffafee 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -40,10 +40,10 @@ class Chronos2Pipeline(BaseChronosPipeline): forecast_type: ForecastType = ForecastType.QUANTILES default_context_length: int = 2048 - def __init__(self, model: Chronos2Model): + def __init__(self, model: Chronos2Model, revision: str | None = None): super().__init__(inner_model=model) self.model = model - + self.revision = revision @staticmethod def _get_prob_mass_per_quantile_level(quantile_levels: torch.Tensor) -> torch.Tensor: """ @@ -195,7 +195,7 @@ class Chronos2Pipeline(BaseChronosPipeline): model.load_state_dict(self.model.state_dict()) if finetune_mode == "lora": - lora_revision = getattr(self.model.config, "_source_revision", None) + lora_revision = self.revision if lora_config is None: lora_config = LoraConfig( r=8, @@ -210,7 +210,6 @@ class Chronos2Pipeline(BaseChronosPipeline): revision=lora_revision, ) elif isinstance(lora_config, dict): - lora_config.setdefault("revision", lora_revision) lora_config = LoraConfig(**lora_config) else: assert isinstance(lora_config, LoraConfig), ( @@ -1182,11 +1181,7 @@ class Chronos2Pipeline(BaseChronosPipeline): model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) model = model.merge_and_unload() - - if revision: - model.config._source_revision = revision - - return cls(model=model) + return cls(model=model, revision=revision) # Handle the case for the base model config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) @@ -1200,11 +1195,7 @@ class Chronos2Pipeline(BaseChronosPipeline): class_ = Chronos2Model model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - - if revision: - model.config._source_revision = revision - - return cls(model=model) + return cls(model=model, revision=revision) def save_pretrained(self, save_directory: str | Path, *args, **kwargs): """ diff --git a/test/test_chronos2.py b/test/test_chronos2.py index a570ad3..14f8179 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1139,18 +1139,13 @@ def test_lora_config_uses_source_revision_from_instantiation( pipeline: Chronos2Pipeline, tmpdir, source_revision ): """ - Test that fit in 'lora' mode correctly uses the '_source_revision' - that was (notionally) stored in the model's config during instantiation. + Test that fit in 'lora' mode correctly uses the 'revision' + stored in the pipeline instance. """ output_dir = Path(tmpdir) dummy_inputs = [torch.rand(100)] - # Manually set the source revision on the config to simulate - # it being set during from_pretrained - if source_revision: - pipeline.model.config._source_revision = source_revision - elif hasattr(pipeline.model.config, "_source_revision"): - delattr(pipeline.model.config, "_source_revision") + pipeline.revision = source_revision pipeline.fit( inputs=dummy_inputs,