mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Chronos-2: Ensure base model revision is saved in LoRA adapter config
Description: This commit fixes a critical issue where the base model's revision was failing to be recorded in `adapter_config.json` during LoRA fine-tuning. Previously, `fit()` did not propagate the loaded model's revision to `LoraConfig`, resulting in the `revision` field being missing or None in the saved adapter configuration. This omission meant that loading the saved adapter would silently default to the `main` branch of the base model, rather than the specific revision used during training, breaking reproducibility. Key changes: * Persist Revision: Store the `revision` passed to `from_pretrained` in `model.config._source_revision`. * Propagate to Config: Update `fit()` to inject this stored revision into `LoraConfig` so it is correctly saved in `adapter_config.json`. * Validate Integrity: Raise `ValueError` if an explicit `revision` argument in `fit()` conflicts with the loaded model's source revision. Impact: EEnsures that adapter_config.json always contains the correct base model revision. This guarantees that AutoGluon can faithfully reproduce the model state by simply loading the adapter, without risking a silent version mismatch, even in scenarios where the base model receives minimal or infrequent updates.
This commit is contained in:
parent
eb5b61234a
commit
53bc613244
2 changed files with 52 additions and 0 deletions
|
|
@ -195,6 +195,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
model.load_state_dict(self.model.state_dict())
|
||||
|
||||
if finetune_mode == "lora":
|
||||
lora_revision = getattr(self.model.config, "_source_revision", None)
|
||||
if lora_config is None:
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
|
|
@ -206,8 +207,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
"self_attention.o",
|
||||
"output_patch_embedding.output_layer",
|
||||
],
|
||||
revision=lora_revision,
|
||||
)
|
||||
elif isinstance(lora_config, dict):
|
||||
lora_config.setdefault("revision", lora_revision)
|
||||
lora_config = LoraConfig(**lora_config)
|
||||
else:
|
||||
assert isinstance(lora_config, LoraConfig), (
|
||||
|
|
@ -1161,6 +1164,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
Load the model, either from a local path, S3 prefix or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
|
||||
"""
|
||||
revision = kwargs.get("revision")
|
||||
|
||||
# Check if the model is on S3 and cache it locally first
|
||||
# NOTE: Only base models (not LoRA adapters) are supported via S3
|
||||
|
|
@ -1178,6 +1182,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if revision:
|
||||
model.config._source_revision = revision
|
||||
|
||||
return cls(model=model)
|
||||
|
||||
# Handle the case for the base model
|
||||
|
|
@ -1192,6 +1200,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
class_ = Chronos2Model
|
||||
|
||||
model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
|
||||
if revision:
|
||||
model.config._source_revision = revision
|
||||
|
||||
return cls(model=model)
|
||||
|
||||
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -1132,3 +1132,43 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
|||
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
|
||||
# Should match exactly or very close (numerical precision)
|
||||
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("source_revision", ["my-test-branch", None])
|
||||
def test_lora_config_uses_source_revision_from_instantiation(
|
||||
pipeline: Chronos2Pipeline, tmpdir, source_revision
|
||||
):
|
||||
"""
|
||||
Test that fit in 'lora' mode correctly uses the '_source_revision'
|
||||
that was (notionally) stored in the model's config during instantiation.
|
||||
"""
|
||||
output_dir = Path(tmpdir)
|
||||
dummy_inputs = [torch.rand(100)]
|
||||
|
||||
# Manually set the source revision on the config to simulate
|
||||
# it being set during from_pretrained
|
||||
if source_revision:
|
||||
pipeline.model.config._source_revision = source_revision
|
||||
elif hasattr(pipeline.model.config, "_source_revision"):
|
||||
delattr(pipeline.model.config, "_source_revision")
|
||||
|
||||
pipeline.fit(
|
||||
inputs=dummy_inputs,
|
||||
prediction_length=10,
|
||||
finetune_mode="lora",
|
||||
output_dir=output_dir,
|
||||
num_steps=1, # Keep it fast
|
||||
batch_size=32,
|
||||
)
|
||||
|
||||
adapter_config_path = output_dir / "finetuned-ckpt" / "adapter_config.json"
|
||||
assert adapter_config_path.exists(), "adapter_config.json was not created"
|
||||
|
||||
with open(adapter_config_path, "r") as f:
|
||||
adapter_config = json.load(f)
|
||||
|
||||
if source_revision is not None:
|
||||
assert "revision" in adapter_config
|
||||
assert adapter_config["revision"] == source_revision
|
||||
else:
|
||||
assert "revision" not in adapter_config or adapter_config["revision"] is None
|
||||
|
|
|
|||
Loading…
Reference in a new issue