Fix types, add mypy to workflow (#42)

*Description of changes:* Fix some type checking issues, add mypy to
github workflow, apply black


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
This commit is contained in:
Lorenzo Stella 2024-04-05 15:36:39 +02:00 committed by GitHub
parent 96cedec3fa
commit 4b1d1c818b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 49 additions and 26 deletions

View file

@ -3,23 +3,44 @@ name: CI
on: [push, pull_request]
jobs:
test:
type-check:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ['3.11']
python-version: ["3.11"]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]"
- name: Test with pytest
run: pytest
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[typecheck]"
- name: Type checks with mypy
run: mypy src test
test:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ["3.11"]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]"
- name: Test with pytest
run: pytest

View file

@ -2,18 +2,16 @@
name = "chronos"
version = "1.1.0"
requires-python = ">=3.8"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
dependencies = [
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate"
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate",
]
[project.optional-dependencies]
test = [
"pytest~=8.0",
"numpy~=1.21"
]
test = ["pytest~=8.0", "numpy~=1.21"]
typecheck = ["mypy~=1.9"]
[tool.mypy]
ignore_missing_imports = true

View file

@ -367,9 +367,9 @@ class ChronosPipeline:
or the length of the longest time series, if a list of 1D tensors was
provided, and the extra 1 is for EOS.
"""
context = self._prepare_and_validate_context(context=context)
context_tensor = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
context
context_tensor
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
@ -424,7 +424,7 @@ class ChronosPipeline:
Tensor of sample forecasts, of shape
(batch_size, num_samples, prediction_length).
"""
context = self._prepare_and_validate_context(context=context)
context_tensor = self._prepare_and_validate_context(context=context)
if prediction_length is None:
prediction_length = self.model.config.prediction_length
@ -443,7 +443,9 @@ class ChronosPipeline:
remaining = prediction_length
while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.input_transform(context)
token_ids, attention_mask, scale = self.tokenizer.input_transform(
context_tensor
)
samples = self.model(
token_ids.to(self.model.device),
attention_mask.to(self.model.device),
@ -463,7 +465,9 @@ class ChronosPipeline:
if remaining <= 0:
break
context = torch.cat([context, prediction.median(dim=1).values], dim=-1)
context_tensor = torch.cat(
[context_tensor, prediction.median(dim=1).values], dim=-1
)
return torch.cat(predictions, dim=-1)

View file

@ -59,7 +59,7 @@ def test_tokenizer_fixed_data(
samples = tokenizer.output_transform(
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
decoding_context=scale,
tokenizer_state=scale,
)
assert (samples[:, 0, [0, -1]] == context).all()
@ -119,7 +119,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
assert samples.shape == (2, 10, 4)
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None:
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
assert isinstance(samples, torch.Tensor)
assert samples.shape == shape