mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 01:58:27 +00:00
Fix type-checking issues (#295)
*Issue #, if available:* See example build https://github.com/amazon-science/chronos-forecasting/actions/runs/14302765904/job/40313421985 *Description of changes:* - Address type-checker complaints, where possible - Bump bugfix version of the package By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
eec771e339
commit
f40a266a55
3 changed files with 10 additions and 6 deletions
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "chronos-forecasting"
|
||||
version = "1.5.0"
|
||||
version = "1.5.1"
|
||||
authors = [
|
||||
{ name="Abdul Fatir Ansari", email="ansarnd@amazon.com" },
|
||||
{ name="Lorenzo Stella", email="stellalo@amazon.com" },
|
||||
|
|
|
|||
|
|
@ -305,6 +305,8 @@ class ChronosModel(nn.Module):
|
|||
assert (
|
||||
self.config.model_type == "seq2seq"
|
||||
), "Encoder embeddings are only supported for encoder-decoder models"
|
||||
assert hasattr(self.model, "encoder")
|
||||
|
||||
return self.model.encoder(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
|
@ -344,6 +346,8 @@ class ChronosModel(nn.Module):
|
|||
if top_p is None:
|
||||
top_p = self.config.top_p
|
||||
|
||||
assert hasattr(self.model, "generate")
|
||||
|
||||
preds = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
|
|
|||
|
|
@ -136,12 +136,12 @@ class ResidualBlock(nn.Module):
|
|||
|
||||
|
||||
class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
_keys_to_ignore_on_load_missing = [ # type: ignore
|
||||
r"input_patch_embedding\.",
|
||||
r"output_patch_embedding\.",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
|
||||
|
|
@ -358,7 +358,7 @@ class ChronosBoltModelForForecasting(T5PreTrainedModel):
|
|||
(target - quantile_preds)
|
||||
* (
|
||||
(target <= quantile_preds).float()
|
||||
- self.quantiles.view(1, self.num_quantiles, 1)
|
||||
- self.quantiles.view(1, self.num_quantiles, 1) # type: ignore
|
||||
)
|
||||
)
|
||||
* target_mask.float()
|
||||
|
|
@ -429,7 +429,7 @@ class ChronosBoltPipeline(BaseChronosPipeline):
|
|||
default_context_length: int = 2048
|
||||
|
||||
def __init__(self, model: ChronosBoltModelForForecasting):
|
||||
super().__init__(inner_model=model)
|
||||
super().__init__(inner_model=model) # type: ignore
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
|
|
|
|||
Loading…
Reference in a new issue