Refactor revision handling: remove private attribute hack

This commit is contained in:
Ryuichi Ichinose 2025-12-12 23:13:44 +09:00
parent 53bc613244
commit 7f2f5a1d18
No known key found for this signature in database
GPG key ID: D5D8D056D6A7CA53
2 changed files with 8 additions and 22 deletions

View file

@ -40,10 +40,10 @@ class Chronos2Pipeline(BaseChronosPipeline):
forecast_type: ForecastType = ForecastType.QUANTILES
default_context_length: int = 2048
def __init__(self, model: Chronos2Model):
def __init__(self, model: Chronos2Model, revision: str | None = None):
super().__init__(inner_model=model)
self.model = model
self.revision = revision
@staticmethod
def _get_prob_mass_per_quantile_level(quantile_levels: torch.Tensor) -> torch.Tensor:
"""
@ -195,7 +195,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
model.load_state_dict(self.model.state_dict())
if finetune_mode == "lora":
lora_revision = getattr(self.model.config, "_source_revision", None)
lora_revision = self.revision
if lora_config is None:
lora_config = LoraConfig(
r=8,
@ -210,7 +210,6 @@ class Chronos2Pipeline(BaseChronosPipeline):
revision=lora_revision,
)
elif isinstance(lora_config, dict):
lora_config.setdefault("revision", lora_revision)
lora_config = LoraConfig(**lora_config)
else:
assert isinstance(lora_config, LoraConfig), (
@ -1182,11 +1181,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
model = model.merge_and_unload()
if revision:
model.config._source_revision = revision
return cls(model=model)
return cls(model=model, revision=revision)
# Handle the case for the base model
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
@ -1200,11 +1195,7 @@ class Chronos2Pipeline(BaseChronosPipeline):
class_ = Chronos2Model
model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if revision:
model.config._source_revision = revision
return cls(model=model)
return cls(model=model, revision=revision)
def save_pretrained(self, save_directory: str | Path, *args, **kwargs):
"""

View file

@ -1139,18 +1139,13 @@ def test_lora_config_uses_source_revision_from_instantiation(
pipeline: Chronos2Pipeline, tmpdir, source_revision
):
"""
Test that fit in 'lora' mode correctly uses the '_source_revision'
that was (notionally) stored in the model's config during instantiation.
Test that fit in 'lora' mode correctly uses the 'revision'
stored in the pipeline instance.
"""
output_dir = Path(tmpdir)
dummy_inputs = [torch.rand(100)]
# Manually set the source revision on the config to simulate
# it being set during from_pretrained
if source_revision:
pipeline.model.config._source_revision = source_revision
elif hasattr(pipeline.model.config, "_source_revision"):
delattr(pipeline.model.config, "_source_revision")
pipeline.revision = source_revision
pipeline.fit(
inputs=dummy_inputs,