From 4b1d1c818b561cc95274050cce42d44da79d71ed Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 5 Apr 2024 15:36:39 +0200 Subject: [PATCH] Fix types, add mypy to workflow (#42) *Description of changes:* Fix some type checking issues, add mypy to github workflow, apply black By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --- .github/workflows/ci.yml | 43 ++++++++++++++++++++++++++++++---------- pyproject.toml | 14 ++++++------- src/chronos/chronos.py | 14 ++++++++----- test/test_chronos.py | 4 ++-- 4 files changed, 49 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c8cba5..7763950 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,23 +3,44 @@ name: CI on: [push, pull_request] jobs: - test: + type-check: strategy: max-parallel: 4 fail-fast: false matrix: - python-version: ['3.11'] + python-version: ["3.11"] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: pip install ".[test]" - - name: Test with pytest - run: pytest + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: pip install ".[typecheck]" + - name: Type checks with mypy + run: mypy src test + + test: + strategy: + max-parallel: 4 + fail-fast: false + matrix: + python-version: ["3.11"] + platform: [ubuntu-latest] + + runs-on: ${{ matrix.platform }} + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: pip install ".[test]" + - name: Test with pytest + run: pytest diff --git a/pyproject.toml b/pyproject.toml index db91d78..61cafa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,18 +2,16 @@ name = "chronos" version = "1.1.0" requires-python = ">=3.8" -license = {file = "LICENSE"} +license = { file = "LICENSE" } dependencies = [ - "torch~=2.1", # package was tested on 2.2 - "transformers~=4.31", - "accelerate" + "torch~=2.1", # package was tested on 2.2 + "transformers~=4.31", + "accelerate", ] [project.optional-dependencies] -test = [ - "pytest~=8.0", - "numpy~=1.21" -] +test = ["pytest~=8.0", "numpy~=1.21"] +typecheck = ["mypy~=1.9"] [tool.mypy] ignore_missing_imports = true diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 92f1033..175b5db 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -367,9 +367,9 @@ class ChronosPipeline: or the length of the longest time series, if a list of 1D tensors was provided, and the extra 1 is for EOS. """ - context = self._prepare_and_validate_context(context=context) + context_tensor = self._prepare_and_validate_context(context=context) token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform( - context + context_tensor ) embeddings = self.model.encode( input_ids=token_ids.to(self.model.device), @@ -424,7 +424,7 @@ class ChronosPipeline: Tensor of sample forecasts, of shape (batch_size, num_samples, prediction_length). """ - context = self._prepare_and_validate_context(context=context) + context_tensor = self._prepare_and_validate_context(context=context) if prediction_length is None: prediction_length = self.model.config.prediction_length @@ -443,7 +443,9 @@ class ChronosPipeline: remaining = prediction_length while remaining > 0: - token_ids, attention_mask, scale = self.tokenizer.input_transform(context) + token_ids, attention_mask, scale = self.tokenizer.input_transform( + context_tensor + ) samples = self.model( token_ids.to(self.model.device), attention_mask.to(self.model.device), @@ -463,7 +465,9 @@ class ChronosPipeline: if remaining <= 0: break - context = torch.cat([context, prediction.median(dim=1).values], dim=-1) + context_tensor = torch.cat( + [context_tensor, prediction.median(dim=1).values], dim=-1 + ) return torch.cat(predictions, dim=-1) diff --git a/test/test_chronos.py b/test/test_chronos.py index 0fba658..a7d63bc 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -59,7 +59,7 @@ def test_tokenizer_fixed_data( samples = tokenizer.output_transform( torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1), - decoding_context=scale, + tokenizer_state=scale, ) assert (samples[:, 0, [0, -1]] == context).all() @@ -119,7 +119,7 @@ def test_tokenizer_random_data(use_eos_token: bool): assert samples.shape == (2, 10, 4) -def validate_tensor(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None: +def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None: assert isinstance(samples, torch.Tensor) assert samples.shape == shape