This commit is contained in:
Ryuichi Ichinose 2026-01-19 14:07:39 +00:00 committed by GitHub
commit 28267b1d8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 4 deletions

View file

@ -40,10 +40,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
forecast_type: ForecastType = ForecastType.QUANTILES
default_context_length: int = 2048
def __init__(self, model: Chronos2Model):
def __init__(self, model: Chronos2Model, revision: str | None = None):
super().__init__(inner_model=model)
self.model = model
self.revision = revision
@staticmethod
def _get_prob_mass_per_quantile_level(quantile_levels: torch.Tensor) -> torch.Tensor:
"""
@ -198,6 +198,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
model.load_state_dict(self.model.state_dict())
if finetune_mode == "lora":
lora_revision = self.revision
if lora_config is None:
lora_config = LoraConfig(
r=8,
@ -209,8 +210,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
"self_attention.o",
"output_patch_embedding.output_layer",
],
revision=lora_revision,
)
elif isinstance(lora_config, dict):
lora_config.setdefault("revision", lora_revision)
lora_config = LoraConfig(**lora_config)
else:
assert isinstance(lora_config, LoraConfig), (
@ -1182,6 +1185,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
Load the model, either from a local path, S3 prefix or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
"""
revision = kwargs.get("revision")
# Check if the model is on S3 and cache it locally first
# NOTE: Only base models (not LoRA adapters) are supported via S3
@ -1199,7 +1203,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
model = model.merge_and_unload()
return cls(model=model)
return cls(model=model, revision=revision)
# Handle the case for the base model
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
@ -1213,7 +1217,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
class_ = Chronos2Model
model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model=model)
return cls(model=model, revision=revision)
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
"""

View file

@ -1143,3 +1143,38 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
# Should match exactly or very close (numerical precision)
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
@pytest.mark.parametrize("source_revision", ["my-test-branch", None])
def test_lora_config_uses_source_revision_from_instantiation(
pipeline: Chronos2Pipeline, tmpdir, source_revision
):
"""
Test that fit in 'lora' mode correctly uses the 'revision'
stored in the pipeline instance.
"""
output_dir = Path(tmpdir)
dummy_inputs = [torch.rand(100)]
pipeline.revision = source_revision
pipeline.fit(
inputs=dummy_inputs,
prediction_length=10,
finetune_mode="lora",
output_dir=output_dir,
num_steps=1, # Keep it fast
batch_size=32,
)
adapter_config_path = output_dir / "finetuned-ckpt" / "adapter_config.json"
assert adapter_config_path.exists(), "adapter_config.json was not created"
with open(adapter_config_path, "r") as f:
adapter_config = json.load(f)
if source_revision is not None:
assert "revision" in adapter_config
assert adapter_config["revision"] == source_revision
else:
assert "revision" not in adapter_config or adapter_config["revision"] is None