diff --git a/src/chronos/base.py b/src/chronos/base.py index 3dc1775..bf57b55 100644 --- a/src/chronos/base.py +++ b/src/chronos/base.py @@ -67,7 +67,8 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): **kwargs, ): """ - Get forecasts for the given time series. + Get forecasts for the given time series. Predictions will be + returned in fp32 on the cpu. Parameters ---------- @@ -97,6 +98,7 @@ class BaseChronosPipeline(metaclass=PipelineRegistry): ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get quantile and mean forecasts for given time series. + Predictions will be returned in fp32 on the cpu. Parameters ---------- diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index ef226f6..31df48a 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -500,9 +500,6 @@ class ChronosPipeline(BaseChronosPipeline): raise ValueError(msg) logger.warning(msg) - input_dtype = context_tensor.dtype - input_device = context_tensor.device - predictions = [] remaining = prediction_length @@ -533,7 +530,7 @@ class ChronosPipeline(BaseChronosPipeline): [context_tensor, prediction.median(dim=1).values], dim=-1 ) - return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device) + return torch.cat(predictions, dim=-1).to(dtype=torch.float32, device="cpu") def predict_quantiles( self, diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index e3182f9..4825466 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -487,13 +487,14 @@ class ChronosBoltPipeline(BaseChronosPipeline): # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast # horizon that the model was trained with (i.e., 64). This results in variance collapsing # every 64 steps. + context_tensor = context_tensor.to( + device=self.model.device, + dtype=torch.float32, + ) while remaining > 0: with torch.no_grad(): prediction = self.model( - context=context_tensor.to( - device=self.model.device, - dtype=torch.float32, # scaling should be done in 32-bit precision - ), + context=context_tensor, ).quantile_preds.to(context_tensor) predictions.append(prediction) @@ -507,7 +508,9 @@ class ChronosBoltPipeline(BaseChronosPipeline): context_tensor = torch.cat([context_tensor, central_prediction], dim=-1) - return torch.cat(predictions, dim=-1)[..., :prediction_length] + return torch.cat(predictions, dim=-1)[..., :prediction_length].to( + dtype=torch.float32, device="cpu" + ) def predict_quantiles( self, diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..03f633a --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/test/test_chronos.py b/test/test_chronos.py index b0235c0..763fde2 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from typing import Optional, Tuple import pytest import torch @@ -13,6 +12,7 @@ from chronos import ( ChronosPipeline, MeanScaleUniformBins, ) +from test.util import validate_tensor def test_base_chronos_pipeline_loads_from_huggingface(): @@ -166,30 +166,21 @@ def test_tokenizer_random_data(use_eos_token: bool): assert samples.shape == (2, 10, 4) -def validate_tensor( - a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None -) -> None: - assert isinstance(a, torch.Tensor) - assert a.shape == shape - - if dtype is not None: - assert a.dtype == dtype - - @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", device_map="cpu", torch_dtype=model_dtype, ) - context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 + context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) # input: tensor of shape (batch_size, context_length) samples = pipeline.predict(context, num_samples=12, prediction_length=3) - validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) + validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32) with pytest.raises(ValueError): samples = pipeline.predict( @@ -199,12 +190,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): samples = pipeline.predict( context, num_samples=7, prediction_length=65, limit_prediction_length=False ) - validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) + validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) - validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype) + validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32) with pytest.raises(ValueError): samples = pipeline.predict( @@ -220,12 +211,12 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): prediction_length=65, limit_prediction_length=False, ) - validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype) + validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32) # input: tensor of shape (context_length,) samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) - validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype) + validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32) with pytest.raises(ValueError): samples = pipeline.predict( @@ -240,16 +231,18 @@ def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype): num_samples=7, prediction_length=65, ) - validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype) + validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32) @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) @pytest.mark.parametrize("prediction_length", [3, 65]) @pytest.mark.parametrize( "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]] ) def test_pipeline_predict_quantiles( model_dtype: torch.dtype, + input_dtype: torch.dtype, prediction_length: int, quantile_levels: list[int], ): @@ -259,6 +252,7 @@ def test_pipeline_predict_quantiles( torch_dtype=model_dtype, ) context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) num_expected_quantiles = len(quantile_levels) # input: tensor of shape (batch_size, context_length) @@ -269,8 +263,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) @@ -280,8 +276,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: tensor of shape (context_length,) @@ -291,12 +289,14 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (1, prediction_length)) + validate_tensor( + quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (1, prediction_length), dtype=torch.float32) @pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-model", @@ -304,7 +304,8 @@ def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype): torch_dtype=model_dtype, ) d_model = pipeline.model.model.config.d_model - context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10 + context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) # input: tensor of shape (batch_size, context_length) diff --git a/test/test_chronos_bolt.py b/test/test_chronos_bolt.py index c4c3db7..4b72568 100644 --- a/test/test_chronos_bolt.py +++ b/test/test_chronos_bolt.py @@ -1,16 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + from pathlib import Path -from typing import Tuple import pytest import torch from chronos import BaseChronosPipeline, ChronosBoltPipeline from chronos.chronos_bolt import InstanceNorm, Patch - - -def validate_tensor(input: torch.Tensor, shape: Tuple[int, ...]) -> None: - assert isinstance(input, torch.Tensor) - assert input.shape == shape +from test.util import validate_tensor def test_base_chronos_pipeline_loads_from_huggingface(): @@ -18,19 +16,21 @@ def test_base_chronos_pipeline_loads_from_huggingface(): @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) -def test_pipeline_predict(torch_dtype: str): +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) +def test_pipeline_predict(torch_dtype: torch.dtype, input_dtype: torch.dtype): pipeline = ChronosBoltPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-bolt-model", device_map="cpu", torch_dtype=torch_dtype, ) context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) expected_num_quantiles = len(pipeline.quantiles) # input: tensor of shape (batch_size, context_length) quantiles = pipeline.predict(context, prediction_length=3) - validate_tensor(quantiles, (4, expected_num_quantiles, 3)) + validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32) with pytest.raises(ValueError): quantiles = pipeline.predict( @@ -43,7 +43,7 @@ def test_pipeline_predict(torch_dtype: str): # input: batch_size-long list of tensors of shape (context_length,) quantiles = pipeline.predict(list(context), prediction_length=3) - validate_tensor(quantiles, (4, expected_num_quantiles, 3)) + validate_tensor(quantiles, (4, expected_num_quantiles, 3), dtype=torch.float32) with pytest.raises(ValueError): quantiles = pipeline.predict( @@ -53,12 +53,12 @@ def test_pipeline_predict(torch_dtype: str): ) quantiles = pipeline.predict(list(context), prediction_length=65) - validate_tensor(quantiles, (4, expected_num_quantiles, 65)) + validate_tensor(quantiles, (4, expected_num_quantiles, 65), dtype=torch.float32) # input: tensor of shape (context_length,) quantiles = pipeline.predict(context[0, ...], prediction_length=3) - validate_tensor(quantiles, (1, expected_num_quantiles, 3)) + validate_tensor(quantiles, (1, expected_num_quantiles, 3), dtype=torch.float32) with pytest.raises(ValueError): quantiles = pipeline.predict( @@ -71,16 +71,20 @@ def test_pipeline_predict(torch_dtype: str): context[0, ...], prediction_length=65, ) - validate_tensor(quantiles, (1, expected_num_quantiles, 65)) + validate_tensor(quantiles, (1, expected_num_quantiles, 65), dtype=torch.float32) @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) @pytest.mark.parametrize("prediction_length", [3, 65]) @pytest.mark.parametrize( "quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]] ) def test_pipeline_predict_quantiles( - torch_dtype: str, prediction_length: int, quantile_levels: list[int] + torch_dtype: torch.dtype, + input_dtype: torch.dtype, + prediction_length: int, + quantile_levels: list[int], ): pipeline = ChronosBoltPipeline.from_pretrained( Path(__file__).parent / "dummy-chronos-bolt-model", @@ -88,6 +92,7 @@ def test_pipeline_predict_quantiles( torch_dtype=torch_dtype, ) context = 10 * torch.rand(size=(4, 16)) + 10 + context = context.to(dtype=input_dtype) num_expected_quantiles = len(quantile_levels) # input: tensor of shape (batch_size, context_length) @@ -97,8 +102,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: batch_size-long list of tensors of shape (context_length,) @@ -107,8 +114,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (4, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (4, prediction_length)) + validate_tensor( + quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (4, prediction_length), dtype=torch.float32) # input: tensor of shape (context_length,) @@ -117,8 +126,10 @@ def test_pipeline_predict_quantiles( prediction_length=prediction_length, quantile_levels=quantile_levels, ) - validate_tensor(quantiles, (1, prediction_length, num_expected_quantiles)) - validate_tensor(mean, (1, prediction_length)) + validate_tensor( + quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32 + ) + validate_tensor(mean, (1, prediction_length), dtype=torch.float32) # The following tests have been taken from diff --git a/test/util.py b/test/util.py new file mode 100644 index 0000000..37a2c3b --- /dev/null +++ b/test/util.py @@ -0,0 +1,13 @@ +from typing import Optional, Tuple + +import torch + + +def validate_tensor( + a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None +) -> None: + assert isinstance(a, torch.Tensor) + assert a.shape == shape + + if dtype is not None: + assert a.dtype == dtype \ No newline at end of file