mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-23 09:39:35 +00:00
Merge 6248da86ec into 1f099eb265
This commit is contained in:
commit
28267b1d8e
2 changed files with 43 additions and 4 deletions
|
|
@ -40,10 +40,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
forecast_type: ForecastType = ForecastType.QUANTILES
|
||||
default_context_length: int = 2048
|
||||
|
||||
def __init__(self, model: Chronos2Model):
|
||||
def __init__(self, model: Chronos2Model, revision: str | None = None):
|
||||
super().__init__(inner_model=model)
|
||||
self.model = model
|
||||
|
||||
self.revision = revision
|
||||
@staticmethod
|
||||
def _get_prob_mass_per_quantile_level(quantile_levels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
|
@ -198,6 +198,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
model.load_state_dict(self.model.state_dict())
|
||||
|
||||
if finetune_mode == "lora":
|
||||
lora_revision = self.revision
|
||||
if lora_config is None:
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
|
|
@ -209,8 +210,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
"self_attention.o",
|
||||
"output_patch_embedding.output_layer",
|
||||
],
|
||||
revision=lora_revision,
|
||||
)
|
||||
elif isinstance(lora_config, dict):
|
||||
lora_config.setdefault("revision", lora_revision)
|
||||
lora_config = LoraConfig(**lora_config)
|
||||
else:
|
||||
assert isinstance(lora_config, LoraConfig), (
|
||||
|
|
@ -1182,6 +1185,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
Load the model, either from a local path, S3 prefix or from the HuggingFace Hub.
|
||||
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
|
||||
"""
|
||||
revision = kwargs.get("revision")
|
||||
|
||||
# Check if the model is on S3 and cache it locally first
|
||||
# NOTE: Only base models (not LoRA adapters) are supported via S3
|
||||
|
|
@ -1199,7 +1203,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
model = model.merge_and_unload()
|
||||
return cls(model=model)
|
||||
return cls(model=model, revision=revision)
|
||||
|
||||
# Handle the case for the base model
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
|
|
@ -1213,7 +1217,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
class_ = Chronos2Model
|
||||
|
||||
model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
return cls(model=model)
|
||||
return cls(model=model, revision=revision)
|
||||
|
||||
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1143,3 +1143,38 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline):
|
|||
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
|
||||
# Should match exactly or very close (numerical precision)
|
||||
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("source_revision", ["my-test-branch", None])
|
||||
def test_lora_config_uses_source_revision_from_instantiation(
|
||||
pipeline: Chronos2Pipeline, tmpdir, source_revision
|
||||
):
|
||||
"""
|
||||
Test that fit in 'lora' mode correctly uses the 'revision'
|
||||
stored in the pipeline instance.
|
||||
"""
|
||||
output_dir = Path(tmpdir)
|
||||
dummy_inputs = [torch.rand(100)]
|
||||
|
||||
pipeline.revision = source_revision
|
||||
|
||||
pipeline.fit(
|
||||
inputs=dummy_inputs,
|
||||
prediction_length=10,
|
||||
finetune_mode="lora",
|
||||
output_dir=output_dir,
|
||||
num_steps=1, # Keep it fast
|
||||
batch_size=32,
|
||||
)
|
||||
|
||||
adapter_config_path = output_dir / "finetuned-ckpt" / "adapter_config.json"
|
||||
assert adapter_config_path.exists(), "adapter_config.json was not created"
|
||||
|
||||
with open(adapter_config_path, "r") as f:
|
||||
adapter_config = json.load(f)
|
||||
|
||||
if source_revision is not None:
|
||||
assert "revision" in adapter_config
|
||||
assert adapter_config["revision"] == source_revision
|
||||
else:
|
||||
assert "revision" not in adapter_config or adapter_config["revision"] is None
|
||||
|
|
|
|||
Loading…
Reference in a new issue