diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 3eddcd7..1f882c3 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -195,6 +195,7 @@ class Chronos2Pipeline(BaseChronosPipeline): model.load_state_dict(self.model.state_dict()) if finetune_mode == "lora": + lora_revision = getattr(self.model.config, "_source_revision", None) if lora_config is None: lora_config = LoraConfig( r=8, @@ -206,8 +207,10 @@ class Chronos2Pipeline(BaseChronosPipeline): "self_attention.o", "output_patch_embedding.output_layer", ], + revision=lora_revision, ) elif isinstance(lora_config, dict): + lora_config.setdefault("revision", lora_revision) lora_config = LoraConfig(**lora_config) else: assert isinstance(lora_config, LoraConfig), ( @@ -1161,6 +1164,7 @@ class Chronos2Pipeline(BaseChronosPipeline): Load the model, either from a local path, S3 prefix or from the HuggingFace Hub. Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``. """ + revision = kwargs.get("revision") # Check if the model is on S3 and cache it locally first # NOTE: Only base models (not LoRA adapters) are supported via S3 @@ -1178,6 +1182,10 @@ class Chronos2Pipeline(BaseChronosPipeline): model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) model = model.merge_and_unload() + + if revision: + model.config._source_revision = revision + return cls(model=model) # Handle the case for the base model @@ -1192,6 +1200,10 @@ class Chronos2Pipeline(BaseChronosPipeline): class_ = Chronos2Model model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + if revision: + model.config._source_revision = revision + return cls(model=model) def save_pretrained(self, save_directory: str | Path, *args, **kwargs): diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 3fac726..a570ad3 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1132,3 +1132,43 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): # Should match exactly or very close (numerical precision) assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + +@pytest.mark.parametrize("source_revision", ["my-test-branch", None]) +def test_lora_config_uses_source_revision_from_instantiation( + pipeline: Chronos2Pipeline, tmpdir, source_revision +): + """ + Test that fit in 'lora' mode correctly uses the '_source_revision' + that was (notionally) stored in the model's config during instantiation. + """ + output_dir = Path(tmpdir) + dummy_inputs = [torch.rand(100)] + + # Manually set the source revision on the config to simulate + # it being set during from_pretrained + if source_revision: + pipeline.model.config._source_revision = source_revision + elif hasattr(pipeline.model.config, "_source_revision"): + delattr(pipeline.model.config, "_source_revision") + + pipeline.fit( + inputs=dummy_inputs, + prediction_length=10, + finetune_mode="lora", + output_dir=output_dir, + num_steps=1, # Keep it fast + batch_size=32, + ) + + adapter_config_path = output_dir / "finetuned-ckpt" / "adapter_config.json" + assert adapter_config_path.exists(), "adapter_config.json was not created" + + with open(adapter_config_path, "r") as f: + adapter_config = json.load(f) + + if source_revision is not None: + assert "revision" in adapter_config + assert adapter_config["revision"] == source_revision + else: + assert "revision" not in adapter_config or adapter_config["revision"] is None