mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Refactor revision handling: remove private attribute hack
This commit is contained in:
parent
53bc613244
commit
7f2f5a1d18
2 changed files with 8 additions and 22 deletions
|
|
@ -40,10 +40,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
forecast_type: ForecastType = ForecastType.QUANTILES
|
||||
default_context_length: int = 2048
|
||||
|
||||
def __init__(self, model: Chronos2Model):
|
||||
def __init__(self, model: Chronos2Model, revision: str | None = None):
|
||||
super().__init__(inner_model=model)
|
||||
self.model = model
|
||||
|
||||
self.revision = revision
|
||||
@staticmethod
|
||||
def _get_prob_mass_per_quantile_level(quantile_levels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
|
@ -195,7 +195,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
model.load_state_dict(self.model.state_dict())
|
||||
|
||||
if finetune_mode == "lora":
|
||||
lora_revision = getattr(self.model.config, "_source_revision", None)
|
||||
lora_revision = self.revision
|
||||
if lora_config is None:
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
|
|
@ -210,7 +210,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
revision=lora_revision,
|
||||
)
|
||||
elif isinstance(lora_config, dict):
|
||||
lora_config.setdefault("revision", lora_revision)
|
||||
lora_config = LoraConfig(**lora_config)
|
||||
else:
|
||||
assert isinstance(lora_config, LoraConfig), (
|
||||
|
|
@ -1182,11 +1181,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
|
||||
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if revision:
|
||||
model.config._source_revision = revision
|
||||
|
||||
return cls(model=model)
|
||||
return cls(model=model, revision=revision)
|
||||
|
||||
# Handle the case for the base model
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
|
|
@ -1200,11 +1195,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
|
|||
class_ = Chronos2Model
|
||||
|
||||
model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
||||
|
||||
if revision:
|
||||
model.config._source_revision = revision
|
||||
|
||||
return cls(model=model)
|
||||
return cls(model=model, revision=revision)
|
||||
|
||||
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1139,18 +1139,13 @@ def test_lora_config_uses_source_revision_from_instantiation(
|
|||
pipeline: Chronos2Pipeline, tmpdir, source_revision
|
||||
):
|
||||
"""
|
||||
Test that fit in 'lora' mode correctly uses the '_source_revision'
|
||||
that was (notionally) stored in the model's config during instantiation.
|
||||
Test that fit in 'lora' mode correctly uses the 'revision'
|
||||
stored in the pipeline instance.
|
||||
"""
|
||||
output_dir = Path(tmpdir)
|
||||
dummy_inputs = [torch.rand(100)]
|
||||
|
||||
# Manually set the source revision on the config to simulate
|
||||
# it being set during from_pretrained
|
||||
if source_revision:
|
||||
pipeline.model.config._source_revision = source_revision
|
||||
elif hasattr(pipeline.model.config, "_source_revision"):
|
||||
delattr(pipeline.model.config, "_source_revision")
|
||||
pipeline.revision = source_revision
|
||||
|
||||
pipeline.fit(
|
||||
inputs=dummy_inputs,
|
||||
|
|
|
|||
Loading…
Reference in a new issue