Chronos-2: Ensure base model revision is saved in LoRA adapter config

Description:
This commit fixes a critical issue where the base model's revision was failing to be recorded in `adapter_config.json` during LoRA fine-tuning.
Previously, `fit()` did not propagate the loaded model's revision to `LoraConfig`, resulting in the `revision` field being missing or None in the saved adapter configuration.

This omission meant that loading the saved adapter would silently default to the `main` branch of the base model, rather than the specific revision used during training,
 breaking reproducibility.

Key changes:
* Persist Revision: Store the `revision` passed to `from_pretrained` in `model.config._source_revision`.
* Propagate to Config: Update `fit()` to inject this stored revision into `LoraConfig` so it is correctly saved in `adapter_config.json`.
* Validate Integrity: Raise `ValueError` if an explicit `revision` argument in `fit()` conflicts with the loaded model's source revision.

Impact:
EEnsures that adapter_config.json always contains the correct base model revision.
This guarantees that AutoGluon can faithfully reproduce the model state by simply loading the adapter,
without risking a silent version mismatch, even in scenarios where the base model receives minimal or infrequent updates.
This commit is contained in:
Ryuichi Ichinose 2025-12-11 23:51:07 +09:00
parent eb5b61234a
commit 53bc613244
No known key found for this signature in database
GPG key ID: D5D8D056D6A7CA53
2 changed files with 52 additions and 0 deletions

View file

@ -195,6 +195,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
model.load_state_dict(self.model.state_dict())
if finetune_mode == "lora":
lora_revision = getattr(self.model.config, "_source_revision", None)
if lora_config is None:
lora_config = LoraConfig(
r=8,
@ -206,8 +207,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
"self_attention.o",
"output_patch_embedding.output_layer",
],
revision=lora_revision,
)
elif isinstance(lora_config, dict):
lora_config.setdefault("revision", lora_revision)
lora_config = LoraConfig(**lora_config)
else:
assert isinstance(lora_config, LoraConfig), (
@ -1161,6 +1164,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
Load the model, either from a local path, S3 prefix or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
"""
revision = kwargs.get("revision")
# Check if the model is on S3 and cache it locally first
# NOTE: Only base models (not LoRA adapters) are supported via S3
@ -1178,6 +1182,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
model = model.merge_and_unload()
if revision:
model.config._source_revision = revision
return cls(model=model)
# Handle the case for the base model
@ -1192,6 +1200,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
class_ = Chronos2Model
model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if revision:
model.config._source_revision = revision
return cls(model=model)
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):

View file

@ -1132,3 +1132,43 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
@pytest.mark.parametrize("source_revision", ["my-test-branch", None])
def test_lora_config_uses_source_revision_from_instantiation(
pipeline: Chronos2Pipeline, tmpdir, source_revision
):
"""
Test that fit in 'lora' mode correctly uses the '_source_revision'
that was (notionally) stored in the model's config during instantiation.
"""
output_dir = Path(tmpdir)
dummy_inputs = [torch.rand(100)]
# Manually set the source revision on the config to simulate
# it being set during from_pretrained
if source_revision:
pipeline.model.config._source_revision = source_revision
elif hasattr(pipeline.model.config, "_source_revision"):
delattr(pipeline.model.config, "_source_revision")
pipeline.fit(
inputs=dummy_inputs,
prediction_length=10,
finetune_mode="lora",
output_dir=output_dir,
num_steps=1, # Keep it fast
batch_size=32,
)
adapter_config_path = output_dir / "finetuned-ckpt" / "adapter_config.json"
assert adapter_config_path.exists(), "adapter_config.json was not created"
with open(adapter_config_path, "r") as f:
adapter_config = json.load(f)
if source_revision is not None:
assert "revision" in adapter_config
assert adapter_config["revision"] == source_revision
else:
assert "revision" not in adapter_config or adapter_config["revision"] is None