mirror of
https://github.com/amazon-science/chronos-forecasting
synced 2026-05-24 10:08:33 +00:00
Fix types, add mypy to workflow (#42)
*Description of changes:* Fix some type checking issues, add mypy to github workflow, apply black By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
96cedec3fa
commit
4b1d1c818b
4 changed files with 49 additions and 26 deletions
43
.github/workflows/ci.yml
vendored
43
.github/workflows/ci.yml
vendored
|
|
@ -3,23 +3,44 @@ name: CI
|
|||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
type-check:
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.11']
|
||||
python-version: ["3.11"]
|
||||
platform: [ubuntu-latest]
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: pip install ".[test]"
|
||||
- name: Test with pytest
|
||||
run: pytest
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: pip install ".[typecheck]"
|
||||
- name: Type checks with mypy
|
||||
run: mypy src test
|
||||
|
||||
test:
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
platform: [ubuntu-latest]
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: pip install ".[test]"
|
||||
- name: Test with pytest
|
||||
run: pytest
|
||||
|
|
|
|||
|
|
@ -2,18 +2,16 @@
|
|||
name = "chronos"
|
||||
version = "1.1.0"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"torch~=2.1", # package was tested on 2.2
|
||||
"transformers~=4.31",
|
||||
"accelerate"
|
||||
"torch~=2.1", # package was tested on 2.2
|
||||
"transformers~=4.31",
|
||||
"accelerate",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest~=8.0",
|
||||
"numpy~=1.21"
|
||||
]
|
||||
test = ["pytest~=8.0", "numpy~=1.21"]
|
||||
typecheck = ["mypy~=1.9"]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
|
|
|
|||
|
|
@ -367,9 +367,9 @@ class ChronosPipeline:
|
|||
or the length of the longest time series, if a list of 1D tensors was
|
||||
provided, and the extra 1 is for EOS.
|
||||
"""
|
||||
context = self._prepare_and_validate_context(context=context)
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
|
||||
context
|
||||
context_tensor
|
||||
)
|
||||
embeddings = self.model.encode(
|
||||
input_ids=token_ids.to(self.model.device),
|
||||
|
|
@ -424,7 +424,7 @@ class ChronosPipeline:
|
|||
Tensor of sample forecasts, of shape
|
||||
(batch_size, num_samples, prediction_length).
|
||||
"""
|
||||
context = self._prepare_and_validate_context(context=context)
|
||||
context_tensor = self._prepare_and_validate_context(context=context)
|
||||
|
||||
if prediction_length is None:
|
||||
prediction_length = self.model.config.prediction_length
|
||||
|
|
@ -443,7 +443,9 @@ class ChronosPipeline:
|
|||
remaining = prediction_length
|
||||
|
||||
while remaining > 0:
|
||||
token_ids, attention_mask, scale = self.tokenizer.input_transform(context)
|
||||
token_ids, attention_mask, scale = self.tokenizer.input_transform(
|
||||
context_tensor
|
||||
)
|
||||
samples = self.model(
|
||||
token_ids.to(self.model.device),
|
||||
attention_mask.to(self.model.device),
|
||||
|
|
@ -463,7 +465,9 @@ class ChronosPipeline:
|
|||
if remaining <= 0:
|
||||
break
|
||||
|
||||
context = torch.cat([context, prediction.median(dim=1).values], dim=-1)
|
||||
context_tensor = torch.cat(
|
||||
[context_tensor, prediction.median(dim=1).values], dim=-1
|
||||
)
|
||||
|
||||
return torch.cat(predictions, dim=-1)
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ def test_tokenizer_fixed_data(
|
|||
|
||||
samples = tokenizer.output_transform(
|
||||
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
|
||||
decoding_context=scale,
|
||||
tokenizer_state=scale,
|
||||
)
|
||||
|
||||
assert (samples[:, 0, [0, -1]] == context).all()
|
||||
|
|
@ -119,7 +119,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
|
|||
assert samples.shape == (2, 10, 4)
|
||||
|
||||
|
||||
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None:
|
||||
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
|
||||
assert isinstance(samples, torch.Tensor)
|
||||
assert samples.shape == shape
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue