Fix type-checking issues (#295)

*Issue #, if available:* See example build
https://github.com/amazon-science/chronos-forecasting/actions/runs/14302765904/job/40313421985

*Description of changes:*
- Address type-checker complaints, where possible
- Bump bugfix version of the package


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Lorenzo Stella 2025-04-10 17:23:59 +02:00 committed by GitHub
parent eec771e339
commit f40a266a55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 10 additions and 6 deletions

View file

@ -1,6 +1,6 @@
[project]
name = "chronos-forecasting"
version = "1.5.0"
version = "1.5.1"
authors = [
{ name="Abdul Fatir Ansari", email="ansarnd@amazon.com" },
{ name="Lorenzo Stella", email="stellalo@amazon.com" },

View file

@ -305,6 +305,8 @@ class ChronosModel(nn.Module):
assert (
self.config.model_type == "seq2seq"
), "Encoder embeddings are only supported for encoder-decoder models"
assert hasattr(self.model, "encoder")
return self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
@ -344,6 +346,8 @@ class ChronosModel(nn.Module):
if top_p is None:
top_p = self.config.top_p
assert hasattr(self.model, "generate")
preds = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,

View file

@ -136,12 +136,12 @@ class ResidualBlock(nn.Module):
class ChronosBoltModelForForecasting(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
_keys_to_ignore_on_load_missing = [ # type: ignore
r"input_patch_embedding\.",
r"output_patch_embedding\.",
]
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore
def __init__(self, config: T5Config):
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
@ -358,7 +358,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
(target - quantile_preds)
* (
(target <= quantile_preds).float()
- self.quantiles.view(1, self.num_quantiles, 1)
- self.quantiles.view(1, self.num_quantiles, 1) # type: ignore
)
)
* target_mask.float()
@ -429,7 +429,7 @@ class ChronosBoltPipeline(BaseChronosPipeline):
default_context_length: int = 2048
def __init__(self, model: ChronosBoltModelForForecasting):
super().__init__(inner_model=model)
super().__init__(inner_model=model) # type: ignore
self.model = model
@property