From f40a266a550e37ac16cb6c46dd29a6183f39f618 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Thu, 10 Apr 2025 17:23:59 +0200 Subject: [PATCH] Fix type-checking issues (#295) *Issue #, if available:* See example build https://github.com/amazon-science/chronos-forecasting/actions/runs/14302765904/job/40313421985 *Description of changes:* - Address type-checker complaints, where possible - Bump bugfix version of the package By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- pyproject.toml | 2 +- src/chronos/chronos.py | 4 ++++ src/chronos/chronos_bolt.py | 10 +++++----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c722699..0eff211 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chronos-forecasting" -version = "1.5.0" +version = "1.5.1" authors = [ { name="Abdul Fatir Ansari", email="ansarnd@amazon.com" }, { name="Lorenzo Stella", email="stellalo@amazon.com" }, diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 31df48a..870bfa6 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -305,6 +305,8 @@ class ChronosModel(nn.Module): assert ( self.config.model_type == "seq2seq" ), "Encoder embeddings are only supported for encoder-decoder models" + assert hasattr(self.model, "encoder") + return self.model.encoder( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state @@ -344,6 +346,8 @@ class ChronosModel(nn.Module): if top_p is None: top_p = self.config.top_p + assert hasattr(self.model, "generate") + preds = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 47bca45..f099e04 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -136,12 +136,12 @@ class ResidualBlock(nn.Module): class ChronosBoltModelForForecasting(T5PreTrainedModel): - _keys_to_ignore_on_load_missing = [ + _keys_to_ignore_on_load_missing = [ # type: ignore r"input_patch_embedding\.", r"output_patch_embedding\.", ] - _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore def __init__(self, config: T5Config): assert hasattr(config, "chronos_config"), "Not a Chronos config file" @@ -358,7 +358,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel): (target - quantile_preds) * ( (target <= quantile_preds).float() - - self.quantiles.view(1, self.num_quantiles, 1) + - self.quantiles.view(1, self.num_quantiles, 1) # type: ignore ) ) * target_mask.float() @@ -429,7 +429,7 @@ class ChronosBoltPipeline(BaseChronosPipeline): default_context_length: int = 2048 def __init__(self, model: ChronosBoltModelForForecasting): - super().__init__(inner_model=model) + super().__init__(inner_model=model) # type: ignore self.model = model @property